Commit 262a9992 authored by lukovnikov's avatar lukovnikov
Browse files

class weights

parent b6c1cae6
...@@ -24,7 +24,8 @@ import logging ...@@ -24,7 +24,8 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", "WarmupCosineWithRestartsSchedule"] __all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam",
"WarmupMultiCosineSchedule", "WarmupCosineWithRestartsSchedule"]
class LRSchedule(object): class LRSchedule(object):
...@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule): ...@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): class WarmupMultiCosineSchedule(WarmupCosineSchedule):
warn_t_total = True warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupCosineWithRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) super(WarmupMultiCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def get_lr_(self, progress): def get_lr_(self, progress):
if self.t_total <= 0: if self.t_total <= 0:
...@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule): ...@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
return ret return ret
class WarmupCosineWithRestartsSchedule(WarmupMultiCosineSchedule):
def get_lr_(self, progress):
if self.t_total <= 0.:
return 1.
progress = progress * self.cycles % 1.
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * progress))
return ret
class WarmupConstantSchedule(LRSchedule): class WarmupConstantSchedule(LRSchedule):
warn_t_total = False warn_t_total = False
def get_lr_(self, progress): def get_lr_(self, progress):
......
...@@ -20,7 +20,9 @@ import unittest ...@@ -20,7 +20,9 @@ import unittest
import torch import torch
from pytorch_pretrained_bert import BertAdam from pytorch_pretrained_bert import BertAdam, WarmupCosineWithRestartsSchedule
from matplotlib import pyplot as plt
import numpy as np
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
...@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase): ...@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithRestartsSchedule(warmup=0.2, t_total=1, cycles=3)
x = np.arange(0, 1000) / 1000
y = [m.get_lr_(xe) for xe in x]
plt.plot(y)
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment