Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fe6ff4a9
Unverified
Commit
fe6ff4a9
authored
Jul 30, 2021
by
wulu473
Committed by
GitHub
Jul 30, 2021
Browse files
Add substep callbacks (#12951)
Co-authored-by:
Lukas Wutschitz
<
lukas.wutschitz@microsoft.com
>
parent
f84226b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
0 deletions
+11
-0
src/transformers/trainer.py
src/transformers/trainer.py
+2
-0
src/transformers/trainer_callback.py
src/transformers/trainer_callback.py
+9
-0
No files found.
src/transformers/trainer.py
View file @
fe6ff4a9
...
@@ -1334,6 +1334,8 @@ class Trainer:
...
@@ -1334,6 +1334,8 @@ class Trainer:
self
.
control
=
self
.
callback_handler
.
on_step_end
(
args
,
self
.
state
,
self
.
control
)
self
.
control
=
self
.
callback_handler
.
on_step_end
(
args
,
self
.
state
,
self
.
control
)
self
.
_maybe_log_save_evaluate
(
tr_loss
,
model
,
trial
,
epoch
,
ignore_keys_for_eval
)
self
.
_maybe_log_save_evaluate
(
tr_loss
,
model
,
trial
,
epoch
,
ignore_keys_for_eval
)
else
:
self
.
control
=
self
.
callback_handler
.
on_substep_end
(
args
,
self
.
state
,
self
.
control
)
if
self
.
control
.
should_epoch_stop
or
self
.
control
.
should_training_stop
:
if
self
.
control
.
should_epoch_stop
or
self
.
control
.
should_training_stop
:
break
break
...
...
src/transformers/trainer_callback.py
View file @
fe6ff4a9
...
@@ -242,6 +242,12 @@ class TrainerCallback:
...
@@ -242,6 +242,12 @@ class TrainerCallback:
"""
"""
pass
pass
def
on_substep_end
(
self
,
args
:
TrainingArguments
,
state
:
TrainerState
,
control
:
TrainerControl
,
**
kwargs
):
"""
Event called at the end of an substep during gradient accumulation.
"""
pass
def
on_step_end
(
self
,
args
:
TrainingArguments
,
state
:
TrainerState
,
control
:
TrainerControl
,
**
kwargs
):
def
on_step_end
(
self
,
args
:
TrainingArguments
,
state
:
TrainerState
,
control
:
TrainerControl
,
**
kwargs
):
"""
"""
Event called at the end of a training step. If using gradient accumulation, one training step might take
Event called at the end of a training step. If using gradient accumulation, one training step might take
...
@@ -355,6 +361,9 @@ class CallbackHandler(TrainerCallback):
...
@@ -355,6 +361,9 @@ class CallbackHandler(TrainerCallback):
control
.
should_save
=
False
control
.
should_save
=
False
return
self
.
call_event
(
"on_step_begin"
,
args
,
state
,
control
)
return
self
.
call_event
(
"on_step_begin"
,
args
,
state
,
control
)
def
on_substep_end
(
self
,
args
:
TrainingArguments
,
state
:
TrainerState
,
control
:
TrainerControl
):
return
self
.
call_event
(
"on_substep_end"
,
args
,
state
,
control
)
def
on_step_end
(
self
,
args
:
TrainingArguments
,
state
:
TrainerState
,
control
:
TrainerControl
):
def
on_step_end
(
self
,
args
:
TrainingArguments
,
state
:
TrainerState
,
control
:
TrainerControl
):
return
self
.
call_event
(
"on_step_end"
,
args
,
state
,
control
)
return
self
.
call_event
(
"on_step_end"
,
args
,
state
,
control
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment