Unverified Commit 6b655cc6 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add ONNX support for MarianMT models (#14586)

* First commit to add MarianMT to ONNX

* Now MarianModel.forward() automatically generates decoder_input_ids, like BartModel.forward()

* Adjusted MarianOnnxConfig.inputs and outputs to work with seq2seq-lm feature

* Style fix

* Added support for other features for already supported models

* Partial support for causal and seq2seq models

* Partial support for causal and seq2seq models

* Add default task for MarianMT ONNX

* Remove automatic creation of decoder_input_ids

* Extend inputs and outputs for MarianMT ONNX config

* Add MarianMT to ONNX unit tests

* Refactor

* OnnxSeq2SeqConfigWithPast to support seq2seq models

* Parameterized the onnx tests

* Restored run_mlm.py

* Restored run_mlm.py

* [WIP] BART update

* BART and MBART

* Add past_key_values and fix dummy decoder inputs

Using a sequence length of 1 in generate_dummy_outputs() produces large discrepancies, presumably due to some hidden optimisations.

* Refactor MarianOnnxConfig to remove custom past_key_values logic

* Fix quality

* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)"

This reverts commit 0f4e39c5.

* is_torch_available test to avoid failing imports

* sorting parameterize parameters to solve ERROR gw0 gw1

* tests fix

* tests fix

* GPT2 with past fix

* Fixed stateful class attribute change that was breaking things when converting multiple models sequentially

* Removed onnx file

* Refactor Marian export to account for base changes

* Fix copies

* Implemented suggestions

* Extend support for causal LM

* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)"

This reverts commit 0f4e39c5.

* is_torch_available test to avoid failing imports

* sorting parameterize parameters to solve ERROR gw0 gw1

* tests fix

* tests fix

* GPT2 with past fix

* Fixed stateful class attribute change that was breaking things when converting multiple models sequentially

* Removed onnx file

* Implemented suggestions

* Fixed __init__ to resolve conflict with master

* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)"

This reverts commit 0f4e39c5

.

* is_torch_available test to avoid failing imports

* sorting parameterize parameters to solve ERROR gw0 gw1

* tests fix

* tests fix

* GPT2 with past fix

* Fixed stateful class attribute change that was breaking things when converting multiple models sequentially

* Removed onnx file

* Implemented suggestions

* Fixed __init__ to resolve conflict with master

* Remove commented import

* Remove ONNX model

* Remove redundant class method

* Tidy up imports

* Fix quality

* Refactor dummy input function

* Add copied from statements to Marian config functions

* Remove false copied from comments

* Fix copy from comment
Co-authored-by: default avatarMassimiliano Bruni <massimiliano.bruni@hcl.com>
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 6a7b9da2
...@@ -42,6 +42,7 @@ Ready-made configurations include the following models: ...@@ -42,6 +42,7 @@ Ready-made configurations include the following models:
- GPT Neo - GPT Neo
- LayoutLM - LayoutLM
- Longformer - Longformer
- Marian
- mBART - mBART
- OpenAI GPT-2 - OpenAI GPT-2
- RoBERTa - RoBERTa
......
...@@ -28,7 +28,7 @@ from ...file_utils import ( ...@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = { _import_structure = {
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig"], "configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig", "MarianOnnxConfig"],
} }
if is_sentencepiece_available(): if is_sentencepiece_available():
...@@ -49,7 +49,7 @@ if is_tf_available(): ...@@ -49,7 +49,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"] _import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_marian import MarianTokenizer from .tokenization_marian import MarianTokenizer
......
...@@ -13,8 +13,14 @@ ...@@ -13,8 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Marian model configuration """ """ Marian model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging from ...utils import logging
...@@ -160,3 +166,226 @@ class MarianConfig(PretrainedConfig): ...@@ -160,3 +166,226 @@ class MarianConfig(PretrainedConfig):
forced_eos_token_id=forced_eos_token_id, forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.inputs
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
# TODO: figure this case out.
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)
return common_inputs
@property
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def _generate_dummy_inputs_for_causal_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
return common_inputs
# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
# We renamed this function because Marian models do not have a sequence classification or question answering head
def _generate_dummy_inputs_for_encoder_and_decoder(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Copied from OnnxConfig.generate_dummy_inputs
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
else:
common_inputs = self._generate_dummy_inputs_for_causal_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
return common_inputs
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
...@@ -310,7 +310,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -310,7 +310,7 @@ class MarianTokenizer(PreTrainedTokenizer):
self.current_spm = self.spm_source self.current_spm = self.spm_source
self._setup_normalizer() self._setup_normalizer()
def num_special_tokens_to_add(self, **unused): def num_special_tokens_to_add(self, *args, **kwargs):
"""Just EOS""" """Just EOS"""
return 1 return 1
......
...@@ -11,6 +11,7 @@ from ..models.gpt2 import GPT2OnnxConfig ...@@ -11,6 +11,7 @@ from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig from ..models.longformer import LongformerOnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
...@@ -152,6 +153,15 @@ class FeaturesManager: ...@@ -152,6 +153,15 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=LongformerOnnxConfig, onnx_config_cls=LongformerOnnxConfig,
), ),
"marian": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
),
"roberta": supported_features_mapping( "roberta": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -188,6 +188,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ...@@ -188,6 +188,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"), ("bart", "facebook/bart-base"),
("mbart", "sshleifer/tiny-mbart"), ("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"), ("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment