Unverified Commit 4f38808e authored by Ruihua Fang's avatar Ruihua Fang Committed by GitHub
Browse files

Add OnnxConfig for SqueezeBert iss17314 (#17315)



* add onnx config for SqueezeBert

* add test for onnx config for SqueezeBert

* add automatically updated doc for onnx config for SqueezeBert

* Update src/transformers/onnx/features.py
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* Update src/transformers/models/squeezebert/configuration_squeezebert.py
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent ba286fe7
...@@ -73,6 +73,7 @@ Ready-made configurations include the following architectures: ...@@ -73,6 +73,7 @@ Ready-made configurations include the following architectures:
- PLBart - PLBart
- RoBERTa - RoBERTa
- RoFormer - RoFormer
- SqueezeBERT
- T5 - T5
- ViT - ViT
- XLM - XLM
......
...@@ -22,7 +22,11 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_ ...@@ -22,7 +22,11 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
_import_structure = { _import_structure = {
"configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"], "configuration_squeezebert": [
"SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SqueezeBertConfig",
"SqueezeBertOnnxConfig",
],
"tokenization_squeezebert": ["SqueezeBertTokenizer"], "tokenization_squeezebert": ["SqueezeBertTokenizer"],
} }
...@@ -54,7 +58,11 @@ else: ...@@ -54,7 +58,11 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig from .configuration_squeezebert import (
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
SqueezeBertConfig,
SqueezeBertOnnxConfig,
)
from .tokenization_squeezebert import SqueezeBertTokenizer from .tokenization_squeezebert import SqueezeBertTokenizer
try: try:
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" SqueezeBERT model configuration""" """ SqueezeBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -154,3 +157,20 @@ class SqueezeBertConfig(PretrainedConfig): ...@@ -154,3 +157,20 @@ class SqueezeBertConfig(PretrainedConfig):
self.post_attention_groups = post_attention_groups self.post_attention_groups = post_attention_groups
self.intermediate_groups = intermediate_groups self.intermediate_groups = intermediate_groups
self.output_groups = output_groups self.output_groups = output_groups
# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert
class SqueezeBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
...@@ -28,6 +28,7 @@ from ..models.mbart import MBartOnnxConfig ...@@ -28,6 +28,7 @@ from ..models.mbart import MBartOnnxConfig
from ..models.mobilebert import MobileBertOnnxConfig from ..models.mobilebert import MobileBertOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig from ..models.roformer import RoFormerOnnxConfig
from ..models.squeezebert import SqueezeBertOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig from ..models.vit import ViTOnnxConfig
from ..models.xlm import XLMOnnxConfig from ..models.xlm import XLMOnnxConfig
...@@ -352,6 +353,15 @@ class FeaturesManager: ...@@ -352,6 +353,15 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls=RoFormerOnnxConfig, onnx_config_cls=RoFormerOnnxConfig,
), ),
"squeezebert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=SqueezeBertOnnxConfig,
),
"t5": supported_features_mapping( "t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
), ),
......
...@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"), ("roformer", "junnyu/roformer_chinese_base"),
("squeezebert", "squeezebert/squeezebert-uncased"),
("mobilebert", "google/mobilebert-uncased"), ("mobilebert", "google/mobilebert-uncased"),
("xlm", "xlm-clm-ende-1024"), ("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment