Unverified Commit 1967be98 authored by Nouamane Tazi's avatar Nouamane Tazi Committed by GitHub
Browse files

fix BLOOM ONNX config (#19573)



* fix BLOOM ONNX config
- `value` params have `seq_len` as their 2nd axe as opposed to other models which have it as 3rd
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 4f0337a0
...@@ -152,7 +152,6 @@ class BloomConfig(PretrainedConfig): ...@@ -152,7 +152,6 @@ class BloomConfig(PretrainedConfig):
class BloomOnnxConfig(OnnxConfigWithPast): class BloomOnnxConfig(OnnxConfigWithPast):
torch_onnx_minimum_version = version.parse("1.12") torch_onnx_minimum_version = version.parse("1.12")
def __init__( def __init__(
...@@ -171,7 +170,8 @@ class BloomOnnxConfig(OnnxConfigWithPast): ...@@ -171,7 +170,8 @@ class BloomOnnxConfig(OnnxConfigWithPast):
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past: if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs") # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else: else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
......
...@@ -486,7 +486,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -486,7 +486,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
return common_inputs return common_inputs
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): def fill_with_past_key_values_(
self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False
):
""" """
Fill the input_or_outputs mapping with past_key_values dynamic axes considering. Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
...@@ -494,6 +496,8 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -494,6 +496,8 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
inputs_or_outputs: The mapping to fill. inputs_or_outputs: The mapping to fill.
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
output mapping, this is important for axes naming. output mapping, this is important for axes naming.
inverted_values_shape:
If `True`, store values on dynamic axis 1, else on axis 2.
""" """
if direction not in ["inputs", "outputs"]: if direction not in ["inputs", "outputs"]:
...@@ -502,7 +506,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -502,7 +506,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
name = "past_key_values" if direction == "inputs" else "present" name = "past_key_values" if direction == "inputs" else "present"
for i in range(self.num_layers): for i in range(self.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} if inverted_values_shape:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"}
else:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
def _flatten_past_key_values_(self, flattened_output, name, idx, t): def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0] flattened_output[f"{name}.{idx}.key"] = t[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment