Commit 17971202 authored by Michael Carilli's avatar Michael Carilli
Browse files

updating FP16_Optimizer example as well

parent eae8b989
...@@ -204,7 +204,6 @@ def main(): ...@@ -204,7 +204,6 @@ def main():
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch)
# train for one epoch # train for one epoch
train(train_loader, model, criterion, optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch)
...@@ -277,6 +276,8 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -277,6 +276,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
while input is not None: while input is not None:
i += 1 i += 1
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if args.prof: if args.prof:
if i > 10: if i > 10:
break break
...@@ -423,9 +424,22 @@ class AverageMeter(object): ...@@ -423,9 +424,22 @@ class AverageMeter(object):
self.avg = self.sum / self.count self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch): def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" """LR schedule that should yield 76% converged accuracy with batch size 256"""
lr = args.lr * (0.1 ** (epoch // 30)) factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(step + epoch*len_epoch)/(5.*len_epoch)
if(args.local_rank == 0):
print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment