Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c47fcd08
Unverified
Commit
c47fcd08
authored
Mar 15, 2024
by
Joao Gante
Committed by
GitHub
Mar 15, 2024
Browse files
Trainer: fail early in the presence of an unsavable `generation_config` (#29675)
parent
f62407f7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
19 deletions
+53
-19
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+2
-1
src/transformers/trainer_seq2seq.py
src/transformers/trainer_seq2seq.py
+32
-18
tests/trainer/test_trainer_seq2seq.py
tests/trainer/test_trainer_seq2seq.py
+19
-0
No files found.
src/transformers/generation/configuration_utils.py
View file @
c47fcd08
...
@@ -652,7 +652,8 @@ class GenerationConfig(PushToHubMixin):
...
@@ -652,7 +652,8 @@ class GenerationConfig(PushToHubMixin):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
"""
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
# This strictness is enforced to prevent bad configurations from being saved and re-used.
try
:
try
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
validate
()
self
.
validate
()
...
...
src/transformers/trainer_seq2seq.py
View file @
c47fcd08
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
from
copy
import
deepcopy
from
copy
import
deepcopy
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -88,8 +89,8 @@ class Seq2SeqTrainer(Trainer):
...
@@ -88,8 +89,8 @@ class Seq2SeqTrainer(Trainer):
# GenerationConfig provided, nothing to do
# GenerationConfig provided, nothing to do
if
isinstance
(
gen_config_arg
,
GenerationConfig
):
if
isinstance
(
gen_config_arg
,
GenerationConfig
):
return
deepcopy
(
gen_config_arg
)
gen_config
=
deepcopy
(
gen_config_arg
)
else
:
# str or Path
# str or Path
pretrained_model_name
=
Path
(
gen_config_arg
)
if
isinstance
(
gen_config_arg
,
str
)
else
gen_config_arg
pretrained_model_name
=
Path
(
gen_config_arg
)
if
isinstance
(
gen_config_arg
,
str
)
else
gen_config_arg
config_file_name
=
None
config_file_name
=
None
...
@@ -107,6 +108,19 @@ class Seq2SeqTrainer(Trainer):
...
@@ -107,6 +108,19 @@ class Seq2SeqTrainer(Trainer):
pretrained_model_name
=
gen_config_arg
pretrained_model_name
=
gen_config_arg
gen_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name
,
config_file_name
)
gen_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name
,
config_file_name
)
# Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
# an exception if there are warnings at validation time.
try
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
gen_config
.
validate
()
if
len
(
caught_warnings
)
>
0
:
raise
ValueError
(
str
([
w
.
message
for
w
in
caught_warnings
]))
except
ValueError
as
exc
:
raise
ValueError
(
"The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings "
"and/or exceptions. Fix these issues to train your model.
\n\n
Thrown during validation:
\n
"
+
str
(
exc
)
)
return
gen_config
return
gen_config
def
evaluate
(
def
evaluate
(
...
...
tests/trainer/test_trainer_seq2seq.py
View file @
c47fcd08
...
@@ -181,3 +181,22 @@ class Seq2seqTrainerTester(TestCasePlus):
...
@@ -181,3 +181,22 @@ class Seq2seqTrainerTester(TestCasePlus):
assert
(
assert
(
metrics
[
"eval_samples"
]
==
dataset_len
*
num_return_sequences
metrics
[
"eval_samples"
]
==
dataset_len
*
num_return_sequences
),
f
"Got
{
metrics
[
'eval_samples'
]
}
, expected:
{
dataset_len
*
num_return_sequences
}
"
),
f
"Got
{
metrics
[
'eval_samples'
]
}
, expected:
{
dataset_len
*
num_return_sequences
}
"
@
require_torch
def
test_bad_generation_config_fail_early
(
self
):
# Tests that a bad geneartion config causes the trainer to fail early
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"google-t5/t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"google-t5/t5-small"
)
data_collator
=
DataCollatorForSeq2Seq
(
tokenizer
,
model
=
model
,
return_tensors
=
"pt"
,
padding
=
"longest"
)
gen_config
=
GenerationConfig
(
do_sample
=
False
,
top_p
=
0.9
)
# bad: top_p is not compatible with do_sample=False
training_args
=
Seq2SeqTrainingArguments
(
"."
,
predict_with_generate
=
True
,
generation_config
=
gen_config
)
with
self
.
assertRaises
(
ValueError
)
as
exc
:
_
=
Seq2SeqTrainer
(
model
=
model
,
args
=
training_args
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
compute_metrics
=
lambda
x
:
{
"samples"
:
x
[
0
].
shape
[
0
]},
)
self
.
assertIn
(
"The loaded generation config instance is invalid"
,
str
(
exc
.
exception
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment