"docs/vscode:/vscode.git/clone" did not exist on "788730c67055fb0805051c638be266f0d8f18188"
Unverified Commit f8a989cf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `num_return_sequences` support for text2text generation. (#14988)



* Adding `num_return_sequences` support for text2text generation.
Co-Authored-By: default avatarEnze <pu.miao@foxmail.com>

* Update tests/test_pipelines_text2text_generation.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_pipelines_text2text_generation.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarEnze <pu.miao@foxmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c043ce6c
......@@ -157,18 +157,20 @@ class Text2TextGenerationPipeline(Pipeline):
return {"output_ids": output_ids}
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
record = {}
records = []
for output_ids in model_outputs["output_ids"]:
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
elif return_type == ReturnType.TEXT:
record = {
f"{self.return_name}_text": self.tokenizer.decode(
model_outputs["output_ids"][0],
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
}
return record
records.append(record)
return records
@add_end_docstrings(PIPELINE_INIT_ARGS)
......
......@@ -50,6 +50,19 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
outputs = generator("Something there", do_sample=False)
self.assertEqual(outputs, [{"generated_text": ""}])
num_return_sequences = 3
outputs = generator(
"Something there",
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences,
)
target_outputs = [
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide"},
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide"},
{"generated_text": ""},
]
self.assertEqual(outputs, target_outputs)
@require_tf
def test_small_model_tf(self):
generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment