"examples/pytorch/text-generation/run_generation.py" did not exist on "a75c64d80c76c3dc71f735d9197a4a601847e0cd"
Unverified Commit 5c882535 authored by Dhruv Pai's avatar Dhruv Pai Committed by GitHub
Browse files

Add on_optimizer_step to callback options (#31095)

* Modified test

* Added on_optimizer_step to callbacks

* Move callback after step is called

* Added on optimizer step callback
parent 4af705c6
...@@ -2306,6 +2306,8 @@ class Trainer: ...@@ -2306,6 +2306,8 @@ class Trainer:
self.optimizer.step() self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run: if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated # Delay optimizer scheduling until metrics are generated
......
...@@ -345,6 +345,12 @@ class TrainerCallback: ...@@ -345,6 +345,12 @@ class TrainerCallback:
""" """
pass pass
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
"""
pass
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
""" """
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
...@@ -470,6 +476,9 @@ class CallbackHandler(TrainerCallback): ...@@ -470,6 +476,9 @@ class CallbackHandler(TrainerCallback):
control.should_save = False control.should_save = False
return self.call_event("on_step_begin", args, state, control) return self.call_event("on_step_begin", args, state, control)
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_optimizer_step", args, state, control)
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_substep_end", args, state, control) return self.call_event("on_substep_end", args, state, control)
......
...@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback): ...@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback):
def on_step_begin(self, args, state, control, **kwargs): def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin") self.events.append("on_step_begin")
def on_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_optimizer_step")
def on_step_end(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs):
self.events.append("on_step_end") self.events.append("on_step_end")
...@@ -148,7 +151,7 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -148,7 +151,7 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events.append("on_epoch_begin") expected_events.append("on_epoch_begin")
for _ in range(train_dl_len): for _ in range(train_dl_len):
step += 1 step += 1
expected_events += ["on_step_begin", "on_step_end"] expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
if step % trainer.args.logging_steps == 0: if step % trainer.args.logging_steps == 0:
expected_events.append("on_log") expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment