Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3deed1f9
Unverified
Commit
3deed1f9
authored
Aug 09, 2023
by
Joao Gante
Committed by
GitHub
Aug 09, 2023
Browse files
Generate: length validation (#25384)
parent
d59b872c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
24 deletions
+51
-24
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+51
-24
No files found.
src/transformers/generation/utils.py
View file @
3deed1f9
...
@@ -1245,6 +1245,52 @@ class GenerationMixin:
...
@@ -1245,6 +1245,52 @@ class GenerationMixin:
" generate arguments will also show up in this list)"
" generate arguments will also show up in this list)"
)
)
def
_validate_generated_length
(
self
,
generation_config
,
input_ids_length
,
has_default_max_length
):
"""Performs validation related to the resulting generated length"""
# 1. Max length warnings related to poor parameterization
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!=
20
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) to control the"
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation."
,
UserWarning
,
)
if
input_ids_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
warnings
.
warn
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
,
UserWarning
,
)
# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix
=
(
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
"increase the maximum length."
)
if
has_default_max_length
:
min_length_error_suffix
+=
(
f
" Note that `max_length` is set to
{
generation_config
.
max_length
}
, its default value."
)
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
warnings
.
warn
(
f
"Unfeasible length constraints: `min_length` (
{
generation_config
.
min_length
}
) is larger than"
f
" the maximum possible length (
{
generation_config
.
max_length
}
)."
+
min_length_error_suffix
,
UserWarning
,
)
if
generation_config
.
min_new_tokens
is
not
None
:
min_length
=
generation_config
.
min_new_tokens
+
input_ids_length
if
min_length
>
generation_config
.
max_length
:
warnings
.
warn
(
f
"Unfeasible length constraints: `min_new_tokens` (
{
generation_config
.
min_new_tokens
}
), when "
f
"added to the prompt length (
{
input_ids_length
}
), is larger than"
f
" the maximum possible length (
{
generation_config
.
max_length
}
)."
+
min_length_error_suffix
,
UserWarning
,
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate
(
def
generate
(
self
,
self
,
...
@@ -1458,16 +1504,9 @@ class GenerationMixin:
...
@@ -1458,16 +1504,9 @@ class GenerationMixin:
streamer
.
put
(
input_ids
.
cpu
())
streamer
.
put
(
input_ids
.
cpu
())
# 6. Prepare `max_length` depending on other stopping criteria.
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_
seq_
length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!=
20
:
if
generation_config
.
max_new_tokens
is
not
None
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
if
not
has_default_max_length
:
if
not
has_default_max_length
:
logger
.
warning
(
logger
.
warning
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
...
@@ -1475,20 +1514,8 @@ class GenerationMixin:
...
@@ -1475,20 +1514,8 @@ class GenerationMixin:
"Please refer to the documentation for more information. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
)
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_length
self
.
_validate_generated_length
(
generation_config
,
input_ids_length
,
has_default_max_length
)
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
f
"Unfeasible length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger than"
f
" the maximum length (
{
generation_config
.
max_length
}
)"
)
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 7. determine generation mode
# 7. determine generation mode
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
...
@@ -1512,7 +1539,7 @@ class GenerationMixin:
...
@@ -1512,7 +1539,7 @@ class GenerationMixin:
# 8. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_
seq_
length
,
input_ids_seq_length
=
input_ids_length
,
encoder_input_ids
=
inputs_tensor
,
encoder_input_ids
=
inputs_tensor
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment