Unverified Commit 9de70f21 authored by Jim Rohrer's avatar Jim Rohrer Committed by GitHub
Browse files

Add ONNX export for BeiT (#16498)

* Add beit onnx conversion support

* Updated docs

* Added cross reference to ViT ONNX config
parent bfeff6cc
...@@ -47,6 +47,7 @@ Ready-made configurations include the following architectures: ...@@ -47,6 +47,7 @@ Ready-made configurations include the following architectures:
- ALBERT - ALBERT
- BART - BART
- BEiT
- BERT - BERT
- Blenderbot - Blenderbot
- BlenderbotSmall - BlenderbotSmall
......
...@@ -22,7 +22,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available, is_visi ...@@ -22,7 +22,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available, is_visi
_import_structure = { _import_structure = {
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"], "configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"],
} }
if is_vision_available(): if is_vision_available():
...@@ -48,7 +48,7 @@ if is_flax_available(): ...@@ -48,7 +48,7 @@ if is_flax_available():
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig
if is_vision_available(): if is_vision_available():
from .feature_extraction_beit import BeitFeatureExtractor from .feature_extraction_beit import BeitFeatureExtractor
......
...@@ -13,8 +13,13 @@ ...@@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BEiT model configuration""" """ BEiT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -176,3 +181,21 @@ class BeitConfig(PretrainedConfig): ...@@ -176,3 +181,21 @@ class BeitConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
class BeitOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
...@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union ...@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.beit import BeitOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
...@@ -270,6 +271,7 @@ class FeaturesManager: ...@@ -270,6 +271,7 @@ class FeaturesManager:
onnx_config_cls=ElectraOnnxConfig, onnx_config_cls=ElectraOnnxConfig,
), ),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig), "vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"blenderbot": supported_features_mapping( "blenderbot": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
......
...@@ -15,14 +15,13 @@ from transformers.onnx import ( ...@@ -15,14 +15,13 @@ from transformers.onnx import (
export, export,
validate_model_outputs, validate_model_outputs,
) )
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
if is_torch_available() or is_tf_available(): if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
@require_onnx @require_onnx
class OnnxUtilsTestCaseV2(TestCase): class OnnxUtilsTestCaseV2(TestCase):
...@@ -181,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -181,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment