"...resnet50_tensorflow.git" did not exist on "69c75593028133907bda21830e422156a2816d3e"
Commit 022525b0 authored by Rémi Louf's avatar Rémi Louf
Browse files

replace LambdaLR scheduler wrappers by function

Custom schedulers are currently initiated by wrapping Pytorch's LambdaLR
class and passing a method of the wrapping class to the __init__
function of LambdaLR. This approach is not appropriate for several
reasons:

1. one does not need to define a class when it only defines a
__init__() method;
2. instantiating the parent class by passing a method of the child class
creates a cyclical reference which leads to memory leaks. See issues #1742 and #1134.

In this commit we replace the wrapper classes with functions that
instantiate `LambdaLR` with a custom learning rate function. We use a
closure to specify the parameter of the latter. We also do a bit of
renaming within the function to explicit the behaviour and removed
docstrings that were subsequently not necessary.
parent 1c542df7
...@@ -97,8 +97,8 @@ if is_torch_available(): ...@@ -97,8 +97,8 @@ if is_torch_available():
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
# Optimization # Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup)
# TensorFlow # TensorFlow
......
...@@ -23,90 +23,66 @@ from torch.optim.lr_scheduler import LambdaLR ...@@ -23,90 +23,66 @@ from torch.optim.lr_scheduler import LambdaLR
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
""" Constant learning rate schedule. def get_constant_schedule(optimizer, last_epoch=-1):
""" Create a schedule with a constant learning rate.
""" """
def __init__(self, optimizer, last_epoch=-1): return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
class WarmupConstantSchedule(LambdaLR): def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
""" Linear warmup and then constant. """ Create a schedule with a constant learning rate preceded by a warmup
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. period during which the learning rate increases linearly between 0 and 1.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Keeps multiplicative variable equal to 1. after warmup_steps.
""" """
def __init__(self, optimizer, warmup_steps, last_epoch=-1): def lr_lambda(current_step):
self.warmup_steps = warmup_steps if current_step < num_warmup_steps:
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) return float(current_step) / float(max(1.0, num_warmup_steps))
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
return 1. return 1.
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
class WarmupLinearSchedule(LambdaLR):
""" Linear warmup and then linear decay. def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. """ Create a schedule with a learning rate that decreases linearly after
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. linearly increasing during a warmup period.
Linearly decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
class WarmupCosineSchedule(LambdaLR):
""" Linear warmup and then cosine decay.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, then the multiplicative variable follows cosine function after warmup.
""" """
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): def lr_lambda(current_step):
self.warmup_steps = warmup_steps if current_step < num_warmup_steps:
self.t_total = t_total return float(current_step) / float(max(1, num_warmup_steps))
self.cycles = cycles return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps)) def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1):
# progress after warmup """ Create a schedule with a learning rate that decreases following the
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) values of the cosine function between 0 and `pi * cycles` after a warmup
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) period during which it increases linearly between 0 and 1.
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
""" Linear warmup and then cosine cycles with hard restarts.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
""" """
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): def lr_lambda(current_step):
self.warmup_steps = warmup_steps if current_step < num_warmup_steps:
self.t_total = t_total return float(current_step) / float(max(1, num_warmup_steps))
self.cycles = cycles progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress)))
def lr_lambda(self, step): return LambdaLR(optimizer, lr_lambda, last_epoch)
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
# progress after warmup
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function with several hard restarts, after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.:
return 0.
return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.))))
return LambdaLR(optimizer, lr_lambda, last_epoch)
class AdamW(Optimizer): class AdamW(Optimizer):
""" Implements Adam algorithm with weight decay fix. """ Implements Adam algorithm with weight decay fix.
......
...@@ -25,8 +25,12 @@ from transformers import is_torch_available ...@@ -25,8 +25,12 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, from transformers import (AdamW,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup)
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -87,59 +91,60 @@ class ScheduleInitTest(unittest.TestCase): ...@@ -87,59 +91,60 @@ class ScheduleInitTest(unittest.TestCase):
self.assertAlmostEqual(a, b, delta=tol) self.assertAlmostEqual(a, b, delta=tol)
def test_constant_scheduler(self): def test_constant_scheduler(self):
scheduler = ConstantLRSchedule(self.optimizer) scheduler = get_constant_schedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps expected_learning_rates = [10.] * self.num_steps
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates) self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = ConstantLRSchedule(self.optimizer) scheduler = get_constant_schedule(self.optimizer)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_constant_scheduler(self): def test_warmup_constant_scheduler(self):
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates) self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_linear_scheduler(self): def test_warmup_linear_scheduler(self):
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates) self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_scheduler(self): def test_warmup_cosine_scheduler(self):
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_hard_restart_scheduler(self): def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment