Unverified Commit 42b60f8b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Relaxed `max_length` and `max_new_tokens` coexistence (#21347)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 6eb3c66a
...@@ -63,14 +63,12 @@ class GenerationConfig(PushToHubMixin): ...@@ -63,14 +63,12 @@ class GenerationConfig(PushToHubMixin):
max_length (`int`, *optional*, defaults to 20): max_length (`int`, *optional*, defaults to 20):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt + The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in the `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
prompt.
max_new_tokens (`int`, *optional*): max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 0): min_length (`int`, *optional*, defaults to 0):
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
`min_new_tokens`. In general, prefer the use of `min_new_tokens`, which ignores the number of tokens in the `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
prompt.
min_new_tokens (`int`, *optional*): min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`): early_stopping (`bool`, *optional*, defaults to `False`):
......
...@@ -318,21 +318,21 @@ class FlaxGenerationMixin: ...@@ -318,21 +318,21 @@ class FlaxGenerationMixin:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn( warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.", " recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif has_default_max_length and generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None: if not has_default_max_length:
raise ValueError( logger.warn(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
" limit to the generated output length. Remove one of those arguments. Please refer to the" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
" documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
) UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError( raise ValueError(
......
...@@ -700,21 +700,21 @@ class TFGenerationMixin: ...@@ -700,21 +700,21 @@ class TFGenerationMixin:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn( warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.", " recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif has_default_max_length and generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None: if not has_default_max_length:
raise ValueError( logger.warn(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
" limit to the generated output length. Remove one of those arguments. Please refer to the" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
" documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
) UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError( raise ValueError(
......
...@@ -1274,21 +1274,21 @@ class GenerationMixin: ...@@ -1274,21 +1274,21 @@ class GenerationMixin:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn( warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.", " recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif has_default_max_length and generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None: if not has_default_max_length:
raise ValueError( logger.warn(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
" limit to the generated output length. Remove one of those arguments. Please refer to the" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
" documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
) UserWarning,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError( raise ValueError(
......
...@@ -2178,10 +2178,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2178,10 +2178,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens # 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24]) self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only_contrastive_search_t5(self): def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
...@@ -2212,12 +2208,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2212,12 +2208,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens # 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24]) self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_bart(self): def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
...@@ -2250,12 +2240,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2250,12 +2240,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens # 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24]) self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self): def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
article = """Justin Timberlake.""" article = """Justin Timberlake."""
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj") gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
...@@ -2279,10 +2263,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2279,10 +2263,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens # 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24]) self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self): def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
article = """Justin Timberlake.""" article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
...@@ -2306,10 +2286,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2306,10 +2286,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens # 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24]) self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only(self): def test_max_new_tokens_decoder_only(self):
article = """Justin Timberlake.""" article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
...@@ -2333,10 +2309,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2333,10 +2309,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens # 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24]) self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self): def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment