Commit a3a09c8c authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix for #188

parent 371633d5
...@@ -269,7 +269,7 @@ def initialize( ...@@ -269,7 +269,7 @@ def initialize(
https://github.com/NVIDIA/apex/tree/master/examples/imagenet https://github.com/NVIDIA/apex/tree/master/examples/imagenet
""" """
_amp_state.opt_properties = Properties() _amp_state.opt_properties = Properties()
_amp_state.opt_properties.verbosity = verbosity _amp_state.verbosity = verbosity
if not enabled: if not enabled:
return models, optimizers return models, optimizers
......
...@@ -334,8 +334,6 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -334,8 +334,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
optimizer.step() optimizer.step()
if args.prof: torch.cuda.nvtx.range_pop() if args.prof: torch.cuda.nvtx.range_pop()
input, target = prefetcher.next()
if i%args.print_freq == 0: if i%args.print_freq == 0:
# Every print_freq iterations, check the loss accuracy and speed. # Every print_freq iterations, check the loss accuracy and speed.
# For best performance, it doesn't make sense to print these metrics every # For best performance, it doesn't make sense to print these metrics every
...@@ -374,6 +372,8 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -374,6 +372,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
batch_time=batch_time, batch_time=batch_time,
loss=losses, top1=top1, top5=top5)) loss=losses, top1=top1, top5=top5))
input, target = prefetcher.next()
def validate(val_loader, model, criterion): def validate(val_loader, model, criterion):
batch_time = AverageMeter() batch_time = AverageMeter()
......
...@@ -365,6 +365,9 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -365,6 +365,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
# If you decide to refactor this test, like examples/imagenet, to sample the loss every
# print_freq iterations, make sure to move this prefetching below the accuracy calculation.
input, target = prefetcher.next() input, target = prefetcher.next()
if i % args.print_freq == 0 and i > 1: if i % args.print_freq == 0 and i > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment