Unverified Commit 46d0e26a authored by Dhruv Karan's avatar Dhruv Karan Committed by GitHub
Browse files

Adds OWLViT to models exportable with ONNX (#18588)

* onnx conversion for owlvit

* .T to .t()

* dynamic shapes for pixel values
parent b83796de
......@@ -83,6 +83,7 @@ Ready-made configurations include the following architectures:
- MobileViT
- MT5
- OpenAI GPT-2
- OWL-ViT
- Perceiver
- PLBart
- ResNet
......
......@@ -32,6 +32,7 @@ _import_structure = {
"configuration_owlvit": [
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"OwlViTConfig",
"OwlViTOnnxConfig",
"OwlViTTextConfig",
"OwlViTVisionConfig",
],
......@@ -66,6 +67,7 @@ if TYPE_CHECKING:
from .configuration_owlvit import (
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OwlViTConfig,
OwlViTOnnxConfig,
OwlViTTextConfig,
OwlViTVisionConfig,
)
......
......@@ -16,9 +16,16 @@
import copy
import os
from typing import Dict, Union
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
if TYPE_CHECKING:
from ...processing_utils import ProcessorMixin
from ...utils import TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -334,3 +341,44 @@ class OwlViTConfig(PretrainedConfig):
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class OwlViTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("logits_per_image", {0: "batch"}),
("logits_per_text", {0: "batch"}),
("text_embeds", {0: "batch"}),
("image_embeds", {0: "batch"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
return {**text_input_dict, **image_input_dict}
@property
def default_onnx_opset(self) -> int:
return 14
......@@ -687,7 +687,10 @@ class OwlViTTextTransformer(nn.Module):
last_hidden_state = self.final_layer_norm(last_hidden_state)
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
......@@ -1066,7 +1069,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
......
......@@ -416,6 +416,10 @@ class FeaturesManager:
"seq2seq-lm-with-past",
onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
),
"owlvit": supported_features_mapping(
"default",
onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
),
"perceiver": supported_features_mapping(
"image-classification",
"masked-lm",
......
......@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
("layoutlm", "microsoft/layoutlm-base-uncased"),
("layoutlmv3", "microsoft/layoutlmv3-base"),
("levit", "facebook/levit-128S"),
("owlvit", "google/owlvit-base-patch32"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment