Unverified Commit 87282cb7 authored by Erin's avatar Erin Committed by GitHub
Browse files

Add RemBERT ONNX config (#20520)



* rembert onnx config

* formatting
Co-authored-by: default avatarHo <erincho@bcd0745f972b.ant.amazon.com>
parent afe2a466
...@@ -93,6 +93,7 @@ Ready-made configurations include the following architectures: ...@@ -93,6 +93,7 @@ Ready-made configurations include the following architectures:
- OWL-ViT - OWL-ViT
- Perceiver - Perceiver
- PLBart - PLBart
- RemBERT
- ResNet - ResNet
- RoBERTa - RoBERTa
- RoFormer - RoFormer
......
...@@ -28,7 +28,9 @@ from ...utils import ( ...@@ -28,7 +28,9 @@ from ...utils import (
) )
_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]} _import_structure = {
"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]
}
try: try:
if not is_sentencepiece_available(): if not is_sentencepiece_available():
...@@ -88,7 +90,7 @@ else: ...@@ -88,7 +90,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig
try: try:
if not is_sentencepiece_available(): if not is_sentencepiece_available():
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" RemBERT model configuration""" """ RemBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -135,3 +138,23 @@ class RemBertConfig(PretrainedConfig): ...@@ -135,3 +138,23 @@ class RemBertConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.tie_word_embeddings = False self.tie_word_embeddings = False
class RemBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
...@@ -447,6 +447,16 @@ class FeaturesManager: ...@@ -447,6 +447,16 @@ class FeaturesManager:
"sequence-classification", "sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig", onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
), ),
"rembert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.rembert.RemBertOnnxConfig",
),
"resnet": supported_features_mapping( "resnet": supported_features_mapping(
"default", "default",
"image-classification", "image-classification",
......
...@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = {
("owlvit", "google/owlvit-base-patch32"), ("owlvit", "google/owlvit-base-patch32"),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
("rembert", "google/rembert"),
("resnet", "microsoft/resnet-50"), ("resnet", "microsoft/resnet-50"),
("roberta", "hf-internal-testing/tiny-random-RobertaModel"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"), ("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment