Commit 022525b0 authored by Rémi Louf's avatar Rémi Louf
Browse files

replace LambdaLR scheduler wrappers by function

Custom schedulers are currently initiated by wrapping Pytorch's LambdaLR
class and passing a method of the wrapping class to the __init__
function of LambdaLR. This approach is not appropriate for several
reasons:

1. one does not need to define a class when it only defines a
__init__() method;
2. instantiating the parent class by passing a method of the child class
creates a cyclical reference which leads to memory leaks. See issues #1742 and #1134.

In this commit we replace the wrapper classes with functions that
instantiate `LambdaLR` with a custom learning rate function. We use a
closure to specify the parameter of the latter. We also do a bit of
renaming within the function to explicit the behaviour and removed
docstrings that were subsequently not necessary.
parent 1c542df7
......@@ -97,8 +97,8 @@ if is_torch_available():
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
# Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup)
# TensorFlow
......
......@@ -23,90 +23,66 @@ from torch.optim.lr_scheduler import LambdaLR
logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
""" Constant learning rate schedule.
def get_constant_schedule(optimizer, last_epoch=-1):
""" Create a schedule with a constant learning rate.
"""
def __init__(self, optimizer, last_epoch=-1):
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
class WarmupConstantSchedule(LambdaLR):
""" Linear warmup and then constant.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Keeps multiplicative variable equal to 1. after warmup_steps.
def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
""" Create a schedule with a constant learning rate preceded by a warmup
period during which the learning rate increases linearly between 0 and 1.
"""
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
self.warmup_steps = warmup_steps
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
return 1.
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
class WarmupLinearSchedule(LambdaLR):
""" Linear warmup and then linear decay.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Linearly decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
class WarmupCosineSchedule(LambdaLR):
""" Linear warmup and then cosine decay.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, then the multiplicative variable follows cosine function after warmup.
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
""" Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.cycles = cycles
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
# progress after warmup
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
""" Linear warmup and then cosine cycles with hard restarts.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function between 0 and `pi * cycles` after a warmup
period during which it increases linearly between 0 and 1.
"""
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.cycles = cycles
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
# progress after warmup
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function with several hard restarts, after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.:
return 0.
return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.))))
return LambdaLR(optimizer, lr_lambda, last_epoch)
class AdamW(Optimizer):
""" Implements Adam algorithm with weight decay fix.
......
......@@ -25,8 +25,12 @@ from transformers import is_torch_available
if is_torch_available():
import torch
from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from transformers import (AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup)
else:
pytestmark = pytest.mark.skip("Require Torch")
......@@ -87,59 +91,60 @@ class ScheduleInitTest(unittest.TestCase):
self.assertAlmostEqual(a, b, delta=tol)
def test_constant_scheduler(self):
scheduler = ConstantLRSchedule(self.optimizer)
scheduler = get_constant_schedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = ConstantLRSchedule(self.optimizer)
scheduler = get_constant_schedule(self.optimizer)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_constant_scheduler(self):
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_linear_scheduler(self):
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_scheduler(self):
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment