Unverified Commit 8c2618e6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing t2t pipelines lists outputs. (#15008)

Backward compatibility broken in
https://github.com/huggingface/transformers/pull/14988
parent 8f6373c6
...@@ -136,8 +136,8 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -136,8 +136,8 @@ class Text2TextGenerationPipeline(Pipeline):
""" """
result = super().__call__(*args, **kwargs) result = super().__call__(*args, **kwargs)
if isinstance(result, dict): if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]):
return [result] return [res[0] for res in result]
return result return result
def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs): def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
......
...@@ -47,6 +47,12 @@ class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta ...@@ -47,6 +47,12 @@ class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
outputs = translator("Some string") outputs = translator("Some string")
self.assertEqual(outputs, [{"translation_text": ANY(str)}]) self.assertEqual(outputs, [{"translation_text": ANY(str)}])
outputs = translator(["Some string"])
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
outputs = translator(["Some string", "other string"])
self.assertEqual(outputs, [{"translation_text": ANY(str)}, {"translation_text": ANY(str)}])
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt") translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment