Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
63f4d8ca
Unverified
Commit
63f4d8ca
authored
Mar 26, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 26, 2020
Browse files
[Bart/Memory] SelfAttention only returns weights if config.outp… (#3369)
parent
2b2a2f8d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+13
-4
No files found.
src/transformers/modeling_bart.py
View file @
63f4d8ca
...
...
@@ -217,7 +217,9 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual
=
x
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
key_padding_mask
=
encoder_padding_mask
,)
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
self
.
output_attentions
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
...
...
@@ -316,6 +318,7 @@ class DecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
BartConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
output_attentions
=
config
.
output_attentions
self
.
self_attn
=
SelfAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
decoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
)
...
...
@@ -343,14 +346,16 @@ class DecoderLayer(nn.Module):
if
layer_state
is
None
:
layer_state
=
{}
# next line mutates layer state
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,)
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,
need_weights
=
self
.
output_attentions
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
assert
self
.
encoder_attn
.
cache_key
!=
self
.
self_attn
.
cache_key
x
,
encoder_attn_weights
=
self
.
encoder_attn
(
x
,
_
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_hidden_states
,
key_padding_mask
=
encoder_attn_mask
,
...
...
@@ -527,6 +532,7 @@ class SelfAttention(nn.Module):
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
layer_state
:
Optional
[
Dict
[
str
,
Optional
[
Tensor
]]]
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
need_weights
=
False
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv
=
self
.
encoder_decoder_attention
# type: bool
...
...
@@ -598,7 +604,10 @@ class SelfAttention(nn.Module):
assert
attn_output
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
attn_output
=
attn_output
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
attn_output
=
self
.
out_proj
(
attn_output
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
if
need_weights
:
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
else
:
attn_weights
=
None
return
attn_output
,
attn_weights
def
_use_saved_state
(
self
,
k
,
v
,
saved_state
,
key_padding_mask
,
static_kv
,
bsz
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment