"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "122fd5d37f9e4d3e77025609e7883c4e0d0fe7ac"
Unverified Commit 776855c7 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing a regression with `return_all_scores` introduced in #17606 (#17906)

Fixing a regression with `return_all_scores` introduced in #17606

- The legacy test actually tested `return_all_scores=False` (the actual
  default) instead of `return_all_scores=True` (the actual weird case).

This commit adds the correct legacy test and fixes it.

Tmp legacy tests.

Actually fix the regression (also contains lists)

Less diffed code.
parent 5f1e67a5
...@@ -136,7 +136,9 @@ class TextClassificationPipeline(Pipeline): ...@@ -136,7 +136,9 @@ class TextClassificationPipeline(Pipeline):
If `top_k` is used, one such dictionary is returned per label. If `top_k` is used, one such dictionary is returned per label.
""" """
result = super().__call__(*args, **kwargs) result = super().__call__(*args, **kwargs)
if isinstance(args[0], str) and isinstance(result, dict): # TODO try and retrieve it in a nicer way from _sanitize_parameters.
_legacy = "top_k" not in kwargs
if isinstance(args[0], str) and _legacy:
# This pipeline is odd, and return a list when single item is run # This pipeline is odd, and return a list when single item is run
return [result] return [result]
else: else:
......
...@@ -60,6 +60,29 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC ...@@ -60,6 +60,29 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("This is great !", return_all_scores=False) outputs = text_classifier("This is great !", return_all_scores=False)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
outputs = text_classifier("This is great !", return_all_scores=True)
self.assertEqual(
nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]]
)
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True)
self.assertEqual(
nested_simplify(outputs),
[
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
],
)
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False)
self.assertEqual(
nested_simplify(outputs),
[
{"label": "LABEL_0", "score": 0.504},
{"label": "LABEL_0", "score": 0.504},
],
)
@require_torch @require_torch
def test_accepts_torch_device(self): def test_accepts_torch_device(self):
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment