Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5bd8c011
Unverified
Commit
5bd8c011
authored
Aug 08, 2023
by
Joao Gante
Committed by
GitHub
Aug 08, 2023
Browse files
Generate: add config-level validation (#25381)
parent
9e57e0c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
111 additions
and
61 deletions
+111
-61
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+111
-2
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+0
-59
No files found.
src/transformers/generation/configuration_utils.py
View file @
5bd8c011
...
...
@@ -332,12 +332,121 @@ class GenerationConfig(PushToHubMixin):
def
validate
(
self
):
"""
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
the values are invalid.
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
model, such as parameters related to the generation length.
"""
# Validation of individual attributes
if
self
.
early_stopping
not
in
{
True
,
False
,
"never"
}:
raise
ValueError
(
f
"`early_stopping` must be a boolean or 'never', but is
{
self
.
early_stopping
}
."
)
# Validation of attribute relations:
# 1. detect sampling-only parameterization when not in sampling mode
if
self
.
do_sample
is
False
:
greedy_wrong_parameter_msg
=
(
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
)
if
self
.
temperature
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"temperature"
,
flag_value
=
self
.
temperature
)
)
if
self
.
top_p
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_p"
,
flag_value
=
self
.
top_p
))
if
self
.
typical_p
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"typical_p"
,
flag_value
=
self
.
typical_p
))
if
self
.
top_k
!=
50
and
self
.
penalty_alpha
is
None
:
# contrastive search uses top_k
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_k"
,
flag_value
=
self
.
top_k
))
if
self
.
epsilon_cutoff
!=
0.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"epsilon_cutoff"
,
flag_value
=
self
.
epsilon_cutoff
)
)
if
self
.
eta_cutoff
!=
0.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"eta_cutoff"
,
flag_value
=
self
.
eta_cutoff
))
# 2. detect beam-only parameterization when not in beam mode
if
self
.
num_beams
==
1
:
single_beam_wrong_parameter_msg
=
(
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes. Set `num_beams>1` or unset {flag_name} to continue."
)
if
self
.
early_stopping
is
not
False
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"early_stopping"
,
flag_value
=
self
.
early_stopping
)
)
if
self
.
num_beam_groups
!=
1
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"num_beam_groups"
,
flag_value
=
self
.
num_beam_groups
)
)
if
self
.
diversity_penalty
!=
0.0
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"diversity_penalty"
,
flag_value
=
self
.
diversity_penalty
)
)
if
self
.
length_penalty
!=
1.0
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"length_penalty"
,
flag_value
=
self
.
length_penalty
)
)
if
self
.
constraints
is
not
None
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"constraints"
,
flag_value
=
self
.
constraints
)
)
# 3. detect incorrect paramaterization specific to advanced beam modes
else
:
# constrained beam search
if
self
.
constraints
is
not
None
:
constrained_wrong_parameter_msg
=
(
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue."
)
if
self
.
do_sample
is
True
:
raise
ValueError
(
constrained_wrong_parameter_msg
.
format
(
flag_name
=
"do_sample"
,
flag_value
=
self
.
do_sample
)
)
if
self
.
num_beam_groups
!=
1
:
raise
ValueError
(
constrained_wrong_parameter_msg
.
format
(
flag_name
=
"num_beam_groups"
,
flag_value
=
self
.
num_beam_groups
)
)
# group beam search
if
self
.
diversity_penalty
!=
0.0
or
self
.
num_beam_groups
!=
1
:
group_error_prefix
=
(
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if
self
.
do_sample
is
True
:
raise
ValueError
(
group_error_prefix
+
"`do_sample` must be set to `False`"
)
if
self
.
num_beams
%
self
.
num_beam_groups
!=
0
:
raise
ValueError
(
group_error_prefix
+
"`num_beams` should be divisible by `num_beam_groups`"
)
if
self
.
diversity_penalty
==
0.0
:
raise
ValueError
(
group_error_prefix
+
"`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# 4. check `num_return_sequences`
if
self
.
num_return_sequences
!=
1
:
if
self
.
num_beams
==
1
:
if
self
.
do_sample
is
False
:
raise
ValueError
(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f
"(got
{
self
.
num_return_sequences
}
)."
)
elif
self
.
num_return_sequences
>
self
.
num_beams
:
raise
ValueError
(
f
"`num_return_sequences` (
{
self
.
num_return_sequences
}
) has to be smaller or equal to `num_beams` "
f
"(
{
self
.
num_beams
}
)."
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
...
...
src/transformers/generation/utils.py
View file @
5bd8c011
...
...
@@ -1493,13 +1493,6 @@ class GenerationMixin:
# 7. determine generation mode
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
if
generation_config
.
num_beam_groups
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
if
generation_mode
==
GenerationMode
.
GROUP_BEAM_SEARCH
and
generation_config
.
do_sample
is
True
:
raise
ValueError
(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
if
streamer
is
not
None
and
(
generation_config
.
num_beams
>
1
):
raise
ValueError
(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
...
...
@@ -1572,12 +1565,6 @@ class GenerationMixin:
**
model_kwargs
,
)
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing greedy search, "
f
"but is
{
generation_config
.
num_return_sequences
}
."
)
# 11. run greedy search
return
self
.
greedy_search
(
input_ids
,
...
...
@@ -1593,11 +1580,6 @@ class GenerationMixin:
)
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing contrastive search, "
f
"but is
{
generation_config
.
num_return_sequences
}
."
)
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
...
...
@@ -1645,12 +1627,6 @@ class GenerationMixin:
)
elif
generation_mode
==
GenerationMode
.
BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if
stopping_criteria
.
max_length
is
None
:
raise
ValueError
(
"`max_length` needs to be a stopping_criteria for now."
)
# 11. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
...
...
@@ -1686,8 +1662,6 @@ class GenerationMixin:
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
if
stopping_criteria
.
max_length
is
None
:
raise
ValueError
(
"`max_length` needs to be a stopping_criteria for now."
)
# 12. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
*
generation_config
.
num_return_sequences
,
...
...
@@ -1722,24 +1696,6 @@ class GenerationMixin:
)
elif
generation_mode
==
GenerationMode
.
GROUP_BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if
generation_config
.
num_beams
%
generation_config
.
num_beam_groups
!=
0
:
raise
ValueError
(
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
)
if
generation_config
.
diversity_penalty
==
0.0
:
raise
ValueError
(
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
)
if
stopping_criteria
.
max_length
is
None
:
raise
ValueError
(
"`max_length` needs to be a stopping_criteria for now."
)
has_default_typical_p
=
kwargs
.
get
(
"typical_p"
)
is
None
and
generation_config
.
typical_p
==
1.0
if
not
has_default_typical_p
:
raise
ValueError
(
"Decoder argument `typical_p` is not supported with beam groups."
)
# 11. prepare beam search scorer
beam_scorer
=
BeamSearchScorer
(
batch_size
=
batch_size
,
...
...
@@ -1773,21 +1729,6 @@ class GenerationMixin:
)
elif
generation_mode
==
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if
stopping_criteria
.
max_length
is
None
:
raise
ValueError
(
"`max_length` needs to be a stopping_criteria for now."
)
if
generation_config
.
num_beams
<=
1
:
raise
ValueError
(
"`num_beams` needs to be greater than 1 for constrained generation."
)
if
generation_config
.
do_sample
:
raise
ValueError
(
"`do_sample` needs to be false for constrained generation."
)
if
generation_config
.
num_beam_groups
is
not
None
and
generation_config
.
num_beam_groups
>
1
:
raise
ValueError
(
"`num_beam_groups` not supported yet for constrained generation."
)
final_constraints
=
[]
if
generation_config
.
constraints
is
not
None
:
final_constraints
=
generation_config
.
constraints
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment