Unverified Commit abf8f54a authored by Daniel Korat's avatar Daniel Korat Committed by GitHub
Browse files

️ Raise `Exception` when trying to generate 0 tokens ️ (#28621)



* change warning to exception

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* validate `max_new_tokens` > 0 in `GenerationConfig`

* fix truncation test parameterization in `TextGenerationPipelineTests`

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 349a6e85
...@@ -373,6 +373,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -373,6 +373,8 @@ class GenerationConfig(PushToHubMixin):
# Validation of individual attributes # Validation of individual attributes
if self.early_stopping not in {True, False, "never"}: if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
# Validation of attribute relations: # Validation of attribute relations:
fix_location = "" fix_location = ""
......
...@@ -1138,11 +1138,10 @@ class GenerationMixin: ...@@ -1138,11 +1138,10 @@ class GenerationMixin:
) )
if input_ids_length >= generation_config.max_length: if input_ids_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
warnings.warn( raise ValueError(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`.", " increasing `max_length` or, better yet, setting `max_new_tokens`."
UserWarning,
) )
# 2. Min length warnings due to unfeasible parameter combinations # 2. Min length warnings due to unfeasible parameter combinations
......
...@@ -93,17 +93,19 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -93,17 +93,19 @@ class TextGenerationPipelineTests(unittest.TestCase):
## -- test tokenizer_kwargs ## -- test tokenizer_kwargs
test_str = "testing tokenizer kwargs. using truncation must result in a different generation." test_str = "testing tokenizer kwargs. using truncation must result in a different generation."
input_len = len(text_generator.tokenizer(test_str)["input_ids"])
output_str, output_str_with_truncation = ( output_str, output_str_with_truncation = (
text_generator(test_str, do_sample=False, return_full_text=False)[0]["generated_text"], text_generator(test_str, do_sample=False, return_full_text=False, min_new_tokens=1)[0]["generated_text"],
text_generator( text_generator(
test_str, test_str,
do_sample=False, do_sample=False,
return_full_text=False, return_full_text=False,
min_new_tokens=1,
truncation=True, truncation=True,
max_length=3, max_length=input_len + 1,
)[0]["generated_text"], )[0]["generated_text"],
) )
assert output_str != output_str_with_truncation # results must be different because one hd truncation assert output_str != output_str_with_truncation # results must be different because one had truncation
# -- what is the point of this test? padding is hardcoded False in the pipeline anyway # -- what is the point of this test? padding is hardcoded False in the pipeline anyway
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment