Unverified Commit 46d0e26a authored by Dhruv Karan's avatar Dhruv Karan Committed by GitHub
Browse files

Adds OWLViT to models exportable with ONNX (#18588)

* onnx conversion for owlvit

* .T to .t()

* dynamic shapes for pixel values
parent b83796de
...@@ -83,6 +83,7 @@ Ready-made configurations include the following architectures: ...@@ -83,6 +83,7 @@ Ready-made configurations include the following architectures:
- MobileViT - MobileViT
- MT5 - MT5
- OpenAI GPT-2 - OpenAI GPT-2
- OWL-ViT
- Perceiver - Perceiver
- PLBart - PLBart
- ResNet - ResNet
......
...@@ -32,6 +32,7 @@ _import_structure = { ...@@ -32,6 +32,7 @@ _import_structure = {
"configuration_owlvit": [ "configuration_owlvit": [
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"OwlViTConfig", "OwlViTConfig",
"OwlViTOnnxConfig",
"OwlViTTextConfig", "OwlViTTextConfig",
"OwlViTVisionConfig", "OwlViTVisionConfig",
], ],
...@@ -66,6 +67,7 @@ if TYPE_CHECKING: ...@@ -66,6 +67,7 @@ if TYPE_CHECKING:
from .configuration_owlvit import ( from .configuration_owlvit import (
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OwlViTConfig, OwlViTConfig,
OwlViTOnnxConfig,
OwlViTTextConfig, OwlViTTextConfig,
OwlViTVisionConfig, OwlViTVisionConfig,
) )
......
...@@ -16,9 +16,16 @@ ...@@ -16,9 +16,16 @@
import copy import copy
import os import os
from typing import Dict, Union from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
if TYPE_CHECKING:
from ...processing_utils import ProcessorMixin
from ...utils import TensorType
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -334,3 +341,44 @@ class OwlViTConfig(PretrainedConfig): ...@@ -334,3 +341,44 @@ class OwlViTConfig(PretrainedConfig):
output["vision_config"] = self.vision_config.to_dict() output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
return output return output
class OwlViTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("logits_per_image", {0: "batch"}),
("logits_per_text", {0: "batch"}),
("text_embeds", {0: "batch"}),
("image_embeds", {0: "batch"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
return {**text_input_dict, **image_input_dict}
@property
def default_onnx_opset(self) -> int:
return 14
...@@ -687,7 +687,10 @@ class OwlViTTextTransformer(nn.Module): ...@@ -687,7 +687,10 @@ class OwlViTTextTransformer(nn.Module):
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# take features from the end of tokens embedding (end of token is the highest number in each sequence) # take features from the end of tokens embedding (end of token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
]
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooled_output) + encoder_outputs[1:]
...@@ -1066,7 +1069,7 @@ class OwlViTModel(OwlViTPreTrainedModel): ...@@ -1066,7 +1069,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
# cosine similarity as logits # cosine similarity as logits
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T logits_per_image = logits_per_text.t()
loss = None loss = None
if return_loss: if return_loss:
......
...@@ -416,6 +416,10 @@ class FeaturesManager: ...@@ -416,6 +416,10 @@ class FeaturesManager:
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls="models.m2m_100.M2M100OnnxConfig", onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
), ),
"owlvit": supported_features_mapping(
"default",
onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
),
"perceiver": supported_features_mapping( "perceiver": supported_features_mapping(
"image-classification", "image-classification",
"masked-lm", "masked-lm",
......
...@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("layoutlmv3", "microsoft/layoutlmv3-base"), ("layoutlmv3", "microsoft/layoutlmv3-base"),
("levit", "facebook/levit-128S"), ("levit", "facebook/levit-128S"),
("owlvit", "google/owlvit-base-patch32"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"), ("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment