Unverified Commit ff65beaf authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

FillMaskPipeline: support passing top_k on __call__ (#7971)

* FillMaskPipeline: support passing top_k on __call__

Also move from topk to top_k

* migrate to new param name in tests

* Review from @sgugger
parent 2e5052d4
...@@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline):
@add_end_docstrings( @add_end_docstrings(
PIPELINE_INIT_ARGS, PIPELINE_INIT_ARGS,
r""" r"""
topk (:obj:`int`, defaults to 5): The number of predictions to return. top_k (:obj:`int`, defaults to 5): The number of predictions to return.
""", """,
) )
class FillMaskPipeline(Pipeline): class FillMaskPipeline(Pipeline):
...@@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline): ...@@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline):
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1, device: int = -1,
topk=5, top_k=5,
task: str = "", task: str = "",
**kwargs
): ):
super().__init__( super().__init__(
model=model, model=model,
...@@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline): ...@@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline):
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING) self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
self.topk = topk if "topk" in kwargs:
warnings.warn(
"The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
FutureWarning,
)
self.top_k = kwargs.pop("topk")
else:
self.top_k = top_k
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray): def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
numel = np.prod(masked_index.shape) numel = np.prod(masked_index.shape)
...@@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline): ...@@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline):
f"No mask_token ({self.tokenizer.mask_token}) found on the input", f"No mask_token ({self.tokenizer.mask_token}) found on the input",
) )
def __call__(self, *args, targets=None, **kwargs): def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
""" """
Fill the masked token in the text(s) given as inputs. Fill the masked token in the text(s) given as inputs.
...@@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline): ...@@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline):
When passed, the model will return the scores for the passed token or tokens rather than the top k When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
be tokenized and the first resulting token will be used (with a warning). be tokenized and the first resulting token will be used (with a warning).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.
Return: Return:
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
...@@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline): ...@@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :] logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits) probs = tf.nn.softmax(logits)
if targets is None: if targets is None:
topk = tf.math.top_k(probs, k=self.topk) topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy() values, predictions = topk.values.numpy(), topk.indices.numpy()
else: else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1))) values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
...@@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline): ...@@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :] logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0) probs = logits.softmax(dim=0)
if targets is None: if targets is None:
values, predictions = probs.topk(self.topk) values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else: else:
values = probs[..., target_inds] values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1))) sort_inds = list(reversed(values.argsort(dim=-1)))
......
...@@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="pt", framework="pt",
topk=2, top_k=2,
) )
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
...@@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="tf", framework="tf",
topk=2, top_k=2,
) )
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
...@@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="pt", framework="pt",
topk=2, top_k=2,
) )
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, nlp,
...@@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
] ]
valid_targets = [" Patrick", " Clara"] valid_targets = [" Patrick", " Clara"]
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS: for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2) nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, nlp,
valid_inputs, valid_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment