Unverified Commit e86c02ea authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix GPTNeo onnx export (#13524)



Update GPT Neo ONNX config to match the changes implied by the simplification of the local attention
Co-authored-by: default avatarMichael Benayoun <michael@huggingface.co>
parent 3fbb55c7
...@@ -19,7 +19,7 @@ from typing import Any, Dict, Iterable, Mapping, Optional ...@@ -19,7 +19,7 @@ from typing import Any, Dict, Iterable, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec from ...onnx import OnnxConfigWithPast
from ...utils import logging from ...utils import logging
...@@ -212,49 +212,17 @@ def custom_get_block_length_and_num_blocks(seq_length, window_size): ...@@ -212,49 +212,17 @@ def custom_get_block_length_and_num_blocks(seq_length, window_size):
class GPTNeoOnnxConfig(OnnxConfigWithPast): class GPTNeoOnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
if is_torch_available():
import torch
from .modeling_gpt_neo import GPTNeoAttentionMixin
patching_specs = [
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
PatchingSpec(
GPTNeoAttentionMixin,
name="_get_block_length_and_num_blocks",
custom_op=custom_get_block_length_and_num_blocks,
op_wrapper=staticmethod,
),
]
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
self._key_values_dynamic_axis = []
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
else:
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
@property
def _number_key_values(self):
return (self._config.num_layers * 2) - self._num_local_attention
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past: if self.use_past:
for i in range(self._config.num_layers): for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local": common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"} common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence"}
else:
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs return common_inputs
...@@ -263,11 +231,11 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -263,11 +231,11 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
common_outputs = super().outputs common_outputs = super().outputs
if self.use_past: if self.use_past:
for i in range(self._config.num_layers): for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local": common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"} common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"} return common_outputs
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
return common_outputs return common_outputs
def generate_dummy_inputs( def generate_dummy_inputs(
...@@ -283,12 +251,6 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -283,12 +251,6 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
batch = common_inputs["input_ids"].shape[0]
past_shapes = {
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
"local": (batch, 1, self._config.hidden_size),
}
# Need to add the past_keys # Need to add the past_keys
if self.use_past: if self.use_past:
if not is_torch_available(): if not is_torch_available():
...@@ -296,23 +258,16 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -296,23 +258,16 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
else: else:
import torch import torch
ordered_inputs["past_key_values"] = [] batch = common_inputs["input_ids"].shape[0]
for i in range(self._config.num_layers): past_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
attention_type = self._config.attention_layers[i] ordered_inputs["past_key_values"] = [
if attention_type == "global": (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self._config.num_layers)
ordered_inputs["past_key_values"].append( ]
(
torch.zeros(past_shapes[attention_type]),
torch.zeros(past_shapes[attention_type]),
)
)
else:
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past: if self.use_past:
ordered_inputs["attention_mask"] = torch.cat( ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1 [ordered_inputs["attention_mask"], torch.ones(batch, 1)], dim=1
) )
return ordered_inputs return ordered_inputs
...@@ -322,11 +277,8 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -322,11 +277,8 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
if name in ["present", "past_key_values"]: if name in ["present", "past_key_values"]:
flatten_output = {} flatten_output = {}
for idx, t in enumerate(field): for idx, t in enumerate(field):
if len(t) == 1: flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.key_value"] = t[0] flatten_output[f"{name}.{idx}.value"] = t[1]
else:
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output return flatten_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment