Unverified Commit e0be053e authored by regisss's avatar regisss Committed by GitHub
Browse files

Add ONNX support for ConvNeXT (#17627)

parent 5323094a
...@@ -55,6 +55,7 @@ Ready-made configurations include the following architectures: ...@@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
- BlenderbotSmall - BlenderbotSmall
- CamemBERT - CamemBERT
- ConvBERT - ConvBERT
- ConvNeXT
- Data2VecText - Data2VecText
- Data2VecVision - Data2VecVision
- DeiT - DeiT
......
...@@ -27,7 +27,9 @@ from ...utils import ( ...@@ -27,7 +27,9 @@ from ...utils import (
) )
_import_structure = {"configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"]} _import_structure = {
"configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig", "ConvNextOnnxConfig"]
}
try: try:
if not is_vision_available(): if not is_vision_available():
...@@ -63,7 +65,7 @@ else: ...@@ -63,7 +65,7 @@ else:
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig, ConvNextOnnxConfig
try: try:
if not is_vision_available(): if not is_vision_available():
......
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
# limitations under the License. # limitations under the License.
""" ConvNeXT model configuration""" """ ConvNeXT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -101,3 +107,20 @@ class ConvNextConfig(PretrainedConfig): ...@@ -101,3 +107,20 @@ class ConvNextConfig(PretrainedConfig):
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.image_size = image_size self.image_size = image_size
class ConvNextOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-5
...@@ -193,6 +193,11 @@ class FeaturesManager: ...@@ -193,6 +193,11 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls="models.convbert.ConvBertOnnxConfig", onnx_config_cls="models.convbert.ConvBertOnnxConfig",
), ),
"convnext": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.convnext.ConvNextOnnxConfig",
),
"data2vec-text": supported_features_mapping( "data2vec-text": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("ibert", "kssteven/ibert-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"), ("convbert", "YituTech/conv-bert-base"),
("convnext", "facebook/convnext-tiny-224"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("resnet", "microsoft/resnet-50"), ("resnet", "microsoft/resnet-50"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment