Unverified Commit 6f5ab9da authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Add MBART to models exportable with ONNX (#13049)

* Add MBART to models exportable with ONNX

* unittest mock

* Add tests

* Misc fixes
parent 13a9c9a3
......@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = {
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"],
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
}
if is_sentencepiece_available():
......@@ -66,7 +66,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer
......
......@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" MBART model configuration """
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfigWithPast
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig):
@property
def hidden_size(self) -> int:
return self.d_model
class MBartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
......@@ -9,6 +9,7 @@ from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
......@@ -58,6 +59,7 @@ class FeaturesManager:
_SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
......
......@@ -25,6 +25,7 @@ from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from typing import Iterator, Union
from unittest import mock
from transformers import logging as transformers_logging
......@@ -1007,7 +1008,7 @@ def mockenv(**kwargs):
use_tf = os.getenv("USE_TF", False)
"""
return unittest.mock.patch.dict(os.environ, kwargs)
return mock.patch.dict(os.environ, kwargs)
# from https://stackoverflow.com/a/34333710/9201239
......
......@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
is_torch_available,
......@@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig
......@@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
)
self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
OnnxConfigWithPast.with_past(config()).use_past,
"OnnxConfigWithPast.from_model_config() should use_past",
)
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
......@@ -190,6 +193,7 @@ if is_torch_available():
DistilBertModel,
GPT2Model,
GPTNeoModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
)
......@@ -204,6 +208,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
......@@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default"))
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class.from_pretrained(model))
onnx_config = onnx_config_class.default(model.config)
onnx_config = onnx_config_class.from_model_config(model.config)
with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment