Unverified Commit 7e7f7434 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add SegFormer ONNX support (#18006)



* Add ONNX support

* Make height and width dynamic axes
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 89514f05
...@@ -90,6 +90,7 @@ Ready-made configurations include the following architectures: ...@@ -90,6 +90,7 @@ Ready-made configurations include the following architectures:
- ResNet - ResNet
- RoBERTa - RoBERTa
- RoFormer - RoFormer
- SegFormer
- SqueezeBERT - SqueezeBERT
- T5 - T5
- ViT - ViT
......
...@@ -26,7 +26,9 @@ from ...utils import ( ...@@ -26,7 +26,9 @@ from ...utils import (
) )
_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]} _import_structure = {
"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig", "SegformerOnnxConfig"]
}
try: try:
if not is_vision_available(): if not is_vision_available():
...@@ -69,7 +71,7 @@ else: ...@@ -69,7 +71,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig, SegformerOnnxConfig
try: try:
if not is_vision_available(): if not is_vision_available():
......
...@@ -15,8 +15,13 @@ ...@@ -15,8 +15,13 @@
""" SegFormer model configuration""" """ SegFormer model configuration"""
import warnings import warnings
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -148,3 +153,24 @@ class SegformerConfig(PretrainedConfig): ...@@ -148,3 +153,24 @@ class SegformerConfig(PretrainedConfig):
self.decoder_hidden_size = decoder_hidden_size self.decoder_hidden_size = decoder_hidden_size
self.reshape_last_stage = kwargs.get("reshape_last_stage", True) self.reshape_last_stage = kwargs.get("reshape_last_stage", True)
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
class SegformerOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
@property
def default_onnx_opset(self) -> int:
return 12
...@@ -456,6 +456,12 @@ class FeaturesManager: ...@@ -456,6 +456,12 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls="models.roformer.RoFormerOnnxConfig", onnx_config_cls="models.roformer.RoFormerOnnxConfig",
), ),
"segformer": supported_features_mapping(
"default",
"image-classification",
"semantic-segmentation",
onnx_config_cls="models.segformer.SegformerOnnxConfig",
),
"squeezebert": supported_features_mapping( "squeezebert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -216,6 +216,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -216,6 +216,7 @@ PYTORCH_EXPORT_MODELS = {
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)), ("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
("longformer", "allenai/longformer-base-4096"), ("longformer", "allenai/longformer-base-4096"),
("yolos", "hustvl/yolos-tiny"), ("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment