"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9129fd0377e4d46cb2d0ea28dc1eb91a15f65b77"
Unverified Commit dc3645dc authored by Manan Dey's avatar Manan Dey Committed by GitHub
Browse files

add `mobilebert` onnx configs (#17029)

* update docs of length_penalty

* Revert "update docs of length_penalty"

This reverts commit 466bf4800b75ec29bd2ff75bad8e8973bd98d01c.

* add mobilebert onnx config

* address suggestions

* Update auto.mdx

* Update __init__.py

* Update features.py
parent a021f2b9
...@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its ...@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] TFAutoModelForMultipleChoice [[autodoc]] TFAutoModelForMultipleChoice
## TFAutoModelForNextSentencePrediction
[[autodoc]] TFAutoModelForNextSentencePrediction
## TFAutoModelForTableQuestionAnswering ## TFAutoModelForTableQuestionAnswering
[[autodoc]] TFAutoModelForTableQuestionAnswering [[autodoc]] TFAutoModelForTableQuestionAnswering
......
...@@ -68,6 +68,7 @@ Ready-made configurations include the following architectures: ...@@ -68,6 +68,7 @@ Ready-made configurations include the following architectures:
- M2M100 - M2M100
- Marian - Marian
- mBART - mBART
- MobileBert
- OpenAI GPT-2 - OpenAI GPT-2
- PLBart - PLBart
- RoBERTa - RoBERTa
......
...@@ -1798,6 +1798,7 @@ if is_tf_available(): ...@@ -1798,6 +1798,7 @@ if is_tf_available():
"TFAutoModelForSeq2SeqLM", "TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification", "TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq", "TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForNextSentencePrediction",
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq", "TFAutoModelForVision2Seq",
...@@ -3964,6 +3965,7 @@ if TYPE_CHECKING: ...@@ -3964,6 +3965,7 @@ if TYPE_CHECKING:
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM,
......
...@@ -108,6 +108,7 @@ if is_tf_available(): ...@@ -108,6 +108,7 @@ if is_tf_available():
"TFAutoModelForSeq2SeqLM", "TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification", "TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq", "TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForNextSentencePrediction",
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq", "TFAutoModelForVision2Seq",
...@@ -224,6 +225,7 @@ if TYPE_CHECKING: ...@@ -224,6 +225,7 @@ if TYPE_CHECKING:
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM,
......
...@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t ...@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
_import_structure = { _import_structure = {
"configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"], "configuration_mobilebert": [
"MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MobileBertConfig",
"MobileBertOnnxConfig",
],
"tokenization_mobilebert": ["MobileBertTokenizer"], "tokenization_mobilebert": ["MobileBertTokenizer"],
} }
...@@ -62,7 +66,11 @@ if is_tf_available(): ...@@ -62,7 +66,11 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_mobilebert import (
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileBertConfig,
MobileBertOnnxConfig,
)
from .tokenization_mobilebert import MobileBertTokenizer from .tokenization_mobilebert import MobileBertTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MobileBERT model configuration""" """ MobileBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -165,3 +168,20 @@ class MobileBertConfig(PretrainedConfig): ...@@ -165,3 +168,20 @@ class MobileBertConfig(PretrainedConfig):
self.true_hidden_size = hidden_size self.true_hidden_size = hidden_size
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert
class MobileBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
...@@ -25,6 +25,7 @@ from ..models.layoutlm import LayoutLMOnnxConfig ...@@ -25,6 +25,7 @@ from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.m2m_100 import M2M100OnnxConfig from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig from ..models.mbart import MBartOnnxConfig
from ..models.mobilebert import MobileBertOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
...@@ -44,6 +45,7 @@ if is_torch_available(): ...@@ -44,6 +45,7 @@ if is_torch_available():
AutoModelForMaskedImageModeling, AutoModelForMaskedImageModeling,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
...@@ -55,6 +57,7 @@ if is_tf_available(): ...@@ -55,6 +57,7 @@ if is_tf_available():
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
...@@ -108,6 +111,7 @@ class FeaturesManager: ...@@ -108,6 +111,7 @@ class FeaturesManager:
"question-answering": AutoModelForQuestionAnswering, "question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification, "image-classification": AutoModelForImageClassification,
"masked-im": AutoModelForMaskedImageModeling, "masked-im": AutoModelForMaskedImageModeling,
"next-sentence-prediction": AutoModelForNextSentencePrediction,
} }
if is_tf_available(): if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = { _TASKS_TO_TF_AUTOMODELS = {
...@@ -119,6 +123,7 @@ class FeaturesManager: ...@@ -119,6 +123,7 @@ class FeaturesManager:
"token-classification": TFAutoModelForTokenClassification, "token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice, "multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering, "question-answering": TFAutoModelForQuestionAnswering,
"next-sentence-prediction": TFAutoModelForNextSentencePrediction,
} }
# Set of model topologies we support associated to the features supported by each topology and the factory # Set of model topologies we support associated to the features supported by each topology and the factory
...@@ -153,6 +158,7 @@ class FeaturesManager: ...@@ -153,6 +158,7 @@ class FeaturesManager:
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
"next-sentence-prediction",
onnx_config_cls=BertOnnxConfig, onnx_config_cls=BertOnnxConfig,
), ),
"big-bird": supported_features_mapping( "big-bird": supported_features_mapping(
...@@ -316,6 +322,16 @@ class FeaturesManager: ...@@ -316,6 +322,16 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=MBartOnnxConfig, onnx_config_cls=MBartOnnxConfig,
), ),
"mobilebert": supported_features_mapping(
"default",
"masked-lm",
"next-sentence-prediction",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=MobileBertOnnxConfig,
),
"m2m-100": supported_features_mapping( "m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
), ),
......
...@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"), ("roformer", "junnyu/roformer_chinese_base"),
("mobilebert", "google/mobilebert-uncased"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment