Unverified Commit 5af38953 authored by Ritik Nandwal's avatar Ritik Nandwal Committed by GitHub
Browse files

Added XLM onnx config (#17030)

* Add onnx configuration for xlm

* Add supported features for xlm

* Add xlm to models exportable with onnx

* Add xlm architecture to test file

* Modify docs

* Make code quality fixes
parent 567d9c06
...@@ -75,6 +75,7 @@ Ready-made configurations include the following architectures: ...@@ -75,6 +75,7 @@ Ready-made configurations include the following architectures:
- RoFormer - RoFormer
- T5 - T5
- ViT - ViT
- XLM
- XLM-RoBERTa - XLM-RoBERTa
- XLM-RoBERTa-XL - XLM-RoBERTa-XL
......
...@@ -22,7 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl ...@@ -22,7 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl
_import_structure = { _import_structure = {
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"], "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
"tokenization_xlm": ["XLMTokenizer"], "tokenization_xlm": ["XLMTokenizer"],
} }
...@@ -64,7 +64,7 @@ else: ...@@ -64,7 +64,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
try: try:
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" XLM configuration""" """ XLM configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -228,3 +231,20 @@ class XLMConfig(PretrainedConfig): ...@@ -228,3 +231,20 @@ class XLMConfig(PretrainedConfig):
self.n_words = kwargs["n_words"] self.n_words = kwargs["n_words"]
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class XLMOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
...@@ -30,6 +30,7 @@ from ..models.roberta import RobertaOnnxConfig ...@@ -30,6 +30,7 @@ from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig from ..models.vit import ViTOnnxConfig
from ..models.xlm import XLMOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
from ..utils import logging from ..utils import logging
from .config import OnnxConfig from .config import OnnxConfig
...@@ -357,6 +358,16 @@ class FeaturesManager: ...@@ -357,6 +358,16 @@ class FeaturesManager:
"vit": supported_features_mapping( "vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig "default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
), ),
"xlm": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMOnnxConfig,
),
"xlm-roberta": supported_features_mapping( "xlm-roberta": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -181,6 +181,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -181,6 +181,7 @@ PYTORCH_EXPORT_MODELS = {
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"), ("roformer", "junnyu/roformer_chinese_base"),
("mobilebert", "google/mobilebert-uncased"), ("mobilebert", "google/mobilebert-uncased"),
("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment