Unverified Commit 1cdd2ad2 authored by Zhiyu Lin's avatar Zhiyu Lin Committed by GitHub
Browse files

Fix #2941 (#4109)



* Fix of issue #2941

Reshaped score array to avoid `numpy` ValueError.

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 5f4f6b65
...@@ -656,8 +656,8 @@ class TextClassificationPipeline(Pipeline): ...@@ -656,8 +656,8 @@ class TextClassificationPipeline(Pipeline):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs) outputs = super().__call__(*args, **kwargs)
scores = np.exp(outputs) / np.exp(outputs).sum(-1) scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores] return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores]
class FillMaskPipeline(Pipeline): class FillMaskPipeline(Pipeline):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment