Unverified Commit 253d43d4 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix lr scheduler not being reset on reruns (#24758)

* Try this

* Solved!

* Rm extranious

* Rm extranious

* self

* Args'

* Check for if we created the lr scheduler

* Move comment

* Clean
parent 1be0145d
......@@ -688,8 +688,9 @@ class Trainer:
self.can_return_loss = can_return_loss(self.model.__class__)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# Internal variables to keep track of the original batch size
# Internal variables to help with automatic batch size reduction
self._train_batch_size = args.train_batch_size
self._created_lr_scheduler = False
# very last
self._memory_tracker.stop_and_update_metrics()
......@@ -1150,6 +1151,7 @@ class Trainer:
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
def num_examples(self, dataloader: DataLoader) -> int:
......@@ -1613,6 +1615,11 @@ class Trainer:
or self.fsdp is not None
)
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None
self._created_lr_scheduler = False
if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment