Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3488ef5a
Unverified
Commit
3488ef5a
authored
Jul 07, 2021
by
shabie
Committed by
GitHub
Jul 07, 2021
Browse files
[trainer] add option to ignore keys for the train function too (#11719) (#12551)
parent
45dcfdec
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
src/transformers/trainer.py
src/transformers/trainer.py
+8
-4
No files found.
src/transformers/trainer.py
View file @
3488ef5a
...
...
@@ -989,6 +989,7 @@ class Trainer:
self
,
resume_from_checkpoint
:
Optional
[
Union
[
str
,
bool
]]
=
None
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]
=
None
,
ignore_keys_for_eval
:
Optional
[
List
[
str
]]
=
None
,
**
kwargs
,
):
"""
...
...
@@ -1002,6 +1003,9 @@ class Trainer:
training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
ignore_keys_for_eval (:obj:`List[str]`, `optional`)
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions for evaluation during the training.
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
...
...
@@ -1322,13 +1326,13 @@ class Trainer:
self
.
state
.
epoch
=
epoch
+
(
step
+
1
)
/
steps_in_epoch
self
.
control
=
self
.
callback_handler
.
on_step_end
(
args
,
self
.
state
,
self
.
control
)
self
.
_maybe_log_save_evaluate
(
tr_loss
,
model
,
trial
,
epoch
)
self
.
_maybe_log_save_evaluate
(
tr_loss
,
model
,
trial
,
epoch
,
ignore_keys_for_eval
)
if
self
.
control
.
should_epoch_stop
or
self
.
control
.
should_training_stop
:
break
self
.
control
=
self
.
callback_handler
.
on_epoch_end
(
args
,
self
.
state
,
self
.
control
)
self
.
_maybe_log_save_evaluate
(
tr_loss
,
model
,
trial
,
epoch
)
self
.
_maybe_log_save_evaluate
(
tr_loss
,
model
,
trial
,
epoch
,
ignore_keys_for_eval
)
if
DebugOption
.
TPU_METRICS_DEBUG
in
self
.
args
.
debug
:
if
is_torch_tpu_available
():
...
...
@@ -1405,7 +1409,7 @@ class Trainer:
if
len
(
load_result
.
unexpected_keys
)
!=
0
:
logger
.
warn
(
f
"There were unexpected keys in the checkpoint model loaded:
{
load_result
.
unexpected_keys
}
."
)
def
_maybe_log_save_evaluate
(
self
,
tr_loss
,
model
,
trial
,
epoch
):
def
_maybe_log_save_evaluate
(
self
,
tr_loss
,
model
,
trial
,
epoch
,
ignore_keys_for_eval
):
if
self
.
control
.
should_log
:
logs
:
Dict
[
str
,
float
]
=
{}
tr_loss_scalar
=
tr_loss
.
item
()
...
...
@@ -1423,7 +1427,7 @@ class Trainer:
metrics
=
None
if
self
.
control
.
should_evaluate
:
metrics
=
self
.
evaluate
()
metrics
=
self
.
evaluate
(
ignore_keys
=
ignore_keys_for_eval
)
self
.
_report_to_hp_search
(
trial
,
epoch
,
metrics
)
if
self
.
control
.
should_save
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment