"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7735e0406fa3d39051cfed9921a3ceef06e6c76e"
Unverified Commit ec81c11a authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Add OnnxConfig for ConvBERT (#16859)



* add OnnxConfig for ConvBert
Co-authored-by: default avatarChainYo <t.chaigneau.tc@gmail.com>
parent 0d1cff11
...@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures: ...@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- Blenderbot - Blenderbot
- BlenderbotSmall - BlenderbotSmall
- CamemBERT - CamemBERT
- ConvBERT
- Data2VecText - Data2VecText
- Data2VecVision - Data2VecVision
- DistilBERT - DistilBERT
......
...@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t ...@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
_import_structure = { _import_structure = {
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"], "configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"],
"tokenization_convbert": ["ConvBertTokenizer"], "tokenization_convbert": ["ConvBertTokenizer"],
} }
...@@ -58,7 +58,7 @@ if is_tf_available(): ...@@ -58,7 +58,7 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
from .tokenization_convbert import ConvBertTokenizer from .tokenization_convbert import ConvBertTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -14,7 +14,11 @@ ...@@ -14,7 +14,11 @@
# limitations under the License. # limitations under the License.
""" ConvBERT model configuration""" """ ConvBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig): ...@@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig):
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.num_groups = num_groups self.num_groups = num_groups
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class ConvBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
...@@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig ...@@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig from ..models.electra import ElectraOnnxConfig
...@@ -187,6 +188,15 @@ class FeaturesManager: ...@@ -187,6 +188,15 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=CamembertOnnxConfig, onnx_config_cls=CamembertOnnxConfig,
), ),
"convbert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ConvBertOnnxConfig,
),
"distilbert": supported_features_mapping( "distilbert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = {
("bigbird", "google/bigbird-roberta-base"), ("bigbird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment