Unverified Commit aaee4038 authored by Krishna Sirumalla's avatar Krishna Sirumalla Committed by GitHub
Browse files

Add onnx config for RoFormer (#16861)

* add roformer onnx config
parent 8afaaa26
...@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures: ...@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2 - OpenAI GPT-2
- PLBart - PLBart
- RoBERTa - RoBERTa
- RoFormer
- T5 - T5
- TAPEX - TAPEX
- ViT - ViT
......
...@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz ...@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
_import_structure = { _import_structure = {
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"], "configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerOnnxConfig"],
"tokenization_roformer": ["RoFormerTokenizer"], "tokenization_roformer": ["RoFormerTokenizer"],
} }
...@@ -73,7 +73,7 @@ if is_flax_available(): ...@@ -73,7 +73,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig
from .tokenization_roformer import RoFormerTokenizer from .tokenization_roformer import RoFormerTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -14,7 +14,11 @@ ...@@ -14,7 +14,11 @@
# limitations under the License. # limitations under the License.
""" RoFormer model configuration""" """ RoFormer model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig): ...@@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.rotary_value = rotary_value self.rotary_value = rotary_value
self.use_cache = use_cache self.use_cache = use_cache
class RoFormerOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
...@@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig ...@@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig from ..models.vit import ViTOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
...@@ -333,6 +334,17 @@ class FeaturesManager: ...@@ -333,6 +334,17 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=Data2VecTextOnnxConfig, onnx_config_cls=Data2VecTextOnnxConfig,
), ),
"roformer": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"multiple-choice",
"question-answering",
"token-classification",
onnx_config_cls=RoFormerOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
...@@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = {
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment