Unverified Commit b56848c8 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Pipeline update & tests (#12207)

parent 700cee34
...@@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline): ...@@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline):
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images. images.
top_k (:obj:`int`, `optional`, defaults to 5): top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline. The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
Return: Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
...@@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline): ...@@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline):
images = [self.load_image(image) for image in images] images = [self.load_image(image) for image in images]
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
with torch.no_grad(): with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt") inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs) outputs = self.model(**inputs)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
from transformers import ( from transformers import (
AutoConfig,
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModelForImageClassification, AutoModelForImageClassification,
PreTrainedTokenizer, PreTrainedTokenizer,
...@@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase): ...@@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase):
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer) image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
self.assertIs(image_classifier.tokenizer, tokenizer) self.assertIs(image_classifier.tokenizer, tokenizer)
def test_num_labels_inferior_to_topk(self):
for small_model in self.small_models:
num_labels = 2
model = AutoModelForImageClassification.from_config(
AutoConfig.from_pretrained(small_model, num_labels=num_labels)
)
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), num_labels)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
# When images are batched, pipeline output is a list of lists of dictionaries
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output(output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment