Unverified Commit 8c2618e6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing t2t pipelines lists outputs. (#15008)

Backward compatibility broken in
https://github.com/huggingface/transformers/pull/14988
parent 8f6373c6
......@@ -136,8 +136,8 @@ class Text2TextGenerationPipeline(Pipeline):
"""
result = super().__call__(*args, **kwargs)
if isinstance(result, dict):
return [result]
if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]):
return [res[0] for res in result]
return result
def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
......
......@@ -47,6 +47,12 @@ class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
outputs = translator("Some string")
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
outputs = translator(["Some string"])
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
outputs = translator(["Some string", "other string"])
self.assertEqual(outputs, [{"translation_text": ANY(str)}, {"translation_text": ANY(str)}])
@require_torch
def test_small_model_pt(self):
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment