"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "50595a333635cd73c2d10e6135db8ed9201708f3"
Unverified Commit 0eabe492 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing zero-shot backward compatiblity (#13725)

Fixes #13697
parent a2ef9c54
...@@ -150,6 +150,7 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -150,6 +150,7 @@ class ZeroShotClassificationPipeline(Pipeline):
def __call__( def __call__(
self, self,
sequences: Union[str, List[str]], sequences: Union[str, List[str]],
*args,
**kwargs, **kwargs,
): ):
""" """
...@@ -183,6 +184,13 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -183,6 +184,13 @@ class ZeroShotClassificationPipeline(Pipeline):
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels. - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
""" """
if len(args) == 0:
pass
elif len(args) == 1 and "candidate_labels" not in kwargs:
kwargs["candidate_labels"] = args[0]
else:
raise ValueError(f"Unable to understand extra arguments {args}")
result = super().__call__(sequences, **kwargs) result = super().__call__(sequences, **kwargs)
if len(result) == 1: if len(result) == 1:
return result[0] return result[0]
......
...@@ -37,6 +37,10 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT ...@@ -37,6 +37,10 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics") outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
# No kwarg
outputs = classifier("Who are you voting for in 2020?", ["politics"])
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"]) outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment