Unverified Commit b2a41d2b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Faster zero shot image (#21897)



* Make ZeroShotImageClassificationPipeline faster

The pipeline makes separate calls to model for each candidate label.
This commit combines all labels into one call.
Original code takes more that 60 seconds to process one image and 1000
candidate labels. Updated code takes less than 2 seconds.

* implement batching

* code formatting

* Creating an even faster zero-shot-image-classifiction.

Unfortunately super tailored towards CLIP.
Co-Authored-By: default avatarYessen Kanapin <yessen@deepinfra.com>

* Quality.

* Cleanup.

* Order different on the CI it seems.

* Cleanup.

* Quality.

---------
Co-authored-by: default avatarYessen Kanapin <yessen@deepinfra.com>
parent 88e5c51a
from collections import UserDict
from typing import List, Union
from ..utils import (
......@@ -8,7 +9,7 @@ from ..utils import (
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available():
......@@ -17,18 +18,16 @@ if is_vision_available():
from ..image_utils import load_image
if is_torch_available():
import torch
pass
if is_tf_available():
import tensorflow as tf
from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotImageClassificationPipeline(ChunkPipeline):
class ZeroShotImageClassificationPipeline(Pipeline):
"""
Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you
provide an image and a set of `candidate_labels`.
......@@ -107,42 +106,39 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
return preprocess_params, {}, {}
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."):
n = len(candidate_labels)
for i, candidate_label in enumerate(candidate_labels):
image = load_image(image)
images = self.image_processor(images=[image], return_tensors=self.framework)
sequence = hypothesis_template.format(candidate_label)
inputs = self.tokenizer(sequence, return_tensors=self.framework)
inputs["pixel_values"] = images.pixel_values
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
image = load_image(image)
inputs = self.image_processor(images=[image], return_tensors=self.framework)
inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels]
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
inputs["text_inputs"] = [text_inputs]
return inputs
def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last")
candidate_label = model_inputs.pop("candidate_label")
outputs = self.model(**model_inputs)
candidate_labels = model_inputs.pop("candidate_labels")
text_inputs = model_inputs.pop("text_inputs")
if isinstance(text_inputs[0], UserDict):
text_inputs = text_inputs[0]
else:
# Batching case.
text_inputs = text_inputs[0][0]
# Clip does crossproduct scoring by default, so we're only
# interested in the results where image and text and in the same
# batch position.
diag = torch.diagonal if self.framework == "pt" else tf.linalg.diag_part
logits_per_image = diag(outputs.logits_per_image)
outputs = self.model(**text_inputs, **model_inputs)
model_outputs = {
"is_last": is_last,
"candidate_label": candidate_label,
"logits_per_image": logits_per_image,
"candidate_labels": candidate_labels,
"logits": outputs.logits_per_image,
}
return model_outputs
def postprocess(self, model_outputs):
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
candidate_labels = model_outputs.pop("candidate_labels")
logits = model_outputs["logits"][0]
if self.framework == "pt":
logits = torch.cat([output["logits_per_image"] for output in model_outputs])
probs = logits.softmax(dim=0)
probs = logits.softmax(dim=-1).squeeze(-1)
scores = probs.tolist()
else:
logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0)
probs = stable_softmax(logits, axis=0)
probs = stable_softmax(logits, axis=-1)
scores = probs.numpy().tolist()
result = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment