Unverified Commit 87d08afb authored by aaron's avatar aaron Committed by GitHub
Browse files

electra is added to onnx supported model (#15084)



* electra is added to onnx supported model

* add google/electra-base-generator for test onnx module
Co-authored-by: default avatarLewis Tunstall <lewis.c.tunstall@gmail.com>
parent 0fe17f37
......@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
- BERT
- CamemBERT
- DistilBERT
- ELECTRA
- GPT Neo
- I-BERT
- LayoutLM
......
......@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = {
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig"],
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"],
"tokenization_electra": ["ElectraTokenizer"],
}
......@@ -71,7 +71,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
from .tokenization_electra import ElectraTokenizer
if is_tokenizers_available():
......
......@@ -15,7 +15,11 @@
# limitations under the License.
""" ELECTRA model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -170,3 +174,15 @@ class ElectraConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
class ElectraOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
......@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig
......@@ -209,6 +210,15 @@ class FeaturesManager:
"token-classification",
onnx_config_cls=LayoutLMOnnxConfig,
),
"electra": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
......@@ -174,6 +174,7 @@ PYTORCH_EXPORT_MODELS = {
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment