Commit bb7557d3 authored by lukovnikov's avatar lukovnikov
Browse files

- removed __all__ in optimization

- removed unused plotting code
- using ABC for LRSchedule
- added some schedule object init tests
parent 34ccc8eb
...@@ -20,15 +20,12 @@ from torch.optim import Optimizer ...@@ -20,15 +20,12 @@ from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging import logging
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam", class _LRSchedule(ABC):
"WarmupCosineWithHardRestartsSchedule", "WarmupCosineWithWarmupRestartsSchedule", "SCHEDULES"]
class LRSchedule(object):
""" Parent of all LRSchedules here. """ """ Parent of all LRSchedules here. """
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
def __init__(self, warmup=0.002, t_total=-1, **kw): def __init__(self, warmup=0.002, t_total=-1, **kw):
...@@ -37,7 +34,7 @@ class LRSchedule(object): ...@@ -37,7 +34,7 @@ class LRSchedule(object):
:param t_total: how many training steps (updates) are planned :param t_total: how many training steps (updates) are planned
:param kw: :param kw:
""" """
super(LRSchedule, self).__init__(**kw) super(_LRSchedule, self).__init__(**kw)
if t_total < 0: if t_total < 0:
logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1: if not 0.0 <= warmup < 1.0 and not warmup == -1:
...@@ -65,16 +62,21 @@ class LRSchedule(object): ...@@ -65,16 +62,21 @@ class LRSchedule(object):
# end warning # end warning
return ret return ret
@abstractmethod
def get_lr_(self, progress): def get_lr_(self, progress):
""" """
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update :return: learning rate multiplier for current update
""" """
return 1. return 1.
# raise NotImplemented("use subclass") -
class WarmupCosineSchedule(LRSchedule): class ConstantLR(_LRSchedule):
def get_lr_(self, progress):
return 1.
class WarmupCosineSchedule(_LRSchedule):
""" """
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts. Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
""" """
...@@ -135,7 +137,7 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul ...@@ -135,7 +137,7 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
return ret return ret
class WarmupConstantSchedule(LRSchedule): class WarmupConstantSchedule(_LRSchedule):
""" """
Applies linear warmup. After warmup always returns 1.. Applies linear warmup. After warmup always returns 1..
""" """
...@@ -145,7 +147,7 @@ class WarmupConstantSchedule(LRSchedule): ...@@ -145,7 +147,7 @@ class WarmupConstantSchedule(LRSchedule):
return 1. return 1.
class WarmupLinearSchedule(LRSchedule): class WarmupLinearSchedule(_LRSchedule):
""" """
Linear warmup. Linear decay after warmup. Linear warmup. Linear decay after warmup.
""" """
...@@ -157,8 +159,8 @@ class WarmupLinearSchedule(LRSchedule): ...@@ -157,8 +159,8 @@ class WarmupLinearSchedule(LRSchedule):
SCHEDULES = { SCHEDULES = {
None: LRSchedule, None: ConstantLR,
"none": LRSchedule, "none": ConstantLR,
"warmup_cosine": WarmupCosineSchedule, "warmup_cosine": WarmupCosineSchedule,
"warmup_constant": WarmupConstantSchedule, "warmup_constant": WarmupConstantSchedule,
"warmup_linear": WarmupLinearSchedule "warmup_linear": WarmupLinearSchedule
...@@ -185,7 +187,7 @@ class BertAdam(Optimizer): ...@@ -185,7 +187,7 @@ class BertAdam(Optimizer):
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule)) raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= b1 < 1.0: if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
...@@ -194,7 +196,7 @@ class BertAdam(Optimizer): ...@@ -194,7 +196,7 @@ class BertAdam(Optimizer):
if not e >= 0.0: if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object # initialize schedule object
if not isinstance(schedule, LRSchedule): if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule] schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total) schedule = schedule_type(warmup=warmup, t_total=t_total)
else: else:
......
...@@ -20,7 +20,8 @@ from torch.optim import Optimizer ...@@ -20,7 +20,8 @@ from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging import logging
from .optimization import * from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,7 +34,7 @@ class OpenAIAdam(Optimizer): ...@@ -33,7 +34,7 @@ class OpenAIAdam(Optimizer):
vector_l2=False, max_grad_norm=-1, **kwargs): vector_l2=False, max_grad_norm=-1, **kwargs):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule)) raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= b1 < 1.0: if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
...@@ -42,7 +43,7 @@ class OpenAIAdam(Optimizer): ...@@ -42,7 +43,7 @@ class OpenAIAdam(Optimizer):
if not e >= 0.0: if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object # initialize schedule object
if not isinstance(schedule, LRSchedule): if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule] schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total) schedule = schedule_type(warmup=warmup, t_total=t_total)
else: else:
......
...@@ -21,10 +21,11 @@ import unittest ...@@ -21,10 +21,11 @@ import unittest
import torch import torch
from pytorch_pretrained_bert import BertAdam from pytorch_pretrained_bert import BertAdam
from pytorch_pretrained_bert.optimization import WarmupCosineWithWarmupRestartsSchedule from pytorch_pretrained_bert import OpenAIAdam
#from matplotlib import pyplot as plt from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule
import numpy as np import numpy as np
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol): def assertListAlmostEqual(self, list1, list2, tol):
...@@ -49,13 +50,33 @@ class OptimizationTest(unittest.TestCase): ...@@ -49,13 +50,33 @@ class OptimizationTest(unittest.TestCase):
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
class ScheduleInitTest(unittest.TestCase):
def test_bert_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
def test_openai_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
class WarmupCosineWithRestartsTest(unittest.TestCase): class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self): def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
x = np.arange(0, 1000) x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x] y = [m.get_lr(xe) for xe in x]
# plt.plot(y)
# plt.show(block=False)
y = np.asarray(y) y = np.asarray(y)
expected_zeros = y[[0, 200, 400, 600, 800]] expected_zeros = y[[0, 200, 400, 600, 800]]
print(expected_zeros) print(expected_zeros)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment