Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eb3ded16
Unverified
Commit
eb3ded16
authored
Aug 09, 2023
by
Joao Gante
Committed by
GitHub
Aug 09, 2023
Browse files
Generate: lower severity of parameterization checks (#25407)
parent
ef74da65
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
23 deletions
+50
-23
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+50
-23
No files found.
src/transformers/generation/configuration_utils.py
View file @
eb3ded16
...
@@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin):
raise
err
raise
err
# Validate the values of the attributes
# Validate the values of the attributes
self
.
validate
()
self
.
validate
(
is_init
=
True
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
GenerationConfig
):
if
not
isinstance
(
other
,
GenerationConfig
):
...
@@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
def
validate
(
self
):
def
validate
(
self
,
is_init
=
False
):
"""
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
of parameterization that can be detected as incorrect from the configuration instance alone.
...
@@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin):
...
@@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin):
raise
ValueError
(
f
"`early_stopping` must be a boolean or 'never', but is
{
self
.
early_stopping
}
."
)
raise
ValueError
(
f
"`early_stopping` must be a boolean or 'never', but is
{
self
.
early_stopping
}
."
)
# Validation of attribute relations:
# Validation of attribute relations:
fix_location
=
""
if
is_init
:
fix_location
=
(
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode
# 1. detect sampling-only parameterization when not in sampling mode
if
self
.
do_sample
is
False
:
if
self
.
do_sample
is
False
:
greedy_wrong_parameter_msg
=
(
greedy_wrong_parameter_msg
=
(
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
+
fix_location
)
)
if
self
.
temperature
!=
1.0
:
if
self
.
temperature
!=
1.0
:
raise
ValueError
(
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"temperature"
,
flag_value
=
self
.
temperature
)
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"temperature"
,
flag_value
=
self
.
temperature
),
UserWarning
,
)
)
if
self
.
top_p
!=
1.0
:
if
self
.
top_p
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_p"
,
flag_value
=
self
.
top_p
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_p"
,
flag_value
=
self
.
top_p
),
UserWarning
,
)
if
self
.
typical_p
!=
1.0
:
if
self
.
typical_p
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"typical_p"
,
flag_value
=
self
.
typical_p
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"typical_p"
,
flag_value
=
self
.
typical_p
),
UserWarning
,
)
if
self
.
top_k
!=
50
and
self
.
penalty_alpha
is
None
:
# contrastive search uses top_k
if
self
.
top_k
!=
50
and
self
.
penalty_alpha
is
None
:
# contrastive search uses top_k
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_k"
,
flag_value
=
self
.
top_k
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_k"
,
flag_value
=
self
.
top_k
),
UserWarning
,
)
if
self
.
epsilon_cutoff
!=
0.0
:
if
self
.
epsilon_cutoff
!=
0.0
:
raise
ValueError
(
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"epsilon_cutoff"
,
flag_value
=
self
.
epsilon_cutoff
)
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"epsilon_cutoff"
,
flag_value
=
self
.
epsilon_cutoff
),
UserWarning
,
)
)
if
self
.
eta_cutoff
!=
0.0
:
if
self
.
eta_cutoff
!=
0.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"eta_cutoff"
,
flag_value
=
self
.
eta_cutoff
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"eta_cutoff"
,
flag_value
=
self
.
eta_cutoff
),
UserWarning
,
)
# 2. detect beam-only parameterization when not in beam mode
# 2. detect beam-only parameterization when not in beam mode
if
self
.
num_beams
==
1
:
if
self
.
num_beams
==
1
:
single_beam_wrong_parameter_msg
=
(
single_beam_wrong_parameter_msg
=
(
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes.
S
et `num_beams>1` or unset {flag_name}
to continue."
"beam-based generation modes.
You should s
et `num_beams>1` or unset {flag_name}
."
+
fix_location
)
)
if
self
.
early_stopping
is
not
False
:
if
self
.
early_stopping
is
not
False
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"early_stopping"
,
flag_value
=
self
.
early_stopping
)
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"early_stopping"
,
flag_value
=
self
.
early_stopping
),
UserWarning
,
)
)
if
self
.
num_beam_groups
!=
1
:
if
self
.
num_beam_groups
!=
1
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"num_beam_groups"
,
flag_value
=
self
.
num_beam_groups
flag_name
=
"num_beam_groups"
,
flag_value
=
self
.
num_beam_groups
)
),
UserWarning
,
)
)
if
self
.
diversity_penalty
!=
0.0
:
if
self
.
diversity_penalty
!=
0.0
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"diversity_penalty"
,
flag_value
=
self
.
diversity_penalty
flag_name
=
"diversity_penalty"
,
flag_value
=
self
.
diversity_penalty
)
),
UserWarning
,
)
)
if
self
.
length_penalty
!=
1.0
:
if
self
.
length_penalty
!=
1.0
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"length_penalty"
,
flag_value
=
self
.
length_penalty
)
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"length_penalty"
,
flag_value
=
self
.
length_penalty
),
UserWarning
,
)
)
if
self
.
constraints
is
not
None
:
if
self
.
constraints
is
not
None
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"constraints"
,
flag_value
=
self
.
constraints
)
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"constraints"
,
flag_value
=
self
.
constraints
),
UserWarning
,
)
)
# 3. detect incorrect paramaterization specific to advanced beam modes
# 3. detect incorrect paramaterization specific to advanced beam modes
...
@@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin):
constrained_wrong_parameter_msg
=
(
constrained_wrong_parameter_msg
=
(
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue."
"{flag_name} to continue."
+
fix_location
)
)
if
self
.
do_sample
is
True
:
if
self
.
do_sample
is
True
:
raise
ValueError
(
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment