Unverified Commit 3deed1f9 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: length validation (#25384)

parent d59b872c
...@@ -1245,6 +1245,52 @@ class GenerationMixin: ...@@ -1245,6 +1245,52 @@ class GenerationMixin:
" generate arguments will also show up in this list)" " generate arguments will also show up in this list)"
) )
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the"
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation.",
UserWarning,
)
if input_ids_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
warnings.warn(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`.",
UserWarning,
)
# 2. Min length warnings due to unfeasible parameter combinations
min_length_error_suffix = (
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
"increase the maximum length."
)
if has_default_max_length:
min_length_error_suffix += (
f" Note that `max_length` is set to {generation_config.max_length}, its default value."
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
if generation_config.min_new_tokens is not None:
min_length = generation_config.min_new_tokens + input_ids_length
if min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
f"added to the prompt length ({input_ids_length}), is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -1458,16 +1504,9 @@ class GenerationMixin: ...@@ -1458,16 +1504,9 @@ class GenerationMixin:
streamer.put(input_ids.cpu()) streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria. # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20: if generation_config.max_new_tokens is not None:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length: if not has_default_max_length:
logger.warning( logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
...@@ -1475,20 +1514,8 @@ class GenerationMixin: ...@@ -1475,20 +1514,8 @@ class GenerationMixin:
"Please refer to the documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 7. determine generation mode # 7. determine generation mode
generation_mode = self._get_generation_mode(generation_config, assistant_model) generation_mode = self._get_generation_mode(generation_config, assistant_model)
...@@ -1512,7 +1539,7 @@ class GenerationMixin: ...@@ -1512,7 +1539,7 @@ class GenerationMixin:
# 8. prepare distribution pre_processing samplers # 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor, logits_processor=logits_processor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment