Unverified Commit ddf177ee authored by Alvaro Bartolome's avatar Alvaro Bartolome Committed by GitHub
Browse files

Set `inputs` as kwarg in `TextClassificationPipeline` (#29495)



* Set `inputs` as kwarg in `TextClassificationPipeline`

This change has been done to align the `TextClassificationPipeline` with the rest of the pipelines, and to be able to e.g. `pipeline(**{"inputs": "text"})` which wouldn't be possible since the `*args` were being used instead.

* Add `noqa: C409` on `tuple([inputs],)`

Even though is discouraged by the linter, the cast `tuple(list(...),)` is required here, as otherwise the original list in `inputs` will be transformed into a `tuple` and the elements 1...N will be ignored by the `Pipeline`

* Run `ruff format`

* Simplify `tuple` conversion with `(inputs,)`
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

---------
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 4ed9ae62
...@@ -118,12 +118,12 @@ class TextClassificationPipeline(Pipeline): ...@@ -118,12 +118,12 @@ class TextClassificationPipeline(Pipeline):
postprocess_params["function_to_apply"] = function_to_apply postprocess_params["function_to_apply"] = function_to_apply
return preprocess_params, {}, postprocess_params return preprocess_params, {}, postprocess_params
def __call__(self, *args, **kwargs): def __call__(self, inputs, **kwargs):
""" """
Classify the text(s) given as inputs. Classify the text(s) given as inputs.
Args: Args:
args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`): inputs (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
One or several texts to classify. In order to use text pairs for your classification, you can send a One or several texts to classify. In order to use text pairs for your classification, you can send a
dictionary containing `{"text", "text_pair"}` keys, or a list of those. dictionary containing `{"text", "text_pair"}` keys, or a list of those.
top_k (`int`, *optional*, defaults to `1`): top_k (`int`, *optional*, defaults to `1`):
...@@ -152,10 +152,11 @@ class TextClassificationPipeline(Pipeline): ...@@ -152,10 +152,11 @@ class TextClassificationPipeline(Pipeline):
If `top_k` is used, one such dictionary is returned per label. If `top_k` is used, one such dictionary is returned per label.
""" """
result = super().__call__(*args, **kwargs) inputs = (inputs,)
result = super().__call__(*inputs, **kwargs)
# TODO try and retrieve it in a nicer way from _sanitize_parameters. # TODO try and retrieve it in a nicer way from _sanitize_parameters.
_legacy = "top_k" not in kwargs _legacy = "top_k" not in kwargs
if isinstance(args[0], str) and _legacy: if isinstance(inputs[0], str) and _legacy:
# This pipeline is odd, and return a list when single item is run # This pipeline is odd, and return a list when single item is run
return [result] return [result]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment