Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d31497b1
Unverified
Commit
d31497b1
authored
Jan 31, 2023
by
regisss
Committed by
GitHub
Jan 31, 2023
Browse files
Do not log the generation config for each prediction step in TrainerSeq2Seq (#21385)
Do not log the generation config for each iteration
parent
98d40fed
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
0 deletions
+5
-0
src/transformers/trainer_seq2seq.py
src/transformers/trainer_seq2seq.py
+5
-0
No files found.
src/transformers/trainer_seq2seq.py
View file @
d31497b1
...
...
@@ -199,6 +199,11 @@ class Seq2SeqTrainer(Trainer):
generation_inputs
,
**
gen_kwargs
,
)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
if
self
.
model
.
generation_config
.
_from_model_config
:
self
.
model
.
generation_config
.
_from_model_config
=
False
# in case the batch is shorter than max length, the output should be padded
if
gen_kwargs
.
get
(
"max_length"
)
is
not
None
and
generated_tokens
.
shape
[
-
1
]
<
gen_kwargs
[
"max_length"
]:
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
gen_kwargs
[
"max_length"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment