Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
afc45b13
Unverified
Commit
afc45b13
authored
Jan 12, 2024
by
Joao Gante
Committed by
GitHub
Jan 12, 2024
Browse files
Generate: refuse to save bad generation config files (#28477)
parent
dc01cf9c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
15 deletions
+10
-15
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+4
-7
tests/generation/test_configuration_utils.py
tests/generation/test_configuration_utils.py
+6
-8
No files found.
src/transformers/generation/configuration_utils.py
View file @
afc45b13
...
...
@@ -551,16 +551,13 @@ class GenerationConfig(PushToHubMixin):
try
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
validate
()
for
w
in
caught_warnings
:
raise
ValueError
(
w
.
message
)
if
len
(
caught_warnings
)
>
0
:
raise
ValueError
(
str
([
w
.
message
for
w
in
caught_warnings
])
)
except
ValueError
as
exc
:
warnings
.
warn
(
raise
ValueError
(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34."
"
\n\n
Thrown during validation:
\n
"
+
str
(
exc
),
UserWarning
,
"Fix these issues to save the configuration.
\n\n
Thrown during validation:
\n
"
+
str
(
exc
)
)
return
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
...
...
tests/generation/test_configuration_utils.py
View file @
afc45b13
...
...
@@ -152,14 +152,13 @@ class GenerationConfigTest(unittest.TestCase):
"""Tests that we refuse to save a generation config that fails validation."""
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a
warning
# is caught, doesn't save, and raises a
n exception
config
=
GenerationConfig
()
config
.
temperature
=
0.5
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
with
self
.
assertRaises
(
ValueError
)
as
exc
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
captured_warnings
),
1
)
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
captured_warnings
[
0
].
message
))
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
exc
.
exception
))
self
.
assertTrue
(
len
(
os
.
listdir
(
tmp_dir
))
==
0
)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
...
...
@@ -167,13 +166,12 @@ class GenerationConfigTest(unittest.TestCase):
config
=
GenerationConfig
()
config
.
num_return_sequences
=
2
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
with
self
.
assertRaises
(
ValueError
)
as
exc
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
captured_warnings
),
1
)
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
captured_warnings
[
0
].
message
))
self
.
assertTrue
(
"Fix these issues to save the configuration."
in
str
(
exc
.
exception
))
self
.
assertTrue
(
len
(
os
.
listdir
(
tmp_dir
))
==
0
)
# final check: no warnings thrown if it is correct, and file is saved
# final check: no warnings
/exceptions
thrown if it is correct, and file is saved
config
=
GenerationConfig
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment