Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a70d6e5
Unverified
Commit
9a70d6e5
authored
Sep 05, 2023
by
Joao Gante
Committed by
GitHub
Sep 05, 2023
Browse files
Trainer: delegate default generation values to `generation_config` (#25987)
parent
aea76149
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
21 deletions
+32
-21
examples/pytorch/question-answering/trainer_seq2seq_qa.py
examples/pytorch/question-answering/trainer_seq2seq_qa.py
+7
-6
src/transformers/trainer_seq2seq.py
src/transformers/trainer_seq2seq.py
+25
-15
No files found.
examples/pytorch/question-answering/trainer_seq2seq_qa.py
View file @
9a70d6e5
...
...
@@ -46,12 +46,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
**
gen_kwargs
,
)
->
Dict
[
str
,
float
]:
gen_kwargs
=
gen_kwargs
.
copy
()
gen_kwargs
[
"max_length"
]
=
(
gen_kwargs
[
"max_length"
]
if
gen_kwargs
.
get
(
"max_length"
)
is
not
None
else
self
.
args
.
generation_max_length
)
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
args
.
generation_num_beams
)
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
self
.
args
.
generation_max_length
is
not
None
:
gen_kwargs
[
"max_length"
]
=
self
.
args
.
generation_max_length
if
gen_kwargs
.
get
(
"num_beams"
)
is
None
and
self
.
args
.
generation_num_beams
is
not
None
:
gen_kwargs
[
"num_beams"
]
=
self
.
args
.
generation_num_beams
self
.
_gen_kwargs
=
gen_kwargs
eval_dataset
=
self
.
eval_dataset
if
eval_dataset
is
None
else
eval_dataset
...
...
src/transformers/trainer_seq2seq.py
View file @
9a70d6e5
...
...
@@ -149,11 +149,17 @@ class Seq2SeqTrainer(Trainer):
"""
gen_kwargs
=
gen_kwargs
.
copy
()
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
:
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if
(
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
and
self
.
args
.
generation_max_length
is
not
None
):
gen_kwargs
[
"max_length"
]
=
self
.
args
.
generation_max_length
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
args
.
generation_num_beams
)
if
gen_kwargs
.
get
(
"num_beams"
)
is
None
and
self
.
args
.
generation_num_beams
is
not
None
:
gen_kwargs
[
"num_beams"
]
=
self
.
args
.
generation_num_beams
self
.
_gen_kwargs
=
gen_kwargs
return
super
().
evaluate
(
eval_dataset
,
ignore_keys
=
ignore_keys
,
metric_key_prefix
=
metric_key_prefix
)
...
...
@@ -206,11 +212,17 @@ class Seq2SeqTrainer(Trainer):
"""
gen_kwargs
=
gen_kwargs
.
copy
()
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
:
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if
(
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
and
self
.
args
.
generation_max_length
is
not
None
):
gen_kwargs
[
"max_length"
]
=
self
.
args
.
generation_max_length
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
args
.
generation_num_beams
)
if
gen_kwargs
.
get
(
"num_beams"
)
is
None
and
self
.
args
.
generation_num_beams
is
not
None
:
gen_kwargs
[
"num_beams"
]
=
self
.
args
.
generation_num_beams
self
.
_gen_kwargs
=
gen_kwargs
return
super
().
predict
(
test_dataset
,
ignore_keys
=
ignore_keys
,
metric_key_prefix
=
metric_key_prefix
)
...
...
@@ -256,16 +268,14 @@ class Seq2SeqTrainer(Trainer):
# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
if
len
(
gen_kwargs
)
==
0
and
hasattr
(
self
,
"_gen_kwargs"
):
gen_kwargs
=
self
.
_gen_kwargs
.
copy
()
if
"num_beams"
in
gen_kwargs
and
gen_kwargs
[
"num_beams"
]
is
None
:
gen_kwargs
.
pop
(
"num_beams"
)
if
"max_length"
in
gen_kwargs
and
gen_kwargs
[
"max_length"
]
is
None
:
gen_kwargs
.
pop
(
"max_length"
)
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
:
gen_kwargs
[
"max_length"
]
=
self
.
model
.
config
.
max_length
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
model
.
config
.
num_beams
)
default_synced_gpus
=
True
if
is_deepspeed_zero3_enabled
()
else
False
gen_kwargs
[
"synced_gpus"
]
=
(
gen_kwargs
[
"synced_gpus"
]
if
gen_kwargs
.
get
(
"synced_gpus"
)
is
not
None
else
default_synced_gpus
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment