Commit 20686b78 authored by lukovnikov's avatar lukovnikov
Browse files

schedule fix

parent 1b4ce76c
......@@ -42,8 +42,8 @@ class LRSchedule(object):
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
warmup = max(warmup, 0)
self.warmup, self.t_total = warmup, t_total
warmup = max(warmup, 0.)
self.warmup, self.t_total = float(warmup), float(t_total)
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
......@@ -153,7 +153,7 @@ class WarmupLinearSchedule(LRSchedule):
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0)
return max((progress - 1.) / (self.warmup - 1.), 0.)
SCHEDULES = {
......
......@@ -51,7 +51,7 @@ class OptimizationTest(unittest.TestCase):
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000, cycles=5)
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x]
# plt.plot(y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment