Unverified Commit 98cb7b2c authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #445 from lukovnikov/master

Learning rate schedules improvement + extension
parents 68a889ee 69850b40
...@@ -20,33 +20,157 @@ from torch.optim import Optimizer ...@@ -20,33 +20,157 @@ from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging import logging
import abc
import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup: if sys.version_info >= (3, 4):
return x/warmup ABC = abc.ABC
x_ = (x - warmup) / (1 - warmup) # progress after warmup - else:
return 0.5 * (1. + math.cos(math.pi * x_)) ABC = abc.ABCMeta('ABC', (), {})
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. class _LRSchedule(ABC):
Learning rate is 1. afterwards. """ """ Parent of all LRSchedules here. """
if x < warmup: warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
return x/warmup def __init__(self, warmup=0.002, t_total=-1, **kw):
return 1.0 """
:param warmup: what fraction of t_total steps will be used for linear warmup
def warmup_linear(x, warmup=0.002): :param t_total: how many training steps (updates) are planned
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. :param kw:
After `t_total`-th training step, learning rate is zero. """ """
if x < warmup: super(_LRSchedule, self).__init__(**kw)
return x/warmup if t_total < 0:
return max((x-1.)/(warmup-1.), 0) logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
warmup = max(warmup, 0.)
self.warmup, self.t_total = float(warmup), float(t_total)
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
if self.t_total < 0:
return 1.
progress = float(step) / self.t_total
ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
logger.warning(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.format(ret, self.__class__.__name__))
self.warned_for_t_total_at_progress = progress
# end warning
return ret
@abc.abstractmethod
def get_lr_(self, progress):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return 1.
class ConstantLR(_LRSchedule):
def get_lr_(self, progress):
return 1.
class WarmupCosineSchedule(_LRSchedule):
"""
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
"""
warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
self.cycles = cycles
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
"""
Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1).
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
return ret
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
"""
Cosine learning rate schedule with linear warmups and linear warmup restarts.
The same warmup rate is used for warmup restarts as for initial warmup.
The total effective fraction of warmup steps over all cycles is warmup * cycles!
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
warmup = warmup * cycles if warmup >= 0 else warmup
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress):
progress = progress * self.cycles % 1.
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * progress))
return ret
class WarmupConstantSchedule(_LRSchedule):
"""
Applies linear warmup. After warmup always returns 1..
"""
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return 1.
class WarmupLinearSchedule(_LRSchedule):
"""
Linear warmup. Linear decay after warmup.
"""
warn_t_total = True
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0.)
SCHEDULES = { SCHEDULES = {
'warmup_cosine': warmup_cosine, None: ConstantLR,
'warmup_constant': warmup_constant, "none": ConstantLR,
'warmup_linear': warmup_linear, "warmup_cosine": WarmupCosineSchedule,
"warmup_constant": WarmupConstantSchedule,
"warmup_linear": WarmupLinearSchedule
} }
...@@ -56,8 +180,10 @@ class BertAdam(Optimizer): ...@@ -56,8 +180,10 @@ class BertAdam(Optimizer):
lr: learning rate lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1 rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear'
b1: Adams b1. Default: 0.9 b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999 b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6 e: Adams epsilon. Default: 1e-6
...@@ -65,21 +191,26 @@ class BertAdam(Optimizer): ...@@ -65,21 +191,26 @@ class BertAdam(Optimizer):
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
""" """
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
max_grad_norm=1.0):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES: if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule)) raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0: if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0: if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0: if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, # initialize schedule object
if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults) super(BertAdam, self).__init__(params, defaults)
...@@ -91,11 +222,8 @@ class BertAdam(Optimizer): ...@@ -91,11 +222,8 @@ class BertAdam(Optimizer):
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
return [0] return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled) lr.append(lr_scheduled)
return lr return lr
...@@ -110,8 +238,6 @@ class BertAdam(Optimizer): ...@@ -110,8 +238,6 @@ class BertAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -153,19 +279,8 @@ class BertAdam(Optimizer): ...@@ -153,19 +279,8 @@ class BertAdam(Optimizer):
if group['weight_decay'] > 0.0: if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data update += group['weight_decay'] * p.data
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
update_with_lr = lr_scheduled * update update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr) p.data.add_(-update_with_lr)
......
...@@ -20,35 +20,11 @@ from torch.optim import Optimizer ...@@ -20,35 +20,11 @@ from torch.optim import Optimizer
from torch.optim.optimizer import required from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import logging import logging
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
x_ = (x - warmup) / (1 - warmup) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * x_))
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
}
class OpenAIAdam(Optimizer): class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix. """Implements Open AI version of Adam algorithm with weight decay fix.
...@@ -58,17 +34,23 @@ class OpenAIAdam(Optimizer): ...@@ -58,17 +34,23 @@ class OpenAIAdam(Optimizer):
vector_l2=False, max_grad_norm=-1, **kwargs): vector_l2=False, max_grad_norm=-1, **kwargs):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES: if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule)) raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0: if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {}".format(b1)) raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0: if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {}".format(b2)) raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0: if not e >= 0.0:
raise ValueError("Invalid epsilon value: {}".format(e)) raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, # initialize schedule object
if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults) super(OpenAIAdam, self).__init__(params, defaults)
...@@ -80,11 +62,8 @@ class OpenAIAdam(Optimizer): ...@@ -80,11 +62,8 @@ class OpenAIAdam(Optimizer):
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
return [0] return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled) lr.append(lr_scheduled)
return lr return lr
...@@ -99,8 +78,6 @@ class OpenAIAdam(Optimizer): ...@@ -99,8 +78,6 @@ class OpenAIAdam(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
warned_for_t_total = False
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -136,19 +113,8 @@ class OpenAIAdam(Optimizer): ...@@ -136,19 +113,8 @@ class OpenAIAdam(Optimizer):
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state['step']
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr'] lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
......
...@@ -21,6 +21,10 @@ import unittest ...@@ -21,6 +21,10 @@ import unittest
import torch import torch
from pytorch_pretrained_bert import BertAdam from pytorch_pretrained_bert import BertAdam
from pytorch_pretrained_bert import OpenAIAdam
from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupCosineWithWarmupRestartsSchedule
import numpy as np
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
...@@ -46,5 +50,43 @@ class OptimizationTest(unittest.TestCase): ...@@ -46,5 +50,43 @@ class OptimizationTest(unittest.TestCase):
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
class ScheduleInitTest(unittest.TestCase):
def test_bert_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
def test_openai_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x]
y = np.asarray(y)
expected_zeros = y[[0, 200, 400, 600, 800]]
print(expected_zeros)
expected_ones = y[[50, 250, 450, 650, 850]]
print(expected_ones)
self.assertTrue(np.allclose(expected_ones, 1))
self.assertTrue(np.allclose(expected_zeros, 0))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment