Unverified Commit b369e507 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: text generation pipeline no longer emits `max_length` warning when it is not set (#23139)

parent 516dc630
...@@ -385,7 +385,6 @@ class FlaxGenerationMixin: ...@@ -385,7 +385,6 @@ class FlaxGenerationMixin:
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length: if not has_default_max_length:
logger.warning( logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
...@@ -393,6 +392,7 @@ class FlaxGenerationMixin: ...@@ -393,6 +392,7 @@ class FlaxGenerationMixin:
"Please refer to the documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError( raise ValueError(
......
...@@ -858,7 +858,6 @@ class TFGenerationMixin: ...@@ -858,7 +858,6 @@ class TFGenerationMixin:
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length: if not has_default_max_length:
logger.warning( logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
...@@ -866,6 +865,7 @@ class TFGenerationMixin: ...@@ -866,6 +865,7 @@ class TFGenerationMixin:
"Please refer to the documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
# If the input length is a tensor (i.e. dynamic length), skip length checks # If the input length is a tensor (i.e. dynamic length), skip length checks
if not isinstance(input_ids_seq_length, tf.Tensor): if not isinstance(input_ids_seq_length, tf.Tensor):
......
...@@ -1348,7 +1348,6 @@ class GenerationMixin: ...@@ -1348,7 +1348,6 @@ class GenerationMixin:
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length: if not has_default_max_length:
logger.warning( logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
...@@ -1356,6 +1355,7 @@ class GenerationMixin: ...@@ -1356,6 +1355,7 @@ class GenerationMixin:
"Please refer to the documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError( raise ValueError(
......
import copy
import enum import enum
import warnings import warnings
...@@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline): ...@@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline):
prefix_inputs = self.tokenizer( prefix_inputs = self.tokenizer(
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
) )
prefix_length = prefix_inputs["input_ids"].shape[-1] generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
pass
elif "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length
else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
if handle_long_generation is not None: if handle_long_generation is not None:
if handle_long_generation not in {"hole"}: if handle_long_generation not in {"hole"}:
raise ValueError( raise ValueError(
...@@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline): ...@@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline):
else: else:
in_b = input_ids.shape[0] in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text") prompt_text = model_inputs.pop("prompt_text")
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL # BS x SL
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0] out_b = generated_sequence.shape[0]
......
...@@ -14,8 +14,15 @@ ...@@ -14,8 +14,15 @@
import unittest import unittest
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TextGenerationPipeline,
logging,
pipeline,
)
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger,
is_pipeline_test, is_pipeline_test,
require_accelerate, require_accelerate,
require_tf, require_tf,
...@@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase):
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16) pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
pipe("This is a test", do_sample=True, top_p=0.5) pipe("This is a test", do_sample=True, top_p=0.5)
def test_pipeline_length_setting_warning(self):
prompt = """Hello world"""
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2")
if text_generator.model.framework == "tf":
logger = logging.get_logger("transformers.generation.tf_utils")
else:
logger = logging.get_logger("transformers.generation.utils")
logger_msg = "Both `max_new_tokens`" # The beggining of the message to be checked in this test
# Both are set by the user -> log warning
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10, max_new_tokens=1)
self.assertIn(logger_msg, cl.out)
# The user only sets one -> no warning
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_new_tokens=1)
self.assertNotIn(logger_msg, cl.out)
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment