Unverified Commit aaee4038 authored by Krishna Sirumalla's avatar Krishna Sirumalla Committed by GitHub
Browse files

Add onnx config for RoFormer (#16861)

* add roformer onnx config
parent 8afaaa26
......@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2
- PLBart
- RoBERTa
- RoFormer
- T5
- TAPEX
- ViT
......
......@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
_import_structure = {
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"],
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerOnnxConfig"],
"tokenization_roformer": ["RoFormerTokenizer"],
}
......@@ -73,7 +73,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig
from .tokenization_roformer import RoFormerTokenizer
if is_tokenizers_available():
......
......@@ -14,7 +14,11 @@
# limitations under the License.
""" RoFormer model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.rotary_value = rotary_value
self.use_cache = use_cache
class RoFormerOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
......@@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
......@@ -333,6 +334,17 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=Data2VecTextOnnxConfig,
),
"roformer": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"multiple-choice",
"question-answering",
"token-classification",
onnx_config_cls=RoFormerOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
......@@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = {
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment