Unverified Commit 0f4e39c5 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Revert "Added support for other features for already supported models (#14358)" (#14679)

This reverts commit 0c70f145.
parent 0c70f145
...@@ -167,3 +167,7 @@ class AlbertOnnxConfig(OnnxConfig): ...@@ -167,3 +167,7 @@ class AlbertOnnxConfig(OnnxConfig):
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", {0: "batch", 1: "sequence"}),
] ]
) )
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
...@@ -15,12 +15,10 @@ ...@@ -15,12 +15,10 @@
""" BART model configuration """ """ BART model configuration """
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Mapping, Optional from typing import Mapping
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available from ...onnx import OnnxConfigWithPast
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...utils import logging from ...utils import logging
...@@ -182,174 +180,30 @@ class BartConfig(PretrainedConfig): ...@@ -182,174 +180,30 @@ class BartConfig(PretrainedConfig):
) )
class BartOnnxConfig(OnnxSeq2SeqConfigWithPast): class BartOnnxConfig(OnnxConfigWithPast):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]: return OrderedDict(
common_inputs = OrderedDict( [
[ ("input_ids", {0: "batch", 1: "sequence"}),
("input_ids", {0: "batch", 1: "encoder_sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}), ]
] )
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past: @property
self.fill_with_past_key_values_(common_inputs, direction="inputs") def outputs(self) -> Mapping[str, Mapping[int, str]]:
elif self.task == "causal-lm": if self.use_past:
# TODO: figure this case out. return OrderedDict(
common_inputs = OrderedDict(
[ [
("input_ids", {0: "batch", 1: "encoder_sequence"}), ("last_hidden_state", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}), ("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
] ]
) )
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else: else:
common_inputs = OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "encoder_sequence"}), ("last_hidden_state", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}), ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
] ]
) )
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
elif self.task == "causal-lm":
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
else:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
return common_inputs
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
...@@ -169,3 +169,7 @@ class BertOnnxConfig(OnnxConfig): ...@@ -169,3 +169,7 @@ class BertOnnxConfig(OnnxConfig):
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", {0: "batch", 1: "sequence"}),
] ]
) )
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
...@@ -142,3 +142,7 @@ class DistilBertOnnxConfig(OnnxConfig): ...@@ -142,3 +142,7 @@ class DistilBertOnnxConfig(OnnxConfig):
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
] ]
) )
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
# limitations under the License. # limitations under the License.
""" OpenAI GPT-2 configuration """ """ OpenAI GPT-2 configuration """
from collections import OrderedDict from collections import OrderedDict
from typing import Any, List, Mapping, Optional from typing import Any, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType, is_torch_available from transformers import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec from ...onnx import OnnxConfigWithPast
from ...utils import logging from ...utils import logging
...@@ -194,36 +194,29 @@ class GPT2Config(PretrainedConfig): ...@@ -194,36 +194,29 @@ class GPT2Config(PretrainedConfig):
class GPT2OnnxConfig(OnnxConfigWithPast): class GPT2OnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) common_inputs = OrderedDict({"input_ids": {0: "batch"}})
if self.use_past: if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs") for i in range(self._config.n_layer * 2):
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
else: else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs return common_inputs
@property @property
def num_layers(self) -> int: def outputs(self) -> Mapping[str, Mapping[int, str]]:
return self._config.n_layer common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}
@property return common_outputs
def num_attention_heads(self) -> int:
return self._config.n_head return common_outputs
def generate_dummy_inputs( def generate_dummy_inputs(
self, self,
...@@ -233,9 +226,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast): ...@@ -233,9 +226,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
tokenizer, batch_size, seq_length, is_pair, framework
)
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
...@@ -247,27 +238,14 @@ class GPT2OnnxConfig(OnnxConfigWithPast): ...@@ -247,27 +238,14 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
else: else:
import torch import torch
batch, seqlen = common_inputs["input_ids"].shape batch = common_inputs["input_ids"].shape[0]
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [ ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) (
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
)
for _ in range(self._config.n_layer)
] ]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
return ordered_inputs return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" GPT Neo model configuration """ """ GPT Neo model configuration """
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Mapping, Optional from typing import Any, Dict, Iterable, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
...@@ -212,7 +212,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -212,7 +212,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past: if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs") for i in range(self._config.num_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else: else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
...@@ -220,8 +223,16 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -220,8 +223,16 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
return common_inputs return common_inputs
@property @property
def num_attention_heads(self) -> int: def outputs(self) -> Mapping[str, Mapping[int, str]]:
return self._config.num_heads common_outputs = super().outputs
if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
return common_outputs
def generate_dummy_inputs( def generate_dummy_inputs(
self, self,
...@@ -231,10 +242,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -231,10 +242,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
...@@ -246,27 +254,28 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -246,27 +254,28 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
else: else:
import torch import torch
batch, seqlen = common_inputs["input_ids"].shape batch = common_inputs["input_ids"].shape[0]
# Not using the same length for past_key_values past_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [ ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self._config.num_layers)
] ]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past: if self.use_past:
ordered_inputs["attention_mask"] = torch.cat( ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [ordered_inputs["attention_mask"], torch.ones(batch, 1)], dim=1
) )
return ordered_inputs return ordered_inputs
@property @staticmethod
def default_onnx_opset(self) -> int: def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
return 13 if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
""" MBART model configuration """ """ MBART model configuration """
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Mapping, Optional from typing import Mapping
from transformers.onnx import OnnxConfigWithPast
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...utils import logging from ...utils import logging
...@@ -166,175 +165,30 @@ class MBartConfig(PretrainedConfig): ...@@ -166,175 +165,30 @@ class MBartConfig(PretrainedConfig):
) )
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart class MBartOnnxConfig(OnnxConfigWithPast):
class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]: return OrderedDict(
common_inputs = OrderedDict( [
[ ("input_ids", {0: "batch", 1: "sequence"}),
("input_ids", {0: "batch", 1: "encoder_sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}), ]
] )
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past: @property
self.fill_with_past_key_values_(common_inputs, direction="inputs") def outputs(self) -> Mapping[str, Mapping[int, str]]:
elif self.task == "causal-lm": if self.use_past:
# TODO: figure this case out. return OrderedDict(
common_inputs = OrderedDict(
[ [
("input_ids", {0: "batch", 1: "encoder_sequence"}), ("last_hidden_state", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}), ("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
] ]
) )
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else: else:
common_inputs = OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "encoder_sequence"}), ("last_hidden_state", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}), ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
] ]
) )
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
elif self.task == "causal-lm":
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
else:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
return common_inputs
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
...@@ -76,3 +76,7 @@ class RobertaOnnxConfig(OnnxConfig): ...@@ -76,3 +76,7 @@ class RobertaOnnxConfig(OnnxConfig):
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
] ]
) )
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" T5 model configuration """ """ T5 model configuration """
from typing import Mapping from collections import OrderedDict
from typing import Any, Dict, Iterable, Mapping, Optional
# from ... import is_torch_available from transformers import PreTrainedTokenizer, TensorType
from ... import is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast from ...onnx import OnnxConfigWithPast
from ...utils import logging from ...utils import logging
...@@ -122,26 +125,101 @@ class T5Config(PretrainedConfig): ...@@ -122,26 +125,101 @@ class T5Config(PretrainedConfig):
) )
class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): class T5OnnxConfig(OnnxConfigWithPast):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = { common_inputs = OrderedDict(
"input_ids": {0: "batch", 1: "encoder_sequence"}, [
"attention_mask": {0: "batch", 1: "encoder_sequence"}, ("input_ids", {0: "batch", 1: "encoder_sequence"}),
} ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
if self.use_past: ("decoder_input_ids", {0: "batch"}),
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" ("decoder_attention_mask", {0: "batch"}),
common_inputs["decoder_input_ids"] = {0: "batch"} ]
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} )
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past: if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs") for i in range(0, self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
return common_inputs return common_inputs
@property @property
def default_onnx_opset(self) -> int: def outputs(self) -> Mapping[str, Mapping[int, str]]:
return 13 common_outputs = super().outputs
if "last_hidden_state" in common_outputs:
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
if self.task == "default":
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# Generate decoder inputs
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = encoder_inputs["input_ids"].shape[0]
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
encoder_shape = (
batch,
self._config.num_heads,
encoder_seq_length,
self._config.hidden_size // self._config.num_heads,
)
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
ordered_inputs["past_key_values"] = []
for _ in range(self._config.num_layers):
ordered_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
return flatten_output
return super().flatten_output_collection_property(name, field)
...@@ -53,3 +53,7 @@ class XLMRobertaOnnxConfig(OnnxConfig): ...@@ -53,3 +53,7 @@ class XLMRobertaOnnxConfig(OnnxConfig):
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", {0: "batch", 1: "sequence"}),
] ]
) )
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
...@@ -13,12 +13,6 @@ ...@@ -13,12 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .config import ( from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
OnnxConfigWithPast,
OnnxSeq2SeqConfigWithPast,
PatchingSpec,
)
from .convert import export, validate_model_outputs from .convert import export, validate_model_outputs
from .utils import ParameterFormat, compute_serialized_parameters_size from .utils import ParameterFormat, compute_serialized_parameters_size
...@@ -32,10 +32,10 @@ def main(): ...@@ -32,10 +32,10 @@ def main():
help="Export the model with some additional feature.", help="Export the model with some additional feature.",
) )
parser.add_argument( parser.add_argument(
"--opset", type=int, default=None, help="ONNX opset version to export the model with (default 12)." "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
) )
parser.add_argument( parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." "--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
) )
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
...@@ -53,9 +53,6 @@ def main(): ...@@ -53,9 +53,6 @@ def main():
onnx_config = model_onnx_config(model.config) onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient # Ensure the requested opset is sufficient
if args.opset is None:
args.opset = onnx_config.default_onnx_opset
if args.opset < onnx_config.default_onnx_opset: if args.opset < onnx_config.default_onnx_opset:
raise ValueError( raise ValueError(
f"Opset {args.opset} is not sufficient to export {model_kind}. " f"Opset {args.opset} is not sufficient to export {model_kind}. "
...@@ -64,9 +61,6 @@ def main(): ...@@ -64,9 +61,6 @@ def main():
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output) onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
if args.atol is None:
args.atol = onnx_config.atol_for_validation
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol) validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}") logger.info(f"All good, model saved at: {args.output.as_posix()}")
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType, is_torch_available from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size
...@@ -58,7 +58,6 @@ class OnnxConfig(ABC): ...@@ -58,7 +58,6 @@ class OnnxConfig(ABC):
_TASKS_TO_COMMON_OUTPUTS = { _TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}),
...@@ -120,8 +119,7 @@ class OnnxConfig(ABC): ...@@ -120,8 +119,7 @@ class OnnxConfig(ABC):
Returns: Returns:
For each output: its name associated to the axes symbolic name and the axis position within the tensor For each output: its name associated to the axes symbolic name and the axis position within the tensor
""" """
common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task] return self._TASKS_TO_COMMON_OUTPUTS[self.task]
return common_outputs
@property @property
def values_override(self) -> Optional[Mapping[str, Any]]: def values_override(self) -> Optional[Mapping[str, Any]]:
...@@ -167,16 +165,6 @@ class OnnxConfig(ABC): ...@@ -167,16 +165,6 @@ class OnnxConfig(ABC):
""" """
return DEFAULT_ONNX_OPSET return DEFAULT_ONNX_OPSET
@property
def atol_for_validation(self) -> float:
"""
What absolute tolerance value to use during model conversion validation.
Returns:
Float absolute tolerance value.
"""
return 1e-5
@staticmethod @staticmethod
def use_external_data_format(num_parameters: int) -> bool: def use_external_data_format(num_parameters: int) -> bool:
""" """
...@@ -241,8 +229,8 @@ class OnnxConfig(ABC): ...@@ -241,8 +229,8 @@ class OnnxConfig(ABC):
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op) setattr(spec.o, spec.name, orig_op)
@classmethod @staticmethod
def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]: def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
""" """
Flatten any potential nested structure expanding the name of the field with the index of the element within the Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure. structure.
...@@ -284,14 +272,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -284,14 +272,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
""" """
return cls(config, task=task, use_past=True) return cls(config, task=task, use_past=True)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
self.fill_with_past_key_values_(common_outputs, direction="outputs")
return common_outputs
@property @property
def values_override(self) -> Optional[Mapping[str, Any]]: def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"): if hasattr(self._config, "use_cache"):
...@@ -299,30 +279,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -299,30 +279,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
return None return None
@property
def num_layers(self) -> int:
"""
The number of layers attribute retrieved from the model config. Override this for model configs where the
number of layers attribute is not called `num_layers`.
"""
if not hasattr(self._config, "num_layers"):
raise AttributeError(
"could not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
)
return self._config.num_layers
@property
def num_attention_heads(self) -> int:
"""
The number of attention heads attribute retrieved from the model config. Override this for model configs where
the number of attention heads attribute is not called `num_attention_heads`.
"""
if not hasattr(self._config, "num_attention_heads"):
raise AttributeError(
"could not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
)
return self._config.num_attention_heads
def generate_dummy_inputs( def generate_dummy_inputs(
self, self,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
...@@ -331,217 +287,32 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -331,217 +287,32 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0
)
# TODO: should we set seq_length = 1 when self.use_past = True? # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
if "attention_mask" in common_inputs:
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = []
for _ in range(self.num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
"""
Fill the input_or_ouputs mapping with past_key_values dynamic axes considering.
Args:
inputs_or_outputs: The mapping to fill.
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
output mapping, this is important for axes naming.
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
name = "past_key_values" if direction == "inputs" else "present" # When use_past the caching mechanism requires inputs to be only 1 single token
for i in range(self.num_layers): fixed_sequence_length = 1 if self.use_past else self.default_sequence_length
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} seq_length = compute_effective_axis_dimension(
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
)
def _flatten_past_key_values_(self, flattened_output, name, idx, t): # Generate dummy inputs according to compute batch and sequence
flattened_output[f"{name}.{idx}.key"] = t[0] dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
flattened_output[f"{name}.{idx}.value"] = t[1] return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: @staticmethod
flattened_output = {} def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]: if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field): for idx, t in enumerate(field):
self._flatten_past_key_values_(flattened_output, name, idx, t) flatten_output[f"{name}.{idx}.key"] = t[0]
else: flatten_output[f"{name}.{idx}.value"] = t[1]
flattened_output = super().flatten_output_collection_property(name, field)
return flattened_output
class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task]
# Renaming the outputs axes properly.
for name, axes_names in common_outputs.items():
sequence_name = "encoder_sequence" if "encoder" in name else "decoder_sequence"
for axis_idx, name in axes_names.items():
if "sequence" in name:
axes_names[axis_idx] = sequence_name
# We reset the value as the order in common_outputs (OrderedDict) is lost otherwise
else:
axes_names[axis_idx] = name
if self.use_past:
self.fill_with_past_key_values_(common_outputs, direction="outputs")
return common_outputs
@property
def num_layers(self) -> Tuple[int]:
try:
num_layers = super().num_layers
num_layers = (num_layers, num_layers)
except AttributeError:
if hasattr(self._config, "encoder_layers") and hasattr(self._config, "decoder_layers"):
num_layers = (self._config.encoder_layers, self._config.decoder_layers)
else:
raise AttributeError(
"could not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
)
return num_layers
@property return flatten_output
def num_attention_heads(self) -> Tuple[int]:
try:
num_attention_heads = super().num_attention_heads
num_attention_heads = (num_attention_heads, num_attention_heads)
except AttributeError:
if hasattr(self._config, "encoder_attention_heads") and hasattr(self._config, "decoder_attention_heads"):
num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads)
else:
raise AttributeError(
"could not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
)
return num_attention_heads
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = common_inputs["input_ids"].shape[0]
encoder_seq_length = common_inputs["input_ids"].shape[1]
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_shape = (
batch,
num_decoder_attention_heads,
# Not using the same length for past_key_values
decoder_seq_length + 3,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["past_key_values"] = [] return super().flatten_output_collection_property(name, field)
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
# For encoder-decoder models, past_key_values contains pre-computed values for both the encoder and the
# decoder layers, hence a tuple of 4 tensors instead of 2
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
name = "past_key_values" if direction == "inputs" else "present"
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
encoder_sequence = "past_encoder_sequence"
decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence"
for i in range(min_num_layers):
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence}
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence}
for i in range(min_num_layers, max_num_layers):
if remaining_side_name == "encoder":
axes_info = {0: "batch", 2: encoder_sequence}
else:
axes_info = {0: "batch", 2: decoder_sequence}
inputs_or_outputs[f"{name}.{i}.{remaining_side_name}.key"] = axes_info
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]
...@@ -191,7 +191,7 @@ def validate_model_outputs( ...@@ -191,7 +191,7 @@ def validate_model_outputs(
f"{onnx_outputs_set.difference(ref_outputs_set)}" f"{onnx_outputs_set.difference(ref_outputs_set)}"
) )
else: else:
logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set})") logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}")
# Check the shape and values match # Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs): for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
......
from functools import partial, reduce from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type from typing import Callable, Tuple
from .. import PretrainedConfig, is_torch_available from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
...@@ -15,7 +15,6 @@ from ..models.mbart import MBartOnnxConfig ...@@ -15,7 +15,6 @@ from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
from .config import OnnxConfig
if is_torch_available(): if is_torch_available():
...@@ -23,7 +22,6 @@ if is_torch_available(): ...@@ -23,7 +22,6 @@ if is_torch_available():
from transformers.models.auto import ( from transformers.models.auto import (
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
...@@ -32,19 +30,8 @@ if is_torch_available(): ...@@ -32,19 +30,8 @@ if is_torch_available():
) )
def supported_features_mapping( def supported_features_mapping(*supported_features, onnx_config_cls=None):
*supported_features: str, onnx_config_cls: Type[OnnxConfig] = None """Generates the mapping between supported features and their corresponding OnnxConfig."""
) -> Dict[str, Callable[[PretrainedConfig, str], OnnxConfig]]:
"""
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
Args:
*supported_features: The names of the supported features.
onnx_config_cls: The OnnxConfig class corresponding to the model.
Returns:
The dictionary mapping a feature to an OnnxConfig constructor.
"""
if onnx_config_cls is None: if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided") raise ValueError("A OnnxConfig class must be provided")
...@@ -62,7 +49,6 @@ def supported_features_mapping( ...@@ -62,7 +49,6 @@ def supported_features_mapping(
class FeaturesManager: class FeaturesManager:
_TASKS_TO_AUTOMODELS = { _TASKS_TO_AUTOMODELS = {
"default": AutoModel, "default": AutoModel,
"masked-lm": AutoModelForMaskedLM,
"causal-lm": AutoModelForCausalLM, "causal-lm": AutoModelForCausalLM,
"seq2seq-lm": AutoModelForSeq2SeqLM, "seq2seq-lm": AutoModelForSeq2SeqLM,
"sequence-classification": AutoModelForSequenceClassification, "sequence-classification": AutoModelForSequenceClassification,
...@@ -72,110 +58,27 @@ class FeaturesManager: ...@@ -72,110 +58,27 @@ class FeaturesManager:
} }
# Set of model topologies we support associated to the features supported by each topology and the factory # Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = { _SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping( "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"default", "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"masked-lm", "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"sequence-classification", "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=AlbertOnnxConfig,
),
"bart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=BartOnnxConfig,
),
"mbart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=MBartOnnxConfig,
),
"bert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=BertOnnxConfig,
),
"camembert": supported_features_mapping( "camembert": supported_features_mapping(
"default", "default",
"masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=CamembertOnnxConfig, onnx_config_cls=CamembertOnnxConfig,
), ),
"distilbert": supported_features_mapping( "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"default", "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"masked-lm", "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"sequence-classification", "roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=DistilBertOnnxConfig,
),
"longformer": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=LongformerOnnxConfig,
),
"roberta": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=RobertaOnnxConfig,
),
"t5": supported_features_mapping( "t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
), ),
"xlm-roberta": supported_features_mapping( "xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMRobertaOnnxConfig,
),
"gpt2": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"token-classification",
"default-with-past",
"causal-lm-with-past",
"sequence-classification-with-past",
"token-classification-with-past",
onnx_config_cls=GPT2OnnxConfig,
),
"gpt-neo": supported_features_mapping( "gpt-neo": supported_features_mapping(
"default", "default",
"causal-lm", "causal-lm",
...@@ -194,46 +97,23 @@ class FeaturesManager: ...@@ -194,46 +97,23 @@ class FeaturesManager:
), ),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
@staticmethod
def get_supported_features_for_model_type(
model_type: str, model_name: Optional[str] = None
) -> Dict[str, Callable[[PretrainedConfig, str], OnnxConfig]]:
"""
Try to retrieve the feature -> OnnxConfig constructor map from the model type.
Args:
model_type: The model type to retrieve the supported features for.
model_name: The name attribute of the model object, only used for the exception message.
Returns:
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
"""
model_type = model_type.lower()
if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE:
model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
raise KeyError(
f"{model_type_and_model_name} is not supported yet. "
f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type]
@staticmethod @staticmethod
def feature_to_task(feature: str) -> str: def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "") return feature.replace("-with-past", "")
@staticmethod @staticmethod
def get_model_class_for_feature(feature: str) -> Type: def get_model_from_feature(feature: str, model: str):
""" """
Attempt to retrieve an AutoModel class from a feature name. Attempt to retrieve a model from a model's name and the feature to be enabled.
Args: Args:
feature: The feature required. feature: The feature required
model: The name of the model to export
Returns: Returns:
The AutoModel class corresponding to the feature.
""" """
task = FeaturesManager.feature_to_task(feature) task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS: if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
...@@ -241,43 +121,38 @@ class FeaturesManager: ...@@ -241,43 +121,38 @@ class FeaturesManager:
f"Unknown task: {feature}. " f"Unknown task: {feature}. "
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
) )
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
def get_model_from_feature(feature: str, model: str) -> PreTrainedModel: return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Args:
feature: The feature required.
model: The name of the model to export.
Returns:
The instance of the model.
"""
model_class = FeaturesManager.get_model_class_for_feature(feature)
return model_class.from_pretrained(model)
@staticmethod @staticmethod
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]: def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
""" """
Check whether or not the model has the requested features. Check whether or not the model has the requested features
Args: Args:
model: The model to export. model: The model to export
feature: The name of the feature to check if it is available. feature: The name of the feature to check if it is available
Returns: Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties. (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
""" """
model_type = model.config.model_type.replace("_", "-") model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", "") model_name = getattr(model, "name", "")
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name) model_name = f"({model_name})" if model_name else ""
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model_name}) is not supported yet. "
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
if feature not in model_features: if feature not in model_features:
raise ValueError( raise ValueError(
f"{model.config.model_type} doesn't support feature {feature}. " f"{model.config.model_type} doesn't support feature {feature}. "
f"Supported values are: {model_features}" f"Supported values are: {list(model_features.keys())}"
) )
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature] return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
...@@ -3,8 +3,33 @@ from tempfile import NamedTemporaryFile ...@@ -3,8 +3,33 @@ from tempfile import NamedTemporaryFile
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
from parameterized import parameterized from transformers import ( # LongformerConfig,; T5Config,
from transformers import AutoConfig, AutoTokenizer AlbertConfig,
AutoTokenizer,
BartConfig,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
LayoutLMConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
is_torch_available,
)
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.layoutlm import LayoutLMOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import ( from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT, EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig, OnnxConfig,
...@@ -12,8 +37,7 @@ from transformers.onnx import ( ...@@ -12,8 +37,7 @@ from transformers.onnx import (
export, export,
validate_model_outputs, validate_model_outputs,
) )
from transformers.onnx.config import OnnxConfigWithPast from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow from transformers.testing_utils import require_onnx, require_torch, slow
...@@ -115,12 +139,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -115,12 +139,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX) Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
""" """
SUPPORTED_WITH_PAST_CONFIGS = {} SUPPORTED_WITH_PAST_CONFIGS = {
# SUPPORTED_WITH_PAST_CONFIGS = { ("BART", BartConfig),
# ("BART", BartConfig), ("GPT2", GPT2Config),
# ("GPT2", GPT2Config), # ("T5", T5Config)
# # ("T5", T5Config) }
# }
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_use_past(self): def test_use_past(self):
...@@ -164,37 +187,40 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -164,37 +187,40 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
) )
PYTORCH_EXPORT_MODELS = { if is_torch_available():
("albert", "hf-internal-testing/tiny-albert"), from transformers import ( # T5Model,
("bert", "bert-base-cased"), AlbertModel,
("camembert", "camembert-base"), BartModel,
("distilbert", "distilbert-base-cased"), BertModel,
# ("longFormer", "longformer-base-4096"), DistilBertModel,
("roberta", "roberta-base"), GPT2Model,
("xlm-roberta", "xlm-roberta-base"), GPTNeoModel,
("layoutlm", "microsoft/layoutlm-base-uncased"), LayoutLMModel,
} MBartModel,
RobertaModel,
PYTORCH_EXPORT_WITH_PAST_MODELS = { XLMRobertaModel,
("gpt2", "gpt2"), )
("gpt-neo", "EleutherAI/gpt-neo-125M"),
} PYTORCH_EXPORT_DEFAULT_MODELS = {
("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig),
PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
("bart", "facebook/bart-base"), ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("mbart", "sshleifer/tiny-mbart"), ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("t5", "t5-small"), ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
} ("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
def _get_models_to_test(export_models_list): ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
models_to_test = [] ("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
for (name, model) in export_models_list: ("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type( # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
name }
).items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor)) PYTORCH_EXPORT_WITH_PAST_MODELS = {
return models_to_test # ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
}
class OnnxExportTestCaseV2(TestCase): class OnnxExportTestCaseV2(TestCase):
...@@ -202,52 +228,52 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -202,52 +228,52 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported Integration tests ensuring supported models are correctly exported
""" """
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): @slow
@require_torch
def test_pytorch_export_default(self):
from transformers.onnx import export from transformers.onnx import export
tokenizer = AutoTokenizer.from_pretrained(model_name) for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
config = AutoConfig.from_pretrained(model_name) with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
# Useful for causal lm models that do not use pad tokens.
if not getattr(config, "pad_token_id", None):
config.pad_token_id = tokenizer.eos_token_id
model_class = FeaturesManager.get_model_class_for_feature(feature)
model = model_class.from_config(config)
onnx_config = onnx_config_class_constructor(model.config)
with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(
tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
)
try:
validate_model_outputs(
onnx_config,
tokenizer,
model,
Path(output.name),
onnx_outputs,
onnx_config.atol_for_validation,
)
except ValueError as ve:
self.fail(f"{name}, {feature} -> {ve}")
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) tokenizer = AutoTokenizer.from_pretrained(model)
@slow model = model_class(config_class.from_pretrained(model))
@require_torch onnx_config = onnx_config_class.from_model_config(model.config)
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) with NamedTemporaryFile("w") as output:
@slow onnx_inputs, onnx_outputs = export(
@require_torch tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor): )
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
try:
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@slow @slow
@require_torch @require_torch
def test_pytorch_export_seq2seq_with_past( def test_pytorch_export_with_past(self):
self, test_name, name, model_name, feature, onnx_config_class_constructor from transformers.onnx import export
):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.with_past(model.config)
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
self.assertTrue(
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
)
with NamedTemporaryFile("w") as output:
output = Path(output.name)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
try:
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment