Unverified Commit 013bdc6d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing Backward compatiblity for zero-shot (#13855)

Fixes #13846
parent 9f58becc
...@@ -191,10 +191,7 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -191,10 +191,7 @@ class ZeroShotClassificationPipeline(Pipeline):
else: else:
raise ValueError(f"Unable to understand extra arguments {args}") raise ValueError(f"Unable to understand extra arguments {args}")
result = super().__call__(sequences, **kwargs) return super().__call__(sequences, **kwargs)
if len(result) == 1:
return result[0]
return result
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."): def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template) sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
...@@ -264,4 +261,6 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -264,4 +261,6 @@ class ZeroShotClassificationPipeline(Pipeline):
"scores": scores[iseq, top_inds].tolist(), "scores": scores[iseq, top_inds].tolist(),
} }
) )
if len(result) == 1:
return result[0]
return result return result
...@@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT ...@@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
) )
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
# https://github.com/huggingface/transformers/issues/13846
outputs = classifier(["I am happy"], ["positive", "negative"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
for i in range(1)
],
)
outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
for i in range(2)
],
)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
classifier("", candidate_labels="politics") classifier("", candidate_labels="politics")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment