Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b369e507
Unverified
Commit
b369e507
authored
May 04, 2023
by
Joao Gante
Committed by
GitHub
May 04, 2023
Browse files
Generate: text generation pipeline no longer emits `max_length` warning when it is not set (#23139)
parent
516dc630
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
56 additions
and
14 deletions
+56
-14
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+1
-1
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+1
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-1
src/transformers/pipelines/text_generation.py
src/transformers/pipelines/text_generation.py
+22
-10
tests/pipelines/test_pipelines_text_generation.py
tests/pipelines/test_pipelines_text_generation.py
+31
-1
No files found.
src/transformers/generation/flax_utils.py
View file @
b369e507
...
...
@@ -385,7 +385,6 @@ class FlaxGenerationMixin:
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
not
has_default_max_length
:
logger
.
warning
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
...
...
@@ -393,6 +392,7 @@ class FlaxGenerationMixin:
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
...
...
src/transformers/generation/tf_utils.py
View file @
b369e507
...
...
@@ -858,7 +858,6 @@ class TFGenerationMixin:
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
not
has_default_max_length
:
logger
.
warning
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
...
...
@@ -866,6 +865,7 @@ class TFGenerationMixin:
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
# If the input length is a tensor (i.e. dynamic length), skip length checks
if
not
isinstance
(
input_ids_seq_length
,
tf
.
Tensor
):
...
...
src/transformers/generation/utils.py
View file @
b369e507
...
...
@@ -1348,7 +1348,6 @@ class GenerationMixin:
UserWarning
,
)
elif
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
not
has_default_max_length
:
logger
.
warning
(
f
"Both `max_new_tokens` (=
{
generation_config
.
max_new_tokens
}
) and `max_length`(="
...
...
@@ -1356,6 +1355,7 @@ class GenerationMixin:
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
...
...
src/transformers/pipelines/text_generation.py
View file @
b369e507
import
copy
import
enum
import
warnings
...
...
@@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline):
prefix_inputs
=
self
.
tokenizer
(
prefix
,
padding
=
False
,
add_special_tokens
=
False
,
return_tensors
=
self
.
framework
)
prefix_length
=
prefix_inputs
[
"input_ids"
].
shape
[
-
1
]
generate_kwargs
[
"
prefix_length
"
]
=
prefix_inputs
[
"input_ids"
].
shape
[
-
1
]
if
"max_new_tokens"
in
generate_kwargs
:
pass
elif
"max_length"
in
generate_kwargs
:
generate_kwargs
[
"max_length"
]
+=
prefix_length
else
:
generate_kwargs
[
"max_length"
]
=
self
.
model
.
config
.
max_length
+
prefix_length
if
"min_length"
in
generate_kwargs
:
generate_kwargs
[
"min_length"
]
+=
prefix_length
if
handle_long_generation
is
not
None
:
if
handle_long_generation
not
in
{
"hole"
}:
raise
ValueError
(
...
...
@@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline):
else
:
in_b
=
input_ids
.
shape
[
0
]
prompt_text
=
model_inputs
.
pop
(
"prompt_text"
)
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
generate_kwargs
=
copy
.
deepcopy
(
generate_kwargs
)
prefix_length
=
generate_kwargs
.
pop
(
"prefix_length"
,
0
)
if
prefix_length
>
0
:
has_max_new_tokens
=
"max_new_tokens"
in
generate_kwargs
or
(
"generation_config"
in
generate_kwargs
and
generate_kwargs
[
"generation_config"
].
max_new_tokens
is
not
None
)
if
not
has_max_new_tokens
:
generate_kwargs
[
"max_length"
]
=
generate_kwargs
.
get
(
"max_length"
)
or
self
.
model
.
config
.
max_length
generate_kwargs
[
"max_length"
]
+=
prefix_length
has_min_new_tokens
=
"min_new_tokens"
in
generate_kwargs
or
(
"generation_config"
in
generate_kwargs
and
generate_kwargs
[
"generation_config"
].
min_new_tokens
is
not
None
)
if
not
has_min_new_tokens
and
"min_length"
in
generate_kwargs
:
generate_kwargs
[
"min_length"
]
+=
prefix_length
# BS x SL
generated_sequence
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
generate_kwargs
)
out_b
=
generated_sequence
.
shape
[
0
]
...
...
tests/pipelines/test_pipelines_text_generation.py
View file @
b369e507
...
...
@@ -14,8 +14,15 @@
import
unittest
from
transformers
import
MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TextGenerationPipeline
,
pipeline
from
transformers
import
(
MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TextGenerationPipeline
,
logging
,
pipeline
,
)
from
transformers.testing_utils
import
(
CaptureLogger
,
is_pipeline_test
,
require_accelerate
,
require_tf
,
...
...
@@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase):
pipe
=
pipeline
(
model
=
"hf-internal-testing/tiny-random-bloom"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
)
pipe
(
"This is a test"
,
do_sample
=
True
,
top_p
=
0.5
)
def
test_pipeline_length_setting_warning
(
self
):
prompt
=
"""Hello world"""
text_generator
=
pipeline
(
"text-generation"
,
model
=
"hf-internal-testing/tiny-random-gpt2"
)
if
text_generator
.
model
.
framework
==
"tf"
:
logger
=
logging
.
get_logger
(
"transformers.generation.tf_utils"
)
else
:
logger
=
logging
.
get_logger
(
"transformers.generation.utils"
)
logger_msg
=
"Both `max_new_tokens`"
# The beggining of the message to be checked in this test
# Both are set by the user -> log warning
with
CaptureLogger
(
logger
)
as
cl
:
_
=
text_generator
(
prompt
,
max_length
=
10
,
max_new_tokens
=
1
)
self
.
assertIn
(
logger_msg
,
cl
.
out
)
# The user only sets one -> no warning
with
CaptureLogger
(
logger
)
as
cl
:
_
=
text_generator
(
prompt
,
max_new_tokens
=
1
)
self
.
assertNotIn
(
logger_msg
,
cl
.
out
)
with
CaptureLogger
(
logger
)
as
cl
:
_
=
text_generator
(
prompt
,
max_length
=
10
)
self
.
assertNotIn
(
logger_msg
,
cl
.
out
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment