Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
afc45b13
Unverified
Commit
afc45b13
authored
Jan 12, 2024
by
Joao Gante
Committed by
GitHub
Jan 12, 2024
Browse files
Generate: refuse to save bad generation config files (#28477)
parent
dc01cf9c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
15 deletions
+10
-15
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+4
-7
tests/generation/test_configuration_utils.py
tests/generation/test_configuration_utils.py
+6
-8
No files found.
src/transformers/generation/configuration_utils.py
View file @
afc45b13
...
...
@@ -551,16 +551,13 @@ class GenerationConfig(PushToHubMixin):
try
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
validate
()
for
w
in
caught_warnings
:
raise
ValueError
(
w
.
message
)
if
len
(
caught_warnings
)
>
0
:
raise
ValueError
(
str
([
w
.
message
for
w
in
caught_warnings
])
)
except
ValueError
as
exc
:
warnings
.
warn
(
raise
ValueError
(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34."
"
\n\n
Thrown during validation:
\n
"
+
str
(
exc
),
UserWarning
,
"Fix these issues to save the configuration.
\n\n
Thrown during validation:
\n
"
+
str
(
exc
)
)
return
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
...
...
tests/generation/test_configuration_utils.py
View file @
afc45b13
...
...
@@ -152,14 +152,13 @@ class GenerationConfigTest(unittest.TestCase):
"""Tests that we refuse to save a generation config that fails validation."""
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a
warning
# is caught, doesn't save, and raises a
n exception
config
=
GenerationConfig
()
config
.
temperature
=
0.5
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
with
self
.
assertRaises
(
ValueError
)
as
exc
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
captured_warnings
),
1
)
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
captured_warnings
[
0
].
message
))
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
exc
.
exception
))
self
.
assertTrue
(
len
(
os
.
listdir
(
tmp_dir
))
==
0
)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
...
...
@@ -167,13 +166,12 @@ class GenerationConfigTest(unittest.TestCase):
config
=
GenerationConfig
()
config
.
num_return_sequences
=
2
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
with
self
.
assertRaises
(
ValueError
)
as
exc
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
captured_warnings
),
1
)
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
captured_warnings
[
0
].
message
))
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
exc
.
exception
))
self
.
assertTrue
(
len
(
os
.
listdir
(
tmp_dir
))
==
0
)
# final check: no warnings thrown if it is correct, and file is saved
# final check: no warnings
/exceptions
thrown if it is correct, and file is saved
config
=
GenerationConfig
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment