Unverified Commit 3f9cb335 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix default max length warning (#25539)

parent e13d5b60
......@@ -377,7 +377,7 @@ class FlaxGenerationMixin:
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
......
......@@ -829,7 +829,7 @@ class TFGenerationMixin:
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = shape_list(input_ids)[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
......
......@@ -1249,7 +1249,7 @@ class GenerationMixin:
"""Performs validation related to the resulting generated length"""
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the"
......
......@@ -1300,7 +1300,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
logger.warning(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
......
......@@ -16,6 +16,7 @@
import inspect
import unittest
import warnings
import numpy as np
......@@ -2844,3 +2845,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
with self.assertRaises(TypeError):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model.generate(input_ids, foo="bar")
def test_default_max_length_warning(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Default generation config value of 20 -> emits warning
with self.assertWarns(UserWarning):
model.generate(input_ids)
# Explicitly setting max_length to 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
model.generate(input_ids, max_length=20)
self.assertEqual(len(warning_list), 0)
# Generation config max_length != 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
model.generation_config.max_length = 10
model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment