Unverified Commit dc420b0e authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

T5 with past ONNX export (#13014)



T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model
Authored-by: default avatarMichael Benayoun <michael@huggingface.co>
parent ee112246
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" GPT Neo model configuration """ """ GPT Neo model configuration """
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Mapping, Optional from typing import Any, Dict, Iterable, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
...@@ -253,8 +253,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -253,8 +253,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past: if self.use_past:
for i in range(self._number_key_values): for i in range(self._config.num_layers):
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i] if self._config.attention_layers[i] == "local":
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
...@@ -264,9 +268,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -264,9 +268,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
def outputs(self) -> Mapping[str, Mapping[int, str]]: def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs common_outputs = super().outputs
if self.use_past: if self.use_past:
for i in range(self._number_key_values): for i in range(self._config.num_layers):
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i] if self._config.attention_layers[i] == "local":
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
return common_outputs return common_outputs
def generate_dummy_inputs( def generate_dummy_inputs(
...@@ -315,3 +322,18 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -315,3 +322,18 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
) )
return ordered_inputs return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
if len(t) == 1:
flatten_output[f"{name}.{idx}.key_value"] = t[0]
else:
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
""" T5 model configuration """ """ T5 model configuration """
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Mapping, Optional from typing import Any, Dict, Iterable, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType from transformers import PreTrainedTokenizer, TensorType
from ... import is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast from ...onnx import OnnxConfigWithPast
from ...utils import logging from ...utils import logging
...@@ -140,9 +141,6 @@ class T5Config(PretrainedConfig): ...@@ -140,9 +141,6 @@ class T5Config(PretrainedConfig):
class T5OnnxConfig(OnnxConfigWithPast): class T5OnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config, use_past)
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict( common_inputs = OrderedDict(
...@@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast): ...@@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast):
) )
if self.use_past: if self.use_past:
for i in range(self._config.num_layers): for i in range(0, self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},) common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},) common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},) common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},) common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
return common_inputs return common_inputs
@property @property
def outputs(self) -> Mapping[str, Mapping[int, str]]: def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict( common_outputs = super().outputs
[
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}), if "last_hidden_state" in common_outputs:
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}), common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
]
)
if self.use_past: if self.use_past:
for i in range(self._config.num_layers): for i in range(self._config.num_layers):
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},) common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},) common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},) common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},) common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
if self.task == "default":
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
return common_outputs return common_outputs
...@@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast): ...@@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast):
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
if self.use_past:
raise NotImplementedError()
# Generate encoder inputs # Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
...@@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast): ...@@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast):
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework) decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
return dict(**encoder_inputs, **decoder_inputs) ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = encoder_inputs["input_ids"].shape[0]
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
encoder_shape = (
batch,
self._config.num_heads,
encoder_seq_length,
self._config.hidden_size // self._config.num_heads,
)
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
ordered_inputs["past_key_values"] = []
for _ in range(self._config.num_layers):
ordered_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
return flatten_output
return super().flatten_output_collection_property(name, field)
...@@ -429,8 +429,6 @@ class T5Attention(nn.Module): ...@@ -429,8 +429,6 @@ class T5Attention(nn.Module):
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
int_seq_length = int(seq_length)
real_seq_length = seq_length real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
...@@ -499,7 +497,7 @@ class T5Attention(nn.Module): ...@@ -499,7 +497,7 @@ class T5Attention(nn.Module):
# if key and values are already calculated # if key and values are already calculated
# we want only the last query position bias # we want only the last query position bias
if past_key_value is not None: if past_key_value is not None:
position_bias = position_bias[:, :, -int_seq_length:, :] position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
...@@ -629,7 +627,7 @@ class T5Block(nn.Module): ...@@ -629,7 +627,7 @@ class T5Block(nn.Module):
if len(past_key_value) != expected_num_past_key_values: if len(past_key_value) != expected_num_past_key_values:
raise ValueError( raise ValueError(
f"There should be {expected_num_past_key_values} past states. " f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states" f"Got {len(past_key_value)} past key / value states"
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, List, Mapping, Optional from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
...@@ -59,6 +59,7 @@ class OnnxConfig(ABC): ...@@ -59,6 +59,7 @@ class OnnxConfig(ABC):
_TASKS_TO_COMMON_OUTPUTS = { _TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}), "multiple-choice": OrderedDict({"logits": {0: "batch"}}),
...@@ -228,6 +229,24 @@ class OnnxConfig(ABC): ...@@ -228,6 +229,24 @@ class OnnxConfig(ABC):
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op) setattr(spec.o, spec.name, orig_op)
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
class OnnxConfigWithPast(OnnxConfig, ABC): class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__( def __init__(
...@@ -285,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -285,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
# Generate dummy inputs according to compute batch and sequence # Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework))) return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)
...@@ -24,7 +24,6 @@ from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedMod ...@@ -24,7 +24,6 @@ from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedMod
from ..file_utils import is_torch_onnx_dict_inputs_support_available from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..utils import logging from ..utils import logging
from .config import OnnxConfig from .config import OnnxConfig
from .utils import flatten_output_collection_property
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -163,7 +162,7 @@ def validate_model_outputs( ...@@ -163,7 +162,7 @@ def validate_model_outputs(
if name == "past_key_values": if name == "past_key_values":
name = "present" name = "present"
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value) value = config.flatten_output_collection_property(name, value)
ref_outputs_dict.update(value) ref_outputs_dict.update(value)
else: else:
ref_outputs_dict[name] = value ref_outputs_dict[name] = value
...@@ -172,7 +171,7 @@ def validate_model_outputs( ...@@ -172,7 +171,7 @@ def validate_model_outputs(
onnx_inputs = {} onnx_inputs = {}
for name, value in reference_model_inputs.items(): for name, value in reference_model_inputs.items():
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value) value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else: else:
onnx_inputs[name] = value.numpy() onnx_inputs[name] = value.numpy()
......
...@@ -21,6 +21,7 @@ if is_torch_available(): ...@@ -21,6 +21,7 @@ if is_torch_available():
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForTokenClassification,
) )
...@@ -46,6 +47,7 @@ class FeaturesManager: ...@@ -46,6 +47,7 @@ class FeaturesManager:
_TASKS_TO_AUTOMODELS = { _TASKS_TO_AUTOMODELS = {
"default": AutoModel, "default": AutoModel,
"causal-lm": AutoModelForCausalLM, "causal-lm": AutoModelForCausalLM,
"seq2seq-lm": AutoModelForSeq2SeqLM,
"sequence-classification": AutoModelForSequenceClassification, "sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification, "token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice, "multiple-choice": AutoModelForMultipleChoice,
...@@ -61,7 +63,9 @@ class FeaturesManager: ...@@ -61,7 +63,9 @@ class FeaturesManager:
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig), "roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig), "t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig), "xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"gpt-neo": supported_features_mapping( "gpt-neo": supported_features_mapping(
"default", "default",
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from ctypes import c_float, sizeof from ctypes import c_float, sizeof
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable
class ParameterFormat(Enum): class ParameterFormat(Enum):
...@@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm ...@@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters Size (in byte) taken to save all the parameters
""" """
return num_parameters * dtype.size return num_parameters * dtype.size
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
...@@ -34,11 +34,7 @@ from transformers.onnx import ( ...@@ -34,11 +34,7 @@ from transformers.onnx import (
validate_model_outputs, validate_model_outputs,
) )
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import ( from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.testing_utils import require_onnx, require_torch, slow from transformers.testing_utils import require_onnx, require_torch, slow
...@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase): ...@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase):
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n} ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
""" """
self.assertEqual( self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]), OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
{ {
"past_key.0": 0, "past_key.0": 0,
"past_key.1": 1, "past_key.1": 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment