Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
466af1a3
"vscode:/vscode.git/clone" did not exist on "8a2be93b4e9ba15e1bb4141202bf3e17ca7dcdd6"
Unverified
Commit
466af1a3
authored
May 16, 2023
by
Joao Gante
Committed by
GitHub
May 16, 2023
Browse files
OPT/BioGPT: Improved attention mask shape exception (#23270)
parent
21741e8c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
0 deletions
+20
-0
src/transformers/models/biogpt/modeling_biogpt.py
src/transformers/models/biogpt/modeling_biogpt.py
+6
-0
src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_opt.py
+5
-0
src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/opt/modeling_tf_opt.py
+9
-0
No files found.
src/transformers/models/biogpt/modeling_biogpt.py
View file @
466af1a3
...
...
@@ -546,6 +546,12 @@ class BioGptModel(BioGptPreTrainedModel):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
inputs_embeds
.
shape
[:
2
],
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
)
elif
attention_mask
.
shape
[
1
]
!=
past_key_values_length
+
input_shape
[
1
]:
raise
ValueError
(
f
"The provided attention mask has length
{
attention_mask
.
shape
[
1
]
}
, but its length should be "
f
"
{
past_key_values_length
+
input_shape
[
1
]
}
(sum of the lengths of current and past inputs)"
)
# embed positions
positions
=
self
.
embed_positions
(
attention_mask
,
past_key_values_length
)
...
...
src/transformers/models/opt/modeling_opt.py
View file @
466af1a3
...
...
@@ -642,6 +642,11 @@ class OPTDecoder(OPTPreTrainedModel):
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
batch_size
,
mask_seq_length
,
device
=
inputs_embeds
.
device
)
elif
attention_mask
.
shape
[
1
]
!=
mask_seq_length
:
raise
ValueError
(
f
"The provided attention mask has length
{
attention_mask
.
shape
[
1
]
}
, but its length should be "
f
"
{
mask_seq_length
}
(sum of the lengths of current and past inputs)"
)
causal_attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
)
...
...
src/transformers/models/opt/modeling_tf_opt.py
View file @
466af1a3
...
...
@@ -645,6 +645,15 @@ class TFOPTDecoder(tf.keras.layers.Layer):
if
attention_mask
is
None
:
attention_mask
=
tf
.
ones
(
inputs_embeds
.
shape
[:
2
],
dtype
=
tf
.
bool
)
else
:
tf
.
debugging
.
assert_equal
(
attention_mask
.
shape
[
1
],
past_key_values_length
+
input_shape
[
1
],
message
=
(
f
"The provided attention mask has length
{
attention_mask
.
shape
[
1
]
}
, but its length should be "
f
"
{
past_key_values_length
+
input_shape
[
1
]
}
(sum of the lengths of current and past inputs)"
),
)
pos_embeds
=
self
.
embed_positions
(
attention_mask
,
past_key_values_length
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment