Unverified Commit 6b655cc6 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add ONNX support for MarianMT models (#14586)

* First commit to add MarianMT to ONNX

* Now MarianModel.forward() automatically generates decoder_input_ids, like BartModel.forward()

* Adjusted MarianOnnxConfig.inputs and outputs to work with seq2seq-lm feature

* Style fix

* Added support for other features for already supported models

* Partial support for causal and seq2seq models

* Partial support for causal and seq2seq models

* Add default task for MarianMT ONNX

* Remove automatic creation of decoder_input_ids

* Extend inputs and outputs for MarianMT ONNX config

* Add MarianMT to ONNX unit tests

* Refactor

* OnnxSeq2SeqConfigWithPast to support seq2seq models

* Parameterized the onnx tests

* Restored run_mlm.py

* Restored run_mlm.py

* [WIP] BART update

* BART and MBART

* Add past_key_values and fix dummy decoder inputs

Using a sequence length of 1 in generate_dummy_outputs() produces large discrepancies, presumably due to some hidden optimisations.

* Refactor MarianOnnxConfig to remove custom past_key_values logic

* Fix quality

* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)"

This reverts commit 0f4e39c5.

* is_torch_available test to avoid failing imports

* sorting parameterize parameters to solve ERROR gw0 gw1

* tests fix

* tests fix

* GPT2 with past fix

* Fixed stateful class attribute change that was breaking things when converting multiple models sequentially

* Removed onnx file

* Refactor Marian export to account for base changes

* Fix copies

* Implemented suggestions

* Extend support for causal LM

* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)"

This reverts commit 0f4e39c5.

* is_torch_available test to avoid failing imports

* sorting parameterize parameters to solve ERROR gw0 gw1

* tests fix

* tests fix

* GPT2 with past fix

* Fixed stateful class attribute change that was breaking things when converting multiple models sequentially

* Removed onnx file

* Implemented suggestions

* Fixed __init__ to resolve conflict with master

* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)"

This reverts commit 0f4e39c5

.

* is_torch_available test to avoid failing imports

* sorting parameterize parameters to solve ERROR gw0 gw1

* tests fix

* tests fix

* GPT2 with past fix

* Fixed stateful class attribute change that was breaking things when converting multiple models sequentially

* Removed onnx file

* Implemented suggestions

* Fixed __init__ to resolve conflict with master

* Remove commented import

* Remove ONNX model

* Remove redundant class method

* Tidy up imports

* Fix quality

* Refactor dummy input function

* Add copied from statements to Marian config functions

* Remove false copied from comments

* Fix copy from comment
Co-authored-by: default avatarMassimiliano Bruni <massimiliano.bruni@hcl.com>
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 6a7b9da2
......@@ -42,6 +42,7 @@ Ready-made configurations include the following models:
- GPT Neo
- LayoutLM
- Longformer
- Marian
- mBART
- OpenAI GPT-2
- RoBERTa
......
......@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = {
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig"],
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig", "MarianOnnxConfig"],
}
if is_sentencepiece_available():
......@@ -49,7 +49,7 @@ if is_tf_available():
if is_flax_available():
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig
if is_sentencepiece_available():
from .tokenization_marian import MarianTokenizer
......
......@@ -13,8 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Marian model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging
......@@ -160,3 +166,226 @@ class MarianConfig(PretrainedConfig):
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.inputs
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
# TODO: figure this case out.
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)
return common_inputs
@property
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def _generate_dummy_inputs_for_causal_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
return common_inputs
# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
# We renamed this function because Marian models do not have a sequence classification or question answering head
def _generate_dummy_inputs_for_encoder_and_decoder(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Copied from OnnxConfig.generate_dummy_inputs
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
else:
common_inputs = self._generate_dummy_inputs_for_causal_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
return common_inputs
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
......@@ -310,7 +310,7 @@ class MarianTokenizer(PreTrainedTokenizer):
self.current_spm = self.spm_source
self._setup_normalizer()
def num_special_tokens_to_add(self, **unused):
def num_special_tokens_to_add(self, *args, **kwargs):
"""Just EOS"""
return 1
......
......@@ -11,6 +11,7 @@ from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig
......@@ -152,6 +153,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=LongformerOnnxConfig,
),
"marian": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
),
"roberta": supported_features_mapping(
"default",
"masked-lm",
......
......@@ -188,6 +188,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"),
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment