Unverified Commit 2351729f authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `top_k` argument to `text-classification` pipeline. (#17606)

* Adding `top_k` and `sort` arguments to `text-classification` pipeline.

- Deprecate `return_all_scores` as `top_k` is more uniform with other
  pipelines, and a superset of what `return_all_scores` can do.
  BC is maintained though.
  `return_all_scores=True` -> `top_k=None`
  `return_all_scores=False` -> `top_k=1`

- Using `top_k` will imply sorting the results, but using no argument
  will keep the results unsorted for backward compatibility.

* Remove `sort`.

* Fixing the test.

* Remove bad doc.
parent 29080643
import warnings
from typing import Dict from typing import Dict
import numpy as np import numpy as np
...@@ -72,15 +73,26 @@ class TextClassificationPipeline(Pipeline): ...@@ -72,15 +73,26 @@ class TextClassificationPipeline(Pipeline):
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
) )
def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, **tokenizer_kwargs): def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
# Using "" as default argument because we're going to use `top_k=None` in user code to declare
# "No top_k"
preprocess_params = tokenizer_kwargs preprocess_params = tokenizer_kwargs
postprocess_params = {} postprocess_params = {}
if hasattr(self.model.config, "return_all_scores") and return_all_scores is None: if hasattr(self.model.config, "return_all_scores") and return_all_scores is None:
return_all_scores = self.model.config.return_all_scores return_all_scores = self.model.config.return_all_scores
if return_all_scores is not None: if isinstance(top_k, int) or top_k is None:
postprocess_params["return_all_scores"] = return_all_scores postprocess_params["top_k"] = top_k
postprocess_params["_legacy"] = False
elif return_all_scores is not None:
warnings.warn(
"`return_all_scores` is now deprecated, use `top_k=1` if you want similar functionnality", UserWarning
)
if return_all_scores:
postprocess_params["top_k"] = None
else:
postprocess_params["top_k"] = 1
if isinstance(function_to_apply, str): if isinstance(function_to_apply, str):
function_to_apply = ClassificationFunction[function_to_apply.upper()] function_to_apply = ClassificationFunction[function_to_apply.upper()]
...@@ -97,8 +109,8 @@ class TextClassificationPipeline(Pipeline): ...@@ -97,8 +109,8 @@ class TextClassificationPipeline(Pipeline):
args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`): args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
One or several texts to classify. In order to use text pairs for your classification, you can send a One or several texts to classify. In order to use text pairs for your classification, you can send a
dictionnary containing `{"text", "text_pair"}` keys, or a list of those. dictionnary containing `{"text", "text_pair"}` keys, or a list of those.
return_all_scores (`bool`, *optional*, defaults to `False`): top_k (`int`, *optional*, defaults to `1`):
Whether to return scores for all labels. How many results to return.
function_to_apply (`str`, *optional*, defaults to `"default"`): function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different The function to apply to the model outputs in order to retrieve the scores. Accepts four different
values: values:
...@@ -121,10 +133,10 @@ class TextClassificationPipeline(Pipeline): ...@@ -121,10 +133,10 @@ class TextClassificationPipeline(Pipeline):
- **label** (`str`) -- The label predicted. - **label** (`str`) -- The label predicted.
- **score** (`float`) -- The corresponding probability. - **score** (`float`) -- The corresponding probability.
If `self.return_all_scores=True`, one such dictionary is returned per label. If `top_k` is used, one such dictionary is returned per label.
""" """
result = super().__call__(*args, **kwargs) result = super().__call__(*args, **kwargs)
if isinstance(args[0], str): if isinstance(args[0], str) and isinstance(result, dict):
# This pipeline is odd, and return a list when single item is run # This pipeline is odd, and return a list when single item is run
return [result] return [result]
else: else:
...@@ -150,7 +162,10 @@ class TextClassificationPipeline(Pipeline): ...@@ -150,7 +162,10 @@ class TextClassificationPipeline(Pipeline):
def _forward(self, model_inputs): def _forward(self, model_inputs):
return self.model(**model_inputs) return self.model(**model_inputs)
def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False): def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
# `_legacy` is used to determine if we're running the naked pipeline and in backward
# compatibility mode, or if running the pipeline with `pipeline(..., top_k=1)` we're running
# the more natural result containing the list.
# Default value before `set_parameters` # Default value before `set_parameters`
if function_to_apply is None: if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1: if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
...@@ -174,7 +189,14 @@ class TextClassificationPipeline(Pipeline): ...@@ -174,7 +189,14 @@ class TextClassificationPipeline(Pipeline):
else: else:
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}") raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
if return_all_scores: if top_k == 1 and _legacy:
return [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)]
else:
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()} return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
dict_scores = [
{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
]
if not _legacy:
dict_scores.sort(key=lambda x: x["score"], reverse=True)
if top_k is not None:
dict_scores = dict_scores[:top_k]
return dict_scores
...@@ -39,6 +39,27 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC ...@@ -39,6 +39,27 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("This is great !") outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
outputs = text_classifier("This is great !", top_k=2)
self.assertEqual(
nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]
)
outputs = text_classifier(["This is great !", "This is bad"], top_k=2)
self.assertEqual(
nested_simplify(outputs),
[
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
],
)
outputs = text_classifier("This is great !", top_k=1)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
# Legacy behavior
outputs = text_classifier("This is great !", return_all_scores=False)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch @require_torch
def test_accepts_torch_device(self): def test_accepts_torch_device(self):
import torch import torch
...@@ -108,6 +129,15 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC ...@@ -108,6 +129,15 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
self.assertTrue(outputs[0]["label"] in model.config.id2label.values()) self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
self.assertTrue(outputs[1]["label"] in model.config.id2label.values()) self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
# Forcing to get all results with `top_k=None`
# This is NOT the legacy format
outputs = text_classifier(valid_inputs, top_k=None)
N = len(model.config.id2label.values())
self.assertEqual(
nested_simplify(outputs),
[[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
)
valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"} valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
outputs = text_classifier(valid_inputs) outputs = text_classifier(valid_inputs)
self.assertEqual( self.assertEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment