Unverified Commit 87d08afb authored by aaron's avatar aaron Committed by GitHub
Browse files

electra is added to onnx supported model (#15084)



* electra is added to onnx supported model

* add google/electra-base-generator for test onnx module
Co-authored-by: default avatarLewis Tunstall <lewis.c.tunstall@gmail.com>
parent 0fe17f37
...@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures: ...@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
- BERT - BERT
- CamemBERT - CamemBERT
- DistilBERT - DistilBERT
- ELECTRA
- GPT Neo - GPT Neo
- I-BERT - I-BERT
- LayoutLM - LayoutLM
......
...@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to ...@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = { _import_structure = {
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig"], "configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"],
"tokenization_electra": ["ElectraTokenizer"], "tokenization_electra": ["ElectraTokenizer"],
} }
...@@ -71,7 +71,7 @@ if is_flax_available(): ...@@ -71,7 +71,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
from .tokenization_electra import ElectraTokenizer from .tokenization_electra import ElectraTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -15,7 +15,11 @@ ...@@ -15,7 +15,11 @@
# limitations under the License. # limitations under the License.
""" ELECTRA model configuration""" """ ELECTRA model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -170,3 +174,15 @@ class ElectraConfig(PretrainedConfig): ...@@ -170,3 +174,15 @@ class ElectraConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.use_cache = use_cache self.use_cache = use_cache
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
class ElectraOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
...@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig ...@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig from ..models.ibert import IBertOnnxConfig
...@@ -209,6 +210,15 @@ class FeaturesManager: ...@@ -209,6 +210,15 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls=LayoutLMOnnxConfig, onnx_config_cls=LayoutLMOnnxConfig,
), ),
"electra": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
...@@ -174,6 +174,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -174,6 +174,7 @@ PYTORCH_EXPORT_MODELS = {
("ibert", "kssteven/ibert-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment