Commit 8bb166db authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Expose more information in the output of TextClassificationPipeline

parent 04b602f9
...@@ -283,7 +283,9 @@ class TextClassificationPipeline(Pipeline): ...@@ -283,7 +283,9 @@ class TextClassificationPipeline(Pipeline):
self._nb_classes = nb_classes self._nb_classes = nb_classes
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist() outputs = super().__call__(*args, **kwargs)
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
return [{'label': self.model.config.id2label[item.argmax()], 'score': item.max()} for item in scores]
class NerPipeline(Pipeline): class NerPipeline(Pipeline):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment