Unverified Commit 65ee1a43 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

fixing BC in `fill-mask` (wasn't tested in theses test suites (#13540)

apparently).
parent 9d60eebe
...@@ -219,4 +219,7 @@ class FillMaskPipeline(Pipeline): ...@@ -219,4 +219,7 @@ class FillMaskPipeline(Pipeline):
- **token** (:obj:`int`) -- The predicted token id (to replace the masked one). - **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
- **token** (:obj:`str`) -- The predicted token (to replace the masked one). - **token** (:obj:`str`) -- The predicted token (to replace the masked one).
""" """
return super().__call__(inputs, **kwargs) outputs = super().__call__(inputs, **kwargs)
if isinstance(inputs, list) and len(inputs) == 1:
return outputs[0]
return outputs
...@@ -186,6 +186,18 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ...@@ -186,6 +186,18 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
], ],
) )
outputs = fill_masker([f"This is a {tokenizer.mask_token}"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
)
outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."]) outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."])
self.assertEqual( self.assertEqual(
outputs, outputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment