"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e20d8895bdc926babc45e6bfa7ec9047b012aa77"
Unverified Commit 5a06118b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Enabling `TF` on `image-classification` pipeline. (#15030)

parent 9f89fa02
from typing import List, Union from typing import List, Union
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..file_utils import (
add_end_docstrings,
is_tf_available,
is_torch_available,
is_vision_available,
requires_backends,
)
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
...@@ -10,6 +16,11 @@ if is_vision_available(): ...@@ -10,6 +16,11 @@ if is_vision_available():
from ..image_utils import load_image from ..image_utils import load_image
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
...@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline): ...@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) self.check_model_type(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
)
def _sanitize_parameters(self, top_k=None): def _sanitize_parameters(self, top_k=None):
postprocess_params = {} postprocess_params = {}
...@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline): ...@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline):
def preprocess(self, image): def preprocess(self, image):
image = load_image(image) image = load_image(image)
model_inputs = self.feature_extractor(images=image, return_tensors="pt") model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs):
...@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline): ...@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline):
def postprocess(self, model_outputs, top_k=5): def postprocess(self, model_outputs, top_k=5):
if top_k > self.model.config.num_labels: if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels top_k = self.model.config.num_labels
probs = model_outputs.logits.softmax(-1)[0]
scores, ids = probs.topk(top_k) if self.framework == "pt":
probs = model_outputs.logits.softmax(-1)[0]
scores, ids = probs.topk(top_k)
elif self.framework == "tf":
probs = tf.nn.softmax(model_outputs.logits, axis=-1)[0]
topk = tf.math.top_k(probs, k=top_k)
scores, ids = topk.values.numpy(), topk.indices.numpy()
else:
raise ValueError(f"Unsupported framework: {self.framework}")
scores = scores.tolist() scores = scores.tolist()
ids = ids.tolist() ids = ids.tolist()
......
...@@ -14,7 +14,12 @@ ...@@ -14,7 +14,12 @@
import unittest import unittest
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
PreTrainedTokenizer,
is_vision_available,
)
from transformers.pipelines import ImageClassificationPipeline, pipeline from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
...@@ -40,9 +45,9 @@ else: ...@@ -40,9 +45,9 @@ else:
@is_pipeline_test @is_pipeline_test
@require_vision @require_vision
@require_torch
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor): def get_test_pipeline(self, model, tokenizer, feature_extractor):
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2) image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
...@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
) )
@require_tf @require_tf
@unittest.skip("Image classification is not implemented for TF")
def test_small_model_tf(self): def test_small_model_tf(self):
pass small_model = "lysandre/tiny-vit-random"
image_classifier = pipeline("image-classification", model=small_model)
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
{"score": 0.0014, "label": "trench coat"},
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
{"score": 0.0014, "label": "baboon"},
],
)
outputs = image_classifier(
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
],
top_k=2,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
],
)
def test_custom_tokenizer(self): def test_custom_tokenizer(self):
tokenizer = PreTrainedTokenizer() tokenizer = PreTrainedTokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment