Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
689e0b24
Commit
689e0b24
authored
Aug 11, 2022
by
Guolin Ke
Browse files
fix a bug for return attn_weights
parent
a92c6297
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
17 deletions
+23
-17
unicore/modules/multihead_attention.py
unicore/modules/multihead_attention.py
+23
-17
No files found.
unicore/modules/multihead_attention.py
View file @
689e0b24
...
...
@@ -92,24 +92,30 @@ class SelfMultiheadAttention(nn.Module):
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_probs
=
softmax_dropout
(
attn_weights
,
self
.
dropout
,
self
.
training
,
bias
=
attn_bias
if
not
return_attn
:
attn
=
softmax_dropout
(
attn_weights
,
self
.
dropout
,
self
.
training
,
bias
=
attn_bias
,
)
else
:
attn_weights
+=
attn_bias
attn
=
softmax_dropout
(
attn_weights
,
self
.
dropout
,
self
.
training
,
)
attn
=
torch
.
bmm
(
attn
_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
o
=
torch
.
bmm
(
attn
,
v
)
assert
list
(
o
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn
=
(
attn
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
o
=
(
o
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
bsz
,
tgt_len
,
embed_dim
)
)
attn
=
self
.
out_proj
(
attn
)
o
=
self
.
out_proj
(
o
)
if
not
return_attn
:
return
attn
return
o
else
:
return
attn
,
attn_weights
,
attn
_probs
return
o
,
attn_weights
,
attn
class
CrossMultiheadAttention
(
nn
.
Module
):
...
...
@@ -201,16 +207,16 @@ class CrossMultiheadAttention(nn.Module):
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn
_probs
=
softmax_dropout
(
attn_weights
,
self
.
dropout
,
self
.
training
,
bias
=
attn_bias
)
attn
=
softmax_dropout
(
attn_weights
,
self
.
dropout
,
self
.
training
,
bias
=
attn_bias
)
attn
=
torch
.
bmm
(
attn
_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
o
=
torch
.
bmm
(
attn
,
v
)
assert
list
(
o
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn
=
(
attn
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
o
=
(
o
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
bsz
,
tgt_len
,
embed_dim
)
)
attn
=
self
.
out_proj
(
attn
)
return
attn
o
=
self
.
out_proj
(
o
)
return
o
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment