Unverified Commit ff7c8168 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[optim] implement AdafactorSchedule (#12123)



* implement AdafactorSchedule

* typo

* fix

* Update src/transformers/optimization.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent fe357648
...@@ -420,6 +420,12 @@ class Adafactor(Optimizer): ...@@ -420,6 +420,12 @@ class Adafactor(Optimizer):
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
When using ``lr=None`` with :class:`~transformers.Trainer` you will most likely need to use :class:`~transformers.optimization.AdafactorSchedule` scheduler as following::
from transformers.optimization import Adafactor, AdafactorSchedule
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
Usage:: Usage::
...@@ -588,3 +594,52 @@ class Adafactor(Optimizer): ...@@ -588,3 +594,52 @@ class Adafactor(Optimizer):
p.data.copy_(p_data_fp32) p.data.copy_(p_data_fp32)
return loss return loss
class AdafactorSchedule(LambdaLR):
"""
Since :class:`~transformers.optimization.Adafactor` performs its own scheduling, if the training loop relies on a
scheduler (e.g., for logging), this class creates a proxy object that retrieves the current lr values from the
optimizer.
It returns ``initial_lr`` during startup and the actual ``lr`` during stepping.
"""
def __init__(self, optimizer, initial_lr=0.0):
def lr_lambda(_):
return initial_lr
for group in optimizer.param_groups:
group["initial_lr"] = initial_lr
super().__init__(optimizer, lr_lambda)
for group in optimizer.param_groups:
del group["initial_lr"]
def get_lr(self):
opt = self.optimizer
lrs = [
opt._get_lr(group, opt.state[group["params"][0]])
for group in opt.param_groups
if group["params"][0].grad is not None
]
if len(lrs) == 0:
lrs = self.base_lrs # if called before stepping
return lrs
def get_adafactor_schedule(optimizer, initial_lr=0.0):
"""
Get a proxy schedule for :class:`~transformers.optimization.Adafactor`
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
initial_lr (:obj:`float`, `optional`, defaults to 0.0):
Initial lr
Return:
:class:`~transformers.optimization.Adafactor` proxy schedule object.
"""
return AdafactorSchedule(optimizer, initial_lr)
...@@ -589,6 +589,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -589,6 +589,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
@require_torch
def test_adafactor_lr_none(self):
# test the special case where lr=None, since Trainer can't not have lr_scheduler
from transformers.optimization import Adafactor, AdafactorSchedule
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
(a, b) = self.default_trained_model
self.assertFalse(torch.allclose(trainer.model.a, a))
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
def test_model_init(self): def test_model_init(self):
train_dataset = RegressionDataset() train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1) args = TrainingArguments("./regression", learning_rate=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment