Unverified Commit e162cebf authored by Bibhabasu Mohapatra's avatar Bibhabasu Mohapatra Committed by GitHub
Browse files

add ONNX support for swin transformer (#19390)



* swin transformer onnx support

* Updated image dimensions as dynamic
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 969534af
......@@ -94,6 +94,7 @@ Ready-made configurations include the following architectures:
- RoFormer
- SegFormer
- SqueezeBERT
- Swin Transformer
- T5
- ViT
- XLM
......
......@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"]}
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig", "SwinOnnxConfig"]}
try:
......@@ -53,7 +53,7 @@ else:
]
if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig, SwinOnnxConfig
try:
if not is_torch_available():
......
......@@ -14,7 +14,13 @@
# limitations under the License.
""" Swin Transformer model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -145,3 +151,20 @@ class SwinConfig(PretrainedConfig):
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
class SwinOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
......@@ -471,6 +471,9 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
),
"swin": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.swin.SwinOnnxConfig"
),
"t5": supported_features_mapping(
"default",
"default-with-past",
......
......@@ -217,6 +217,7 @@ PYTORCH_EXPORT_MODELS = {
("longformer", "allenai/longformer-base-4096"),
("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("swin", "microsoft/swin-tiny-patch4-window7-224"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment