Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eb3ded16
"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "4c946d5b2136cf183f2b9be17813e84dd7731c3e"
Unverified
Commit
eb3ded16
authored
Aug 09, 2023
by
Joao Gante
Committed by
GitHub
Aug 09, 2023
Browse files
Generate: lower severity of parameterization checks (#25407)
parent
ef74da65
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
23 deletions
+50
-23
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+50
-23
No files found.
src/transformers/generation/configuration_utils.py
View file @
eb3ded16
...
...
@@ -313,7 +313,7 @@ class GenerationConfig(PushToHubMixin):
raise
err
# Validate the values of the attributes
self
.
validate
()
self
.
validate
(
is_init
=
True
)
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
GenerationConfig
):
...
...
@@ -330,7 +330,7 @@ class GenerationConfig(PushToHubMixin):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
def
validate
(
self
):
def
validate
(
self
,
is_init
=
False
):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
...
...
@@ -344,58 +344,85 @@ class GenerationConfig(PushToHubMixin):
raise
ValueError
(
f
"`early_stopping` must be a boolean or 'never', but is
{
self
.
early_stopping
}
."
)
# Validation of attribute relations:
fix_location
=
""
if
is_init
:
fix_location
=
(
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode
if
self
.
do_sample
is
False
:
greedy_wrong_parameter_msg
=
(
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
"in sample-based generation modes. Set `do_sample=True` or unset {flag_name} to continue."
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
+
fix_location
)
if
self
.
temperature
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"temperature"
,
flag_value
=
self
.
temperature
)
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"temperature"
,
flag_value
=
self
.
temperature
),
UserWarning
,
)
if
self
.
top_p
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_p"
,
flag_value
=
self
.
top_p
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_p"
,
flag_value
=
self
.
top_p
),
UserWarning
,
)
if
self
.
typical_p
!=
1.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"typical_p"
,
flag_value
=
self
.
typical_p
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"typical_p"
,
flag_value
=
self
.
typical_p
),
UserWarning
,
)
if
self
.
top_k
!=
50
and
self
.
penalty_alpha
is
None
:
# contrastive search uses top_k
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_k"
,
flag_value
=
self
.
top_k
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"top_k"
,
flag_value
=
self
.
top_k
),
UserWarning
,
)
if
self
.
epsilon_cutoff
!=
0.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"epsilon_cutoff"
,
flag_value
=
self
.
epsilon_cutoff
)
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"epsilon_cutoff"
,
flag_value
=
self
.
epsilon_cutoff
),
UserWarning
,
)
if
self
.
eta_cutoff
!=
0.0
:
raise
ValueError
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"eta_cutoff"
,
flag_value
=
self
.
eta_cutoff
))
warnings
.
warn
(
greedy_wrong_parameter_msg
.
format
(
flag_name
=
"eta_cutoff"
,
flag_value
=
self
.
eta_cutoff
),
UserWarning
,
)
# 2. detect beam-only parameterization when not in beam mode
if
self
.
num_beams
==
1
:
single_beam_wrong_parameter_msg
=
(
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
"beam-based generation modes.
S
et `num_beams>1` or unset {flag_name}
to continue."
"beam-based generation modes.
You should s
et `num_beams>1` or unset {flag_name}
."
+
fix_location
)
if
self
.
early_stopping
is
not
False
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"early_stopping"
,
flag_value
=
self
.
early_stopping
)
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"early_stopping"
,
flag_value
=
self
.
early_stopping
),
UserWarning
,
)
if
self
.
num_beam_groups
!=
1
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"num_beam_groups"
,
flag_value
=
self
.
num_beam_groups
)
),
UserWarning
,
)
if
self
.
diversity_penalty
!=
0.0
:
raise
ValueError
(
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"diversity_penalty"
,
flag_value
=
self
.
diversity_penalty
)
),
UserWarning
,
)
if
self
.
length_penalty
!=
1.0
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"length_penalty"
,
flag_value
=
self
.
length_penalty
)
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"length_penalty"
,
flag_value
=
self
.
length_penalty
),
UserWarning
,
)
if
self
.
constraints
is
not
None
:
raise
ValueError
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"constraints"
,
flag_value
=
self
.
constraints
)
warnings
.
warn
(
single_beam_wrong_parameter_msg
.
format
(
flag_name
=
"constraints"
,
flag_value
=
self
.
constraints
),
UserWarning
,
)
# 3. detect incorrect paramaterization specific to advanced beam modes
...
...
@@ -405,7 +432,7 @@ class GenerationConfig(PushToHubMixin):
constrained_wrong_parameter_msg
=
(
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
"{flag_name} to continue."
"{flag_name} to continue."
+
fix_location
)
if
self
.
do_sample
is
True
:
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment