Unverified Commit 781e4b13 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `skip_special_tokens=True` to FillMaskPipeline (#9783)

* We most likely don't want special tokens in this output.

* Adding `skip_special_tokens=True` to FillMaskPipeline

- It's backward incompatible.
- It makes for sense for pipelines to remove references to
special_tokens (all of the other pipelines do that).
- Keeping special tokens makes it hard for users to actually remove them
  because all models have different tokens (<s>, <cls>, [CLS], ....)

* Fixing `token_str` in the same vein, and actually fix the tests too !
parent 1867d9a8
...@@ -179,10 +179,10 @@ class FillMaskPipeline(Pipeline): ...@@ -179,10 +179,10 @@ class FillMaskPipeline(Pipeline):
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append( result.append(
{ {
"sequence": self.tokenizer.decode(tokens), "sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": v, "score": v,
"token": p, "token": p,
"token_str": self.tokenizer.convert_ids_to_tokens(p), "token_str": self.tokenizer.decode(p),
} }
) )
......
...@@ -22,32 +22,27 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin ...@@ -22,32 +22,27 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
EXPECTED_FILL_MASK_RESULT = [ EXPECTED_FILL_MASK_RESULT = [
[ [
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"}, {"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
{"sequence": "<s>My name is Chris</s>", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"}, {"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
], ],
[
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
],
]
EXPECTED_FILL_MASK_TARGET_RESULT = [
[ [
{ {
"sequence": "<s>My name is Patrick</s>", "sequence": "The largest city in France is Paris",
"score": 0.004992353264242411, "score": 0.2510891854763031,
"token": 3499, "token": 2201,
"token_str": "ĠPatrick", "token_str": " Paris",
}, },
{ {
"sequence": "<s>My name is Clara</s>", "sequence": "The largest city in France is Lyon",
"score": 0.00019297805556561798, "score": 0.21418564021587372,
"token": 13606, "token": 12790,
"token_str": "ĠClara", "token_str": " Lyon",
}, },
] ],
] ]
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "fill-mask" pipeline_task = "fill-mask"
...@@ -138,7 +133,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ...@@ -138,7 +133,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
self.assertIsInstance(multi_result[0], (dict, list)) self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT): for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list): if isinstance(multi_result[0], list):
multi_result = multi_result[0] multi_result = multi_result[0]
...@@ -162,7 +161,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ...@@ -162,7 +161,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
self.assertIsInstance(multi_result[0], (dict, list)) self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT): for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list): if isinstance(multi_result[0], list):
multi_result = multi_result[0] multi_result = multi_result[0]
...@@ -197,7 +200,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ...@@ -197,7 +200,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
self.assertIsInstance(multi_result[0], (dict, list)) self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT): for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list): if isinstance(multi_result[0], list):
multi_result = multi_result[0] multi_result = multi_result[0]
...@@ -221,7 +228,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ...@@ -221,7 +228,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
self.assertIsInstance(multi_result[0], (dict, list)) self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT): for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list): if isinstance(multi_result[0], list):
multi_result = multi_result[0] multi_result = multi_result[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment