Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
510270af
Unverified
Commit
510270af
authored
Nov 30, 2023
by
Joao Gante
Committed by
GitHub
Nov 30, 2023
Browse files
Generate: `GenerationConfig` throws an exception when `generate` args are passed (#27757)
parent
fe41647a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
0 deletions
+46
-0
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+18
-0
tests/generation/test_configuration_utils.py
tests/generation/test_configuration_utils.py
+28
-0
No files found.
src/transformers/generation/configuration_utils.py
View file @
510270af
...
...
@@ -497,6 +497,24 @@ class GenerationConfig(PushToHubMixin):
f
"(
{
self
.
num_beams
}
)."
)
# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments
=
(
"logits_processor"
,
"stopping_criteria"
,
"prefix_allowed_tokens_fn"
,
"synced_gpus"
,
"assistant_model"
,
"streamer"
,
"negative_prompt_ids"
,
"negative_prompt_attention_mask"
,
)
for
arg
in
generate_arguments
:
if
hasattr
(
self
,
arg
):
raise
ValueError
(
f
"Argument `
{
arg
}
` is not a valid argument of `GenerationConfig`. It should be passed to "
"`generate()` (or a pipeline) directly."
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
...
...
tests/generation/test_configuration_utils.py
View file @
510270af
...
...
@@ -120,6 +120,34 @@ class GenerationConfigTest(unittest.TestCase):
self
.
assertEqual
(
loaded_config
.
do_sample
,
True
)
self
.
assertEqual
(
loaded_config
.
num_beams
,
1
)
# default value
def
test_validate
(
self
):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
GenerationConfig
()
self
.
assertEqual
(
len
(
captured_warnings
),
0
)
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
GenerationConfig
(
temperature
=
0.5
)
self
.
assertEqual
(
len
(
captured_warnings
),
1
)
# Case 3: Impossible sets of contraints/parameters will raise an exception
with
self
.
assertRaises
(
ValueError
):
GenerationConfig
(
num_return_sequences
=
2
)
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
with
self
.
assertRaises
(
ValueError
):
GenerationConfig
(
logits_processor
=
"foo"
)
# Case 5: Model-specific parameters will NOT raise an exception or a warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
captured_warnings
:
GenerationConfig
(
foo
=
"bar"
)
self
.
assertEqual
(
len
(
captured_warnings
),
0
)
def
test_refuse_to_save
(
self
):
"""Tests that we refuse to save a generation config that fails validation."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment