Unverified Commit 04822f40 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix calibration in PyTorch example (#322)



* Fix example
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0707552e
......@@ -68,7 +68,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
break
def calibrate(model, device, test_loader):
def calibrate(model, device, test_loader, fp8):
"""Calibration function."""
model.eval()
test_loss = 0
......@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=False, calibrating=True):
with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8):
......@@ -182,9 +182,6 @@ def main():
assert use_cuda, "CUDA needed for FP8 execution."
args.use_te = True
if args.use_fp8_infer:
assert not args.use_fp8, "fp8-infer path currently only supports calibration from a bfloat checkpoint"
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
......@@ -213,8 +210,8 @@ def main():
test(model, device, test_loader, args.use_fp8)
scheduler.step()
if args.use_fp8_infer:
calibrate(model, device, test_loader)
if args.use_fp8_infer and not args.use_fp8:
calibrate(model, device, test_loader, args.use_fp8)
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment