"convert/vscode:/vscode.git/clone" did not exist on "9876c9faa41c7dd7143fa47727520d353559f81b"
Unverified Commit 253d43d4 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix lr scheduler not being reset on reruns (#24758)

* Try this

* Solved!

* Rm extranious

* Rm extranious

* self

* Args'

* Check for if we created the lr scheduler

* Move comment

* Clean
parent 1be0145d
...@@ -688,8 +688,9 @@ class Trainer: ...@@ -688,8 +688,9 @@ class Trainer:
self.can_return_loss = can_return_loss(self.model.__class__) self.can_return_loss = can_return_loss(self.model.__class__)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# Internal variables to keep track of the original batch size # Internal variables to help with automatic batch size reduction
self._train_batch_size = args.train_batch_size self._train_batch_size = args.train_batch_size
self._created_lr_scheduler = False
# very last # very last
self._memory_tracker.stop_and_update_metrics() self._memory_tracker.stop_and_update_metrics()
...@@ -1150,6 +1151,7 @@ class Trainer: ...@@ -1150,6 +1151,7 @@ class Trainer:
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
) )
self._created_lr_scheduler = True
return self.lr_scheduler return self.lr_scheduler
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
...@@ -1613,6 +1615,11 @@ class Trainer: ...@@ -1613,6 +1615,11 @@ class Trainer:
or self.fsdp is not None or self.fsdp is not None
) )
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None
self._created_lr_scheduler = False
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment