Commit 59bf7d13 authored by Michael Carilli's avatar Michael Carilli
Browse files

minor cleanup

parent 17971202
...@@ -436,12 +436,12 @@ def adjust_learning_rate(optimizer, epoch, step, len_epoch): ...@@ -436,12 +436,12 @@ def adjust_learning_rate(optimizer, epoch, step, len_epoch):
lr = args.lr*(0.1**factor) lr = args.lr*(0.1**factor)
"""Warmup""" """Warmup"""
if epoch < 5: if epoch < 5:
lr = lr*float(step + epoch*len_epoch)/(5.*len_epoch) lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
if(args.local_rank == 0): # if(args.local_rank == 0):
print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
......
...@@ -433,12 +433,12 @@ def adjust_learning_rate(optimizer, epoch, step, len_epoch): ...@@ -433,12 +433,12 @@ def adjust_learning_rate(optimizer, epoch, step, len_epoch):
lr = args.lr*(0.1**factor) lr = args.lr*(0.1**factor)
"""Warmup""" """Warmup"""
if epoch < 5: if epoch < 5:
lr = lr*float(step + epoch*len_epoch)/(5.*len_epoch) lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
if(args.local_rank == 0): # if(args.local_rank == 0):
print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment