Unverified Commit f8a989cf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `num_return_sequences` support for text2text generation. (#14988)



* Adding `num_return_sequences` support for text2text generation.
Co-Authored-By: default avatarEnze <pu.miao@foxmail.com>

* Update tests/test_pipelines_text2text_generation.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_pipelines_text2text_generation.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarEnze <pu.miao@foxmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c043ce6c
...@@ -157,18 +157,20 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -157,18 +157,20 @@ class Text2TextGenerationPipeline(Pipeline):
return {"output_ids": output_ids} return {"output_ids": output_ids}
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False): def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
record = {} records = []
if return_type == ReturnType.TENSORS: for output_ids in model_outputs["output_ids"]:
record = {f"{self.return_name}_token_ids": model_outputs} if return_type == ReturnType.TENSORS:
elif return_type == ReturnType.TEXT: record = {f"{self.return_name}_token_ids": model_outputs}
record = { elif return_type == ReturnType.TEXT:
f"{self.return_name}_text": self.tokenizer.decode( record = {
model_outputs["output_ids"][0], f"{self.return_name}_text": self.tokenizer.decode(
skip_special_tokens=True, output_ids,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, skip_special_tokens=True,
) clean_up_tokenization_spaces=clean_up_tokenization_spaces,
} )
return record }
records.append(record)
return records
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
......
...@@ -50,6 +50,19 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -50,6 +50,19 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
outputs = generator("Something there", do_sample=False) outputs = generator("Something there", do_sample=False)
self.assertEqual(outputs, [{"generated_text": ""}]) self.assertEqual(outputs, [{"generated_text": ""}])
num_return_sequences = 3
outputs = generator(
"Something there",
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences,
)
target_outputs = [
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide"},
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide"},
{"generated_text": ""},
]
self.assertEqual(outputs, target_outputs)
@require_tf @require_tf
def test_small_model_tf(self): def test_small_model_tf(self):
generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf") generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment