"vscode:/vscode.git/clone" did not exist on "7f03f7ceae05d4d45fc4b12b81736c55a13d872c"
Unverified Commit 04822f40 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix calibration in PyTorch example (#322)



* Fix example
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0707552e
...@@ -68,7 +68,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): ...@@ -68,7 +68,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
break break
def calibrate(model, device, test_loader): def calibrate(model, device, test_loader, fp8):
"""Calibration function.""" """Calibration function."""
model.eval() model.eval()
test_loss = 0 test_loss = 0
...@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader): ...@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader):
with torch.no_grad(): with torch.no_grad():
for data, target in test_loader: for data, target in test_loader:
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=False, calibrating=True): with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
...@@ -182,9 +182,6 @@ def main(): ...@@ -182,9 +182,6 @@ def main():
assert use_cuda, "CUDA needed for FP8 execution." assert use_cuda, "CUDA needed for FP8 execution."
args.use_te = True args.use_te = True
if args.use_fp8_infer:
assert not args.use_fp8, "fp8-infer path currently only supports calibration from a bfloat checkpoint"
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
...@@ -213,8 +210,8 @@ def main(): ...@@ -213,8 +210,8 @@ def main():
test(model, device, test_loader, args.use_fp8) test(model, device, test_loader, args.use_fp8)
scheduler.step() scheduler.step()
if args.use_fp8_infer: if args.use_fp8_infer and not args.use_fp8:
calibrate(model, device, test_loader) calibrate(model, device, test_loader, args.use_fp8)
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment