Unverified Commit 6f5ab9da authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Add MBART to models exportable with ONNX (#13049)

* Add MBART to models exportable with ONNX

* unittest mock

* Add tests

* Misc fixes
parent 13a9c9a3
...@@ -28,7 +28,7 @@ from ...file_utils import ( ...@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = { _import_structure = {
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"], "configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
} }
if is_sentencepiece_available(): if is_sentencepiece_available():
...@@ -66,7 +66,7 @@ if is_flax_available(): ...@@ -66,7 +66,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer from .tokenization_mbart import MBartTokenizer
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MBART model configuration """ """ MBART model configuration """
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfigWithPast
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig): ...@@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
class MBartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
...@@ -9,6 +9,7 @@ from ..models.distilbert import DistilBertOnnxConfig ...@@ -9,6 +9,7 @@ from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.longformer import LongformerOnnxConfig from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
...@@ -58,6 +59,7 @@ class FeaturesManager: ...@@ -58,6 +59,7 @@ class FeaturesManager:
_SUPPORTED_MODEL_KIND = { _SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
......
...@@ -25,6 +25,7 @@ from distutils.util import strtobool ...@@ -25,6 +25,7 @@ from distutils.util import strtobool
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Iterator, Union from typing import Iterator, Union
from unittest import mock
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
...@@ -1007,7 +1008,7 @@ def mockenv(**kwargs): ...@@ -1007,7 +1008,7 @@ def mockenv(**kwargs):
use_tf = os.getenv("USE_TF", False) use_tf = os.getenv("USE_TF", False)
""" """
return unittest.mock.patch.dict(os.environ, kwargs) return mock.patch.dict(os.environ, kwargs)
# from https://stackoverflow.com/a/34333710/9201239 # from https://stackoverflow.com/a/34333710/9201239
......
...@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config, ...@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig, DistilBertConfig,
GPT2Config, GPT2Config,
GPTNeoConfig, GPTNeoConfig,
MBartConfig,
RobertaConfig, RobertaConfig,
XLMRobertaConfig, XLMRobertaConfig,
is_torch_available, is_torch_available,
...@@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig ...@@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig # from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig # from transformers.models.t5 import T5OnnxConfig
...@@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
) )
self.assertTrue( self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past" OnnxConfigWithPast.with_past(config()).use_past,
"OnnxConfigWithPast.from_model_config() should use_past",
) )
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
...@@ -190,6 +193,7 @@ if is_torch_available(): ...@@ -190,6 +193,7 @@ if is_torch_available():
DistilBertModel, DistilBertModel,
GPT2Model, GPT2Model,
GPTNeoModel, GPTNeoModel,
MBartModel,
RobertaModel, RobertaModel,
XLMRobertaModel, XLMRobertaModel,
) )
...@@ -204,6 +208,7 @@ if is_torch_available(): ...@@ -204,6 +208,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
} }
...@@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name): with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default")) self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class.from_pretrained(model)) model = model_class(config_class.from_pretrained(model))
onnx_config = onnx_config_class.default(model.config) onnx_config = onnx_config_class.from_model_config(model.config)
with NamedTemporaryFile("w") as output: with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export( onnx_inputs, onnx_outputs = export(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment