Unverified Commit a958c4a8 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

fix output data type of image classification (#31444)



* fix output data type of image classification

* add tests for low-precision pipeline

* add bf16 pipeline tests

* fix bf16 tests

* Update tests/pipelines/test_pipelines_image_classification.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix import

* fix import torch

* fix style

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 7e86cb6c
...@@ -23,6 +23,8 @@ if is_tf_available(): ...@@ -23,6 +23,8 @@ if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
if is_torch_available(): if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -180,7 +182,10 @@ class ImageClassificationPipeline(Pipeline): ...@@ -180,7 +182,10 @@ class ImageClassificationPipeline(Pipeline):
top_k = self.model.config.num_labels top_k = self.model.config.num_labels
outputs = model_outputs["logits"][0] outputs = model_outputs["logits"][0]
outputs = outputs.numpy() if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
outputs = outputs.to(torch.float32).numpy()
else:
outputs = outputs.numpy()
if function_to_apply == ClassificationFunction.SIGMOID: if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs) scores = sigmoid(outputs)
......
...@@ -18,6 +18,7 @@ from transformers import ( ...@@ -18,6 +18,7 @@ from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.pipelines import ImageClassificationPipeline, pipeline from transformers.pipelines import ImageClassificationPipeline, pipeline
...@@ -34,6 +35,9 @@ from transformers.testing_utils import ( ...@@ -34,6 +35,9 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
if is_torch_available():
import torch
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
else: else:
...@@ -177,6 +181,30 @@ class ImageClassificationPipelineTests(unittest.TestCase): ...@@ -177,6 +181,30 @@ class ImageClassificationPipelineTests(unittest.TestCase):
self.assertIs(image_classifier.tokenizer, tokenizer) self.assertIs(image_classifier.tokenizer, tokenizer)
@require_torch
def test_torch_float16_pipeline(self):
image_classifier = pipeline(
"image-classification", model="hf-internal-testing/tiny-random-vit", torch_dtype=torch.float16
)
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=3),
[{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
)
@require_torch
def test_torch_bfloat16_pipeline(self):
image_classifier = pipeline(
"image-classification", model="hf-internal-testing/tiny-random-vit", torch_dtype=torch.bfloat16
)
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=3),
[{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
)
@slow @slow
@require_torch @require_torch
def test_perceiver(self): def test_perceiver(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment