"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ce85686a1f425c8e60d9104522d8626395dd507d"
Unverified Commit 5323094a authored by regisss's avatar regisss Committed by GitHub
Browse files

Add ONNX support for ResNet (#17585)

* Add ONNX support for ResNet

* Add ONNX test

* make fix-copies
parent ca2a55e9
...@@ -72,6 +72,7 @@ Ready-made configurations include the following architectures: ...@@ -72,6 +72,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2 - OpenAI GPT-2
- Perceiver - Perceiver
- PLBart - PLBart
- ResNet
- RoBERTa - RoBERTa
- RoFormer - RoFormer
- SqueezeBERT - SqueezeBERT
......
...@@ -21,7 +21,9 @@ from typing import TYPE_CHECKING ...@@ -21,7 +21,9 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"]} _import_structure = {
"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig", "ResNetOnnxConfig"]
}
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -38,7 +40,7 @@ else: ...@@ -38,7 +40,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
try: try:
if not is_torch_available(): if not is_torch_available():
......
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
# limitations under the License. # limitations under the License.
""" ResNet model configuration""" """ ResNet model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -89,3 +95,20 @@ class ResNetConfig(PretrainedConfig): ...@@ -89,3 +95,20 @@ class ResNetConfig(PretrainedConfig):
self.layer_type = layer_type self.layer_type = layer_type
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage self.downsample_in_first_stage = downsample_in_first_stage
class ResNetOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-3
...@@ -318,6 +318,11 @@ class FeaturesManager: ...@@ -318,6 +318,11 @@ class FeaturesManager:
"sequence-classification", "sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig", onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
), ),
"resnet": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.resnet.ResNetOnnxConfig",
),
"roberta": supported_features_mapping( "roberta": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("convbert", "YituTech/conv-bert-base"), ("convbert", "YituTech/conv-bert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"), ("roformer", "junnyu/roformer_chinese_base"),
("squeezebert", "squeezebert/squeezebert-uncased"), ("squeezebert", "squeezebert/squeezebert-uncased"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment