Unverified Commit 82486e59 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

🚨🚨 TextGenerationPipeline: rely on the tokenizer default kwargs (#31747)

* rely on the tokenizer default kwargs

* fix a few tests
parent a9701953
...@@ -266,31 +266,33 @@ class TextGenerationPipeline(Pipeline): ...@@ -266,31 +266,33 @@ class TextGenerationPipeline(Pipeline):
prompt_text, prompt_text,
prefix="", prefix="",
handle_long_generation=None, handle_long_generation=None,
add_special_tokens=False, add_special_tokens=None,
truncation=None, truncation=None,
padding=False, padding=None,
max_length=None, max_length=None,
**generate_kwargs, **generate_kwargs,
): ):
if isinstance(prompt_text, Chat): if isinstance(prompt_text, Chat):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
inputs = self.tokenizer.apply_chat_template( inputs = self.tokenizer.apply_chat_template(
prompt_text.messages, prompt_text.messages,
truncation=truncation,
padding=padding,
max_length=max_length,
add_generation_prompt=True, add_generation_prompt=True,
return_dict=True, return_dict=True,
return_tensors=self.framework, return_tensors=self.framework,
**tokenizer_kwargs,
) )
else: else:
inputs = self.tokenizer( # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
prefix + prompt_text, tokenizer_kwargs = {}
truncation=truncation, for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]:
padding=padding, if locals()[tokenizer_kwarg_name] is not None:
max_length=max_length, tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
add_special_tokens=add_special_tokens, inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)
return_tensors=self.framework,
)
inputs["prompt_text"] = prompt_text inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole": if handle_long_generation == "hole":
......
...@@ -2087,6 +2087,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2087,6 +2087,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
[1, 18], [1, 18],
) )
# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
def test_stop_sequence_stopping_criteria(self): def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria # PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in""" prompt = """Hello I believe in"""
...@@ -2094,17 +2095,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2094,17 +2095,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output = generator(prompt) output = generator(prompt)
self.assertEqual( self.assertEqual(
output, output,
[ [{"generated_text": ("Hello I believe in we we we we we we we we we")}],
{
"generated_text": (
"Hello I believe in in in number number number number number number number number number"
)
}
],
) )
output = generator(prompt, stop_sequence=" number") output = generator(prompt, stop_sequence=" we")
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])
def test_generate_non_nlp_input_ids_as_kwarg(self): def test_generate_non_nlp_input_ids_as_kwarg(self):
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
......
...@@ -398,7 +398,7 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -398,7 +398,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
self.assertEqual(outputs, [{"generated_text": ANY(str)}]) self.assertEqual(outputs, [{"generated_text": ANY(str)}])
else: else:
with self.assertRaises((ValueError, AssertionError)): with self.assertRaises((ValueError, AssertionError)):
outputs = text_generator("") outputs = text_generator("", add_special_tokens=False)
if text_generator.framework == "tf": if text_generator.framework == "tf":
# TF generation does not support max_new_tokens, and it's impossible # TF generation does not support max_new_tokens, and it's impossible
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment