"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "51397336234a56ed2169413385c097fa1db4532d"
Unverified Commit 220da3b8 authored by Dhruv Karan's avatar Dhruv Karan Committed by GitHub
Browse files

Adds GroupViT to models exportable with ONNX (#18628)

* groupvit to onnx

* dynamic shape for pixel values dim
parent 46d0e26a
...@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures: ...@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- FlauBERT - FlauBERT
- GPT Neo - GPT Neo
- GPT-J - GPT-J
- GroupViT
- I-BERT - I-BERT
- LayoutLM - LayoutLM
- LayoutLMv3 - LayoutLMv3
......
...@@ -24,6 +24,7 @@ _import_structure = { ...@@ -24,6 +24,7 @@ _import_structure = {
"configuration_groupvit": [ "configuration_groupvit": [
"GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GroupViTConfig", "GroupViTConfig",
"GroupViTOnnxConfig",
"GroupViTTextConfig", "GroupViTTextConfig",
"GroupViTVisionConfig", "GroupViTVisionConfig",
], ],
...@@ -47,6 +48,7 @@ if TYPE_CHECKING: ...@@ -47,6 +48,7 @@ if TYPE_CHECKING:
from .configuration_groupvit import ( from .configuration_groupvit import (
GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GroupViTConfig, GroupViTConfig,
GroupViTOnnxConfig,
GroupViTTextConfig, GroupViTTextConfig,
GroupViTVisionConfig, GroupViTVisionConfig,
) )
......
...@@ -16,12 +16,19 @@ ...@@ -16,12 +16,19 @@
import copy import copy
import os import os
from typing import Union from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
if TYPE_CHECKING:
from ...processing_utils import ProcessorMixin
from ...utils import TensorType
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
...@@ -343,3 +350,44 @@ class GroupViTConfig(PretrainedConfig): ...@@ -343,3 +350,44 @@ class GroupViTConfig(PretrainedConfig):
output["vision_config"] = self.vision_config.to_dict() output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
return output return output
class GroupViTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("logits_per_image", {0: "batch"}),
("logits_per_text", {0: "batch"}),
("text_embeds", {0: "batch"}),
("image_embeds", {0: "batch"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
return {**text_input_dict, **image_input_dict}
@property
def default_onnx_opset(self) -> int:
return 14
...@@ -1542,7 +1542,7 @@ class GroupViTModel(GroupViTPreTrainedModel): ...@@ -1542,7 +1542,7 @@ class GroupViTModel(GroupViTPreTrainedModel):
# cosine similarity as logits # cosine similarity as logits
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T logits_per_image = logits_per_text.t()
seg_logits = None seg_logits = None
if output_segmentation: if output_segmentation:
......
...@@ -326,6 +326,10 @@ class FeaturesManager: ...@@ -326,6 +326,10 @@ class FeaturesManager:
"sequence-classification", "sequence-classification",
onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig", onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
), ),
"groupvit": supported_features_mapping(
"default",
onnx_config_cls="models.groupvit.GroupViTOnnxConfig",
),
"ibert": supported_features_mapping( "ibert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -204,6 +204,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -204,6 +204,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("layoutlmv3", "microsoft/layoutlmv3-base"), ("layoutlmv3", "microsoft/layoutlmv3-base"),
("groupvit", "nvidia/groupvit-gcc-yfcc"),
("levit", "facebook/levit-128S"), ("levit", "facebook/levit-128S"),
("owlvit", "google/owlvit-base-patch32"), ("owlvit", "google/owlvit-base-patch32"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment