Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3deed1f9
Unverified
Commit
3deed1f9
authored
Aug 09, 2023
by
Joao Gante
Committed by
GitHub
Aug 09, 2023
Browse files
Generate: length validation (#25384)
parent
d59b872c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
24 deletions
+51
-24
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+51
-24
No files found.
src/transformers/generation/utils.py
View file @
3deed1f9
...
@@ -1245,6 +1245,52 @@ class GenerationMixin:
...
@@ -1245,6 +1245,52 @@ class GenerationMixin:
" generate arguments will also show up in this list)"
" generate arguments will also show up in this list)"
)
)
def
_validate_generated_length
(
self
,
generation_config
,
input_ids_length
,
has_default_max_length
):
"""Performs validation related to the resulting generated length"""
# 1. Max length warnings related to poor parameterization
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!=
20
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) to control the"
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation."
,
UserWarning
,
)
if
input_ids_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
warnings
.
warn
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
,
UserWarning
,
)
# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix
=
(
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
"increase the maximum length."
)
if
has_default_max_length
:
min_length_error_suffix
+=
(
f
" Note that `max_length` is set to
{
generation_config
.
max_length
}
, its default value."
)
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
warnings
.
warn
(
f
"Unfeasible length constraints: `min_length` (
{
generation_config
.
min_length
}
) is larger than"
f
" the maximum possible length (
{
generation_config
.
max_length
}
)."
+
min_length_error_suffix
,
UserWarning
,
)
if
generation_config
.
min_new_tokens
is
not
None
:
min_length
=
generation_config
.
min_new_tokens
+
input_ids_length
if
min_length
>
generation_config
.
max_length
:
warnings
.
warn
(
f
"Unfeasible length constraints: `min_new_tokens` (
{
generation_config
.
min_new_tokens
}
), when "
f
"added to the prompt length (
{
input_ids_length
}
), is larger than"
f
" the maximum possible length (
{
generation_config
.
max_length
}
)."
+
min_length_error_suffix
,
UserWarning
,
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate
(
def
generate
(
self
,
self
,
...
@@ -1458,16 +1504,9 @@ class GenerationMixin:
...
@@ -1458,16 +1504,9 @@ class GenerationMixin:
streamer
.
put
(
input_ids
.
cpu
())
streamer
.
put
(
input_ids
.
cpu
())
# 6. Prepare `max_length` depending on other stopping criteria.
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_
seq_
length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!=
20
:
if
generation_config
.
max_new_tokens
is
not
None
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
if
not
has_default_max_length
:
if
not
has_default_max_length
:
logger
.
warning
(
logger
.
warning
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
...
@@ -1475,20 +1514,8 @@ class GenerationMixin:
...
@@ -1475,20 +1514,8 @@ class GenerationMixin:
"Please refer to the documentation for more information. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
)
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_length
self
.
_validate_generated_length
(
generation_config
,
input_ids_length
,
has_default_max_length
)
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
f
"Unfeasible length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger than"
f
" the maximum length (
{
generation_config
.
max_length
}
)"
)
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 7. determine generation mode
# 7. determine generation mode
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
...
@@ -1512,7 +1539,7 @@ class GenerationMixin:
...
@@ -1512,7 +1539,7 @@ class GenerationMixin:
# 8. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_
seq_
length
,
input_ids_seq_length
=
input_ids_length
,
encoder_input_ids
=
inputs_tensor
,
encoder_input_ids
=
inputs_tensor
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment