"vscode:/vscode.git/clone" did not exist on "aee11fe427b2f2fd66c3ef3cd91757ec00420ac9"
Unverified Commit b56848c8 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Pipeline update & tests (#12207)

parent 700cee34
......@@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline):
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline.
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
......@@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline):
images = [self.load_image(image) for image in images]
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
......
......@@ -15,6 +15,7 @@
import unittest
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForImageClassification,
PreTrainedTokenizer,
......@@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase):
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
self.assertIs(image_classifier.tokenizer, tokenizer)
def test_num_labels_inferior_to_topk(self):
for small_model in self.small_models:
num_labels = 2
model = AutoModelForImageClassification.from_config(
AutoConfig.from_pretrained(small_model, num_labels=num_labels)
)
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), num_labels)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
# When images are batched, pipeline output is a list of lists of dictionaries
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output(output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment