Unverified Commit 6b22a8f2 authored by Chujie Zheng's avatar Chujie Zheng Committed by GitHub
Browse files

fix bf16 issue in text classification pipeline (#30996)

* fix logits dtype

* Add bf16/fp16 tests for text_classification pipeline

* Update test_pipelines_text_classification.py

* fix

* fix
parent de460e28
......@@ -202,7 +202,7 @@ class TextClassificationPipeline(Pipeline):
function_to_apply = ClassificationFunction.NONE
outputs = model_outputs["logits"][0]
outputs = outputs.numpy()
outputs = outputs.float().numpy()
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)
......
......@@ -14,13 +14,24 @@
import unittest
import torch
from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TextClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_bf16,
require_torch_fp16,
slow,
torch_device,
)
from .test_pipelines_common import ANY
......@@ -106,6 +117,32 @@ class TextClassificationPipelineTests(unittest.TestCase):
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch_fp16
def test_accepts_torch_fp16(self):
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch_device,
torch_dtype=torch.float16,
)
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch_bf16
def test_accepts_torch_bf16(self):
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch_device,
torch_dtype=torch.bfloat16,
)
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment