Unverified Commit 65ee1a43 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

fixing BC in `fill-mask` (wasn't tested in theses test suites (#13540)

apparently).
parent 9d60eebe
......@@ -219,4 +219,7 @@ class FillMaskPipeline(Pipeline):
- **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
- **token** (:obj:`str`) -- The predicted token (to replace the masked one).
"""
return super().__call__(inputs, **kwargs)
outputs = super().__call__(inputs, **kwargs)
if isinstance(inputs, list) and len(inputs) == 1:
return outputs[0]
return outputs
......@@ -186,6 +186,18 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
],
)
outputs = fill_masker([f"This is a {tokenizer.mask_token}"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
)
outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."])
self.assertEqual(
outputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment