Commit 91a073f8 authored by lukovnikov's avatar lukovnikov
Browse files

schedule fix

parent b64cc63a
...@@ -38,11 +38,12 @@ class LRSchedule(object): ...@@ -38,11 +38,12 @@ class LRSchedule(object):
:param kw: :param kw:
""" """
super(LRSchedule, self).__init__(**kw) super(LRSchedule, self).__init__(**kw)
self.warmup, self.t_total = warmup, t_total
if t_total <= 0: if t_total <= 0:
logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1: if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
warmup = max(warmup, 0)
self.warmup, self.t_total = warmup, t_total
self.warned_for_t_total_at_progress = -1 self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False): def get_lr(self, step, nowarn=False):
...@@ -51,6 +52,8 @@ class LRSchedule(object): ...@@ -51,6 +52,8 @@ class LRSchedule(object):
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update :return: learning rate multiplier for current update
""" """
if self.t_total < 0:
return 1.
progress = step / self.t_total progress = step / self.t_total
ret = self.get_lr_(progress) ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear # warning for exceeding t_total (only active with warmup_linear
...@@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule): ...@@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule):
self.cycles = cycles self.cycles = cycles
def get_lr_(self, progress): def get_lr_(self, progress):
""" get learning rate multiplier """
if self.t_total <= 0:
return 1.
if progress < self.warmup: if progress < self.warmup:
return progress / self.warmup return progress / self.warmup
else: else:
...@@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): ...@@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
assert(cycles >= 1.) assert(cycles >= 1.)
def get_lr_(self, progress): def get_lr_(self, progress):
if self.t_total <= 0:
return 1.
if progress < self.warmup: if progress < self.warmup:
return progress / self.warmup return progress / self.warmup
else: else:
...@@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul ...@@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
""" """
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.) assert(warmup * cycles < 1.)
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw) warmup = warmup * cycles if warmup >= 0 else warmup
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress): def get_lr_(self, progress):
if self.t_total <= 0.:
return 1.
progress = progress * self.cycles % 1. progress = progress * self.cycles % 1.
if progress < self.warmup: if progress < self.warmup:
return progress / self.warmup return progress / self.warmup
...@@ -174,7 +171,7 @@ class BertAdam(Optimizer): ...@@ -174,7 +171,7 @@ class BertAdam(Optimizer):
lr: learning rate lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1 rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
schedule: schedule to use for the warmup (see above). schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object. Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear' Default: 'warmup_linear'
......
...@@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase): ...@@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase):
class WarmupCosineWithRestartsTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self): def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5) m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5)
x = np.arange(0, 1000) / 1000 x = np.arange(0, 1000)
y = [m.get_lr_(xe) for xe in x] y = [m.get_lr(xe) for xe in x]
plt.plot(y) plt.plot(y)
plt.show(block=False) plt.show(block=False)
y = np.asarray(y) y = np.asarray(y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment