"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c64c2fc4c2be2f86de092c298c5ee713347bf61d"
Unverified Commit 15bc776f authored by Syed Abdul Gaffar Shakhadri's avatar Syed Abdul Gaffar Shakhadri Committed by GitHub
Browse files

Add Onnx Config for PoolFormer (#20868)



poolformer onnx
Co-authored-by: default avatarsyed <syed.abdul@sandlogic.com>
parent 4a4cd6cd
...@@ -102,6 +102,7 @@ Ready-made configurations include the following architectures: ...@@ -102,6 +102,7 @@ Ready-made configurations include the following architectures:
- OWL-ViT - OWL-ViT
- Perceiver - Perceiver
- PLBart - PLBart
- PoolFormer
- RemBERT - RemBERT
- ResNet - ResNet
- RoBERTa - RoBERTa
......
...@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING ...@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"]} _import_structure = {
"configuration_poolformer": [
"POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PoolFormerConfig",
"PoolFormerOnnxConfig",
]
}
try: try:
if not is_vision_available(): if not is_vision_available():
...@@ -47,7 +53,11 @@ else: ...@@ -47,7 +53,11 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .configuration_poolformer import (
POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
PoolFormerConfig,
PoolFormerOnnxConfig,
)
try: try:
if not is_vision_available(): if not is_vision_available():
......
...@@ -13,8 +13,13 @@ ...@@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PoolFormer model configuration""" """ PoolFormer model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -125,3 +130,20 @@ class PoolFormerConfig(PretrainedConfig): ...@@ -125,3 +130,20 @@ class PoolFormerConfig(PretrainedConfig):
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.initializer_range = initializer_range self.initializer_range = initializer_range
super().__init__(**kwargs) super().__init__(**kwargs)
class PoolFormerOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 2e-3
...@@ -447,6 +447,9 @@ class FeaturesManager: ...@@ -447,6 +447,9 @@ class FeaturesManager:
"sequence-classification", "sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig", onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
), ),
"poolformer": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig"
),
"rembert": supported_features_mapping( "rembert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = {
("owlvit", "google/owlvit-base-patch32"), ("owlvit", "google/owlvit-base-patch32"),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
("poolformer", "sail/poolformer_s12"),
("rembert", "google/rembert"), ("rembert", "google/rembert"),
("resnet", "microsoft/resnet-50"), ("resnet", "microsoft/resnet-50"),
("roberta", "hf-internal-testing/tiny-random-RobertaModel"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment