"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "688448db7547be90203440cfd105703d8a853f39"
Commit 17971202 authored by Michael Carilli's avatar Michael Carilli
Browse files

updating FP16_Optimizer example as well

parent eae8b989
......@@ -204,7 +204,6 @@ def main():
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
......@@ -277,6 +276,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
while input is not None:
i += 1
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if args.prof:
if i > 10:
break
......@@ -423,9 +424,22 @@ class AverageMeter(object):
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(step + epoch*len_epoch)/(5.*len_epoch)
if(args.local_rank == 0):
print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment