Unverified Commit 8246caf3 authored by Rushi Chaudhari's avatar Rushi Chaudhari Committed by GitHub
Browse files

added deit onnx config (#16887)

* added deit onnx config
parent 9331b379
......@@ -56,6 +56,7 @@ Ready-made configurations include the following architectures:
- ConvBERT
- Data2VecText
- Data2VecVision
- DeiT
- DistilBERT
- ELECTRA
- FlauBERT
......
......@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"],
}
if is_vision_available():
......@@ -39,7 +39,7 @@ if is_torch_available():
if TYPE_CHECKING:
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig
if is_vision_available():
from .feature_extraction_deit import DeiTFeatureExtractor
......
......@@ -14,7 +14,13 @@
# limitations under the License.
""" DeiT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -120,3 +126,19 @@ class DeiTConfig(PretrainedConfig):
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride
class DeiTOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
......@@ -12,6 +12,7 @@ from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.deit import DeiTOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
......@@ -38,6 +39,7 @@ if is_torch_available():
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
......@@ -103,6 +105,7 @@ class FeaturesManager:
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
"masked-im": AutoModelForMaskedImageModeling,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
......@@ -294,8 +297,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"beit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",
......
......@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment