Unverified Commit 0645b07d authored by arampacha's avatar arampacha Committed by GitHub
Browse files

propagate "attention_mask" dtype for "use_past" in OnnxConfig.generate_dummy_inputs (#17105)

* propagate attention_mask dtype

* fixup&style
parent 0e6ec2a4
...@@ -337,8 +337,9 @@ class BartOnnxConfig(OnnxSeq2SeqConfigWithPast): ...@@ -337,8 +337,9 @@ class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
self._config.hidden_size // num_encoder_attention_heads, self._config.hidden_size // num_encoder_attention_heads,
) )
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
common_inputs["past_key_values"] = [ common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
......
...@@ -313,8 +313,9 @@ class BlenderbotOnnxConfig(OnnxSeq2SeqConfigWithPast): ...@@ -313,8 +313,9 @@ class BlenderbotOnnxConfig(OnnxSeq2SeqConfigWithPast):
past_key_values_length, past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads, self._config.hidden_size // num_encoder_attention_heads,
) )
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
common_inputs["past_key_values"] = [ common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_decoder_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_decoder_layers)
......
...@@ -327,8 +327,9 @@ class BlenderbotSmallOnnxConfig(OnnxSeq2SeqConfigWithPast): ...@@ -327,8 +327,9 @@ class BlenderbotSmallOnnxConfig(OnnxSeq2SeqConfigWithPast):
self._config.hidden_size // num_encoder_attention_heads, self._config.hidden_size // num_encoder_attention_heads,
) )
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
common_inputs["past_key_values"] = [ common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
......
...@@ -262,8 +262,9 @@ class GPT2OnnxConfig(OnnxConfigWithPast): ...@@ -262,8 +262,9 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past: if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat( ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
return ordered_inputs return ordered_inputs
......
...@@ -261,8 +261,9 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -261,8 +261,9 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past: if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat( ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
return ordered_inputs return ordered_inputs
......
...@@ -211,8 +211,9 @@ class GPTJOnnxConfig(OnnxConfigWithPast): ...@@ -211,8 +211,9 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past: if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat( ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
return ordered_inputs return ordered_inputs
......
...@@ -327,8 +327,9 @@ class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast): ...@@ -327,8 +327,9 @@ class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
self._config.hidden_size // num_encoder_attention_heads, self._config.hidden_size // num_encoder_attention_heads,
) )
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
common_inputs["past_key_values"] = [ common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
......
...@@ -322,8 +322,9 @@ class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast): ...@@ -322,8 +322,9 @@ class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
self._config.hidden_size // num_encoder_attention_heads, self._config.hidden_size // num_encoder_attention_heads,
) )
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
common_inputs["past_key_values"] = [ common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
......
...@@ -457,8 +457,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -457,8 +457,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
) )
if "attention_mask" in common_inputs: if "attention_mask" in common_inputs:
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)],
dim=1,
) )
common_inputs["past_key_values"] = [] common_inputs["past_key_values"] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment