Unverified Commit fe6ff4a9 authored by wulu473's avatar wulu473 Committed by GitHub
Browse files

Add substep callbacks (#12951)


Co-authored-by: default avatarLukas Wutschitz <lukas.wutschitz@microsoft.com>
parent f84226b7
......@@ -1334,6 +1334,8 @@ class Trainer:
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
......
......@@ -242,6 +242,12 @@ class TrainerCallback:
"""
pass
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of an substep during gradient accumulation.
"""
pass
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of a training step. If using gradient accumulation, one training step might take
......@@ -355,6 +361,9 @@ class CallbackHandler(TrainerCallback):
control.should_save = False
return self.call_event("on_step_begin", args, state, control)
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_substep_end", args, state, control)
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_step_end", args, state, control)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment