Unverified Commit 42b60f8b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Relaxed `max_length` and `max_new_tokens` coexistence (#21347)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 6eb3c66a
......@@ -63,14 +63,12 @@ class GenerationConfig(PushToHubMixin):
max_length (`int`, *optional*, defaults to 20):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in the
prompt.
`max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 0):
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
`min_new_tokens`. In general, prefer the use of `min_new_tokens`, which ignores the number of tokens in the
prompt.
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`):
......
......@@ -318,20 +318,20 @@ class FlaxGenerationMixin:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif has_default_max_length and generation_config.max_new_tokens is not None:
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
......
......@@ -700,20 +700,20 @@ class TFGenerationMixin:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif has_default_max_length and generation_config.max_new_tokens is not None:
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
......
......@@ -1274,20 +1274,20 @@ class GenerationMixin:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif has_default_max_length and generation_config.max_new_tokens is not None:
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
......
......@@ -2178,10 +2178,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
......@@ -2212,12 +2208,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......@@ -2250,12 +2240,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
article = """Justin Timberlake."""
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
......@@ -2279,10 +2263,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......@@ -2306,10 +2286,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......@@ -2333,10 +2309,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment