Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0645b07d
Unverified
Commit
0645b07d
authored
May 11, 2022
by
arampacha
Committed by
GitHub
May 11, 2022
Browse files
propagate "attention_mask" dtype for "use_past" in OnnxConfig.generate_dummy_inputs (#17105)
* propagate attention_mask dtype * fixup&style
parent
0e6ec2a4
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
19 additions
and
9 deletions
+19
-9
src/transformers/models/bart/configuration_bart.py
src/transformers/models/bart/configuration_bart.py
+2
-1
src/transformers/models/blenderbot/configuration_blenderbot.py
...ransformers/models/blenderbot/configuration_blenderbot.py
+2
-1
src/transformers/models/blenderbot_small/configuration_blenderbot_small.py
...models/blenderbot_small/configuration_blenderbot_small.py
+2
-1
src/transformers/models/gpt2/configuration_gpt2.py
src/transformers/models/gpt2/configuration_gpt2.py
+2
-1
src/transformers/models/gpt_neo/configuration_gpt_neo.py
src/transformers/models/gpt_neo/configuration_gpt_neo.py
+2
-1
src/transformers/models/gptj/configuration_gptj.py
src/transformers/models/gptj/configuration_gptj.py
+2
-1
src/transformers/models/marian/configuration_marian.py
src/transformers/models/marian/configuration_marian.py
+2
-1
src/transformers/models/mbart/configuration_mbart.py
src/transformers/models/mbart/configuration_mbart.py
+2
-1
src/transformers/onnx/config.py
src/transformers/onnx/config.py
+3
-1
No files found.
src/transformers/models/bart/configuration_bart.py
View file @
0645b07d
...
...
@@ -337,8 +337,9 @@ class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
mask_dtype
=
common_inputs
[
"attention_mask"
].
dtype
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[
(
torch
.
zeros
(
past_shape
),
torch
.
zeros
(
past_shape
))
for
_
in
range
(
num_encoder_layers
)
...
...
src/transformers/models/blenderbot/configuration_blenderbot.py
View file @
0645b07d
...
...
@@ -313,8 +313,9 @@ class BlenderbotOnnxConfig(OnnxSeq2SeqConfigWithPast):
past_key_values_length
,
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
mask_dtype
=
common_inputs
[
"attention_mask"
].
dtype
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[
(
torch
.
zeros
(
past_shape
),
torch
.
zeros
(
past_shape
))
for
_
in
range
(
num_decoder_layers
)
...
...
src/transformers/models/blenderbot_small/configuration_blenderbot_small.py
View file @
0645b07d
...
...
@@ -327,8 +327,9 @@ class BlenderbotSmallOnnxConfig(OnnxSeq2SeqConfigWithPast):
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
mask_dtype
=
common_inputs
[
"attention_mask"
].
dtype
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[
(
torch
.
zeros
(
past_shape
),
torch
.
zeros
(
past_shape
))
for
_
in
range
(
num_encoder_layers
)
...
...
src/transformers/models/gpt2/configuration_gpt2.py
View file @
0645b07d
...
...
@@ -262,8 +262,9 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
ordered_inputs
[
"attention_mask"
]
=
common_inputs
[
"attention_mask"
]
if
self
.
use_past
:
mask_dtype
=
ordered_inputs
[
"attention_mask"
].
dtype
ordered_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
ordered_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
ordered_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
return
ordered_inputs
...
...
src/transformers/models/gpt_neo/configuration_gpt_neo.py
View file @
0645b07d
...
...
@@ -261,8 +261,9 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
ordered_inputs
[
"attention_mask"
]
=
common_inputs
[
"attention_mask"
]
if
self
.
use_past
:
mask_dtype
=
ordered_inputs
[
"attention_mask"
].
dtype
ordered_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
ordered_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
ordered_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
return
ordered_inputs
...
...
src/transformers/models/gptj/configuration_gptj.py
View file @
0645b07d
...
...
@@ -211,8 +211,9 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
ordered_inputs
[
"attention_mask"
]
=
common_inputs
[
"attention_mask"
]
if
self
.
use_past
:
mask_dtype
=
ordered_inputs
[
"attention_mask"
].
dtype
ordered_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
ordered_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
ordered_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
return
ordered_inputs
...
...
src/transformers/models/marian/configuration_marian.py
View file @
0645b07d
...
...
@@ -327,8 +327,9 @@ class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
mask_dtype
=
common_inputs
[
"attention_mask"
].
dtype
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[
(
torch
.
zeros
(
past_shape
),
torch
.
zeros
(
past_shape
))
for
_
in
range
(
num_encoder_layers
)
...
...
src/transformers/models/mbart/configuration_mbart.py
View file @
0645b07d
...
...
@@ -322,8 +322,9 @@ class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
self
.
_config
.
hidden_size
//
num_encoder_attention_heads
,
)
mask_dtype
=
common_inputs
[
"attention_mask"
].
dtype
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
)
common_inputs
[
"past_key_values"
]
=
[
(
torch
.
zeros
(
past_shape
),
torch
.
zeros
(
past_shape
))
for
_
in
range
(
num_encoder_layers
)
...
...
src/transformers/onnx/config.py
View file @
0645b07d
...
...
@@ -457,8 +457,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
)
if
"attention_mask"
in
common_inputs
:
mask_dtype
=
common_inputs
[
"attention_mask"
].
dtype
common_inputs
[
"attention_mask"
]
=
torch
.
cat
(
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
)],
dim
=
1
[
common_inputs
[
"attention_mask"
],
torch
.
ones
(
batch
,
past_key_values_length
,
dtype
=
mask_dtype
)],
dim
=
1
,
)
common_inputs
[
"past_key_values"
]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment