"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5d02e6bd2079a3691c5ef7a1c888abe6a16d854b"
Unverified Commit 2ca73e5e authored by Charbel Abi Daher's avatar Charbel Abi Daher Committed by GitHub
Browse files

Fixed passing scheduler-specific kwargs via TrainingArguments lr_scheduler_kwargs (#27595)

* Fix passing scheduler-specific kwargs through TrainingArguments `lr_scheduler_kwargs`

* Added test for lr_scheduler_kwargs
parent 0864dd3b
......@@ -1111,7 +1111,7 @@ class Trainer:
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
**self.args.lr_scheduler_kwargs,
scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
)
self._created_lr_scheduler = True
return self.lr_scheduler
......
......@@ -39,6 +39,7 @@ from transformers import (
IntervalStrategy,
PretrainedConfig,
TrainingArguments,
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
logging,
)
......@@ -643,6 +644,33 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
def test_lr_scheduler_kwargs(self):
# test scheduler kwargs passed via TrainingArguments
train_dataset = RegressionDataset()
model = RegressionModel()
num_steps, num_warmup_steps = 10, 2
extra_kwargs = {"power": 5.0, "lr_end": 1e-5} # Non-default arguments
args = TrainingArguments(
"./regression",
lr_scheduler_type="polynomial",
lr_scheduler_kwargs=extra_kwargs,
learning_rate=0.2,
warmup_steps=num_warmup_steps,
)
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
# Checking that the scheduler was created
self.assertIsNotNone(trainer.lr_scheduler)
# Checking that the correct args were passed
sched1 = trainer.lr_scheduler
sched2 = get_polynomial_decay_schedule_with_warmup(
trainer.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_steps, **extra_kwargs
)
self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args)
self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords)
def test_reduce_lr_on_plateau_args(self):
# test passed arguments for a custom ReduceLROnPlateau scheduler
train_dataset = RegressionDataset(length=64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment