Unverified Commit e162cebf authored by Bibhabasu Mohapatra's avatar Bibhabasu Mohapatra Committed by GitHub
Browse files

add ONNX support for swin transformer (#19390)



* swin transformer onnx support

* Updated image dimensions as dynamic
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 969534af
...@@ -94,6 +94,7 @@ Ready-made configurations include the following architectures: ...@@ -94,6 +94,7 @@ Ready-made configurations include the following architectures:
- RoFormer - RoFormer
- SegFormer - SegFormer
- SqueezeBERT - SqueezeBERT
- Swin Transformer
- T5 - T5
- ViT - ViT
- XLM - XLM
......
...@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING ...@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"]} _import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig", "SwinOnnxConfig"]}
try: try:
...@@ -53,7 +53,7 @@ else: ...@@ -53,7 +53,7 @@ else:
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig, SwinOnnxConfig
try: try:
if not is_torch_available(): if not is_torch_available():
......
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
# limitations under the License. # limitations under the License.
""" Swin Transformer model configuration""" """ Swin Transformer model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -145,3 +151,20 @@ class SwinConfig(PretrainedConfig): ...@@ -145,3 +151,20 @@ class SwinConfig(PretrainedConfig):
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
class SwinOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
...@@ -471,6 +471,9 @@ class FeaturesManager: ...@@ -471,6 +471,9 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig", onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
), ),
"swin": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.swin.SwinOnnxConfig"
),
"t5": supported_features_mapping( "t5": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
......
...@@ -217,6 +217,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -217,6 +217,7 @@ PYTORCH_EXPORT_MODELS = {
("longformer", "allenai/longformer-base-4096"), ("longformer", "allenai/longformer-base-4096"),
("yolos", "hustvl/yolos-tiny"), ("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"), ("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("swin", "microsoft/swin-tiny-patch4-window7-224"),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment