Unverified Commit 8c14b342 authored by gcheron's avatar gcheron Committed by GitHub
Browse files

add ONNX support for LeVit (#18154)


Co-authored-by: default avatarGuilhem Chéron <guilhemc@authentifier.com>
parent c1c79b06
...@@ -72,6 +72,7 @@ Ready-made configurations include the following architectures: ...@@ -72,6 +72,7 @@ Ready-made configurations include the following architectures:
- I-BERT - I-BERT
- LayoutLM - LayoutLM
- LayoutLMv3 - LayoutLMv3
- LeViT
- LongT5 - LongT5
- M2M100 - M2M100
- Marian - Marian
......
...@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING ...@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"]} _import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]}
try: try:
if not is_vision_available(): if not is_vision_available():
...@@ -46,7 +46,7 @@ else: ...@@ -46,7 +46,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig
try: try:
if not is_vision_available(): if not is_vision_available():
......
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
# limitations under the License. # limitations under the License.
""" LeViT model configuration""" """ LeViT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -120,3 +126,21 @@ class LevitConfig(PretrainedConfig): ...@@ -120,3 +126,21 @@ class LevitConfig(PretrainedConfig):
["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2], ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2], ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
] ]
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
class LevitOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
...@@ -333,6 +333,9 @@ class FeaturesManager: ...@@ -333,6 +333,9 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig", onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
), ),
"levit": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
),
"longt5": supported_features_mapping( "longt5": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
......
...@@ -196,6 +196,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -196,6 +196,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("layoutlmv3", "microsoft/layoutlmv3-base"), ("layoutlmv3", "microsoft/layoutlmv3-base"),
("levit", "facebook/levit-128S"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"), ("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment