Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b2505f7d
Unverified
Commit
b2505f7d
authored
Jul 14, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 14, 2020
Browse files
Cleanup bart caching logic (#5640)
parent
838950ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
30 deletions
+6
-30
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+6
-30
No files found.
src/transformers/modeling_bart.py
View file @
b2505f7d
...
...
@@ -628,8 +628,8 @@ class SelfAttention(nn.Module):
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
cache_key
=
"encoder_decoder"
if
self
.
encoder_decoder_attention
else
"self"
def
_shape
(
self
,
tensor
,
dim_0
,
bsz
):
return
tensor
.
contiguous
().
view
(
dim_0
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
def
_shape
(
self
,
tensor
,
seq_len
,
bsz
):
return
tensor
.
contiguous
().
view
(
seq_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
def
forward
(
self
,
...
...
@@ -648,10 +648,9 @@ class SelfAttention(nn.Module):
# get here for encoder decoder cause of static_kv
if
layer_state
is
not
None
:
# reuse k,v and encoder_padding_mask
saved_state
=
layer_state
.
get
(
self
.
cache_key
,
{})
if
"prev_key"
in
saved_state
:
if
"prev_key"
in
saved_state
and
static_kv
:
# previous time steps are cached - no need to recompute key and value if they are static
if
static_kv
:
key
=
None
key
=
None
else
:
saved_state
=
None
layer_state
=
{}
...
...
@@ -738,37 +737,14 @@ class SelfAttention(nn.Module):
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
assert
k
is
not
None
and
v
is
not
None
prev_key_padding_mask
:
Optional
[
Tensor
]
=
saved_state
.
get
(
"prev_key_padding_mask"
,
None
)
key_padding_mask
=
self
.
_cat_prev_key_padding_mask
(
key_padding_mask
,
prev_key_padding_mask
,
bsz
,
k
.
size
(
1
),
static_kv
)
return
k
,
v
,
key_padding_mask
@
staticmethod
def
_cat_prev_key_padding_mask
(
key_padding_mask
:
Optional
[
Tensor
],
prev_key_padding_mask
:
Optional
[
Tensor
],
batch_size
:
int
,
src_len
:
int
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
:
if
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
else
:
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
,
key_padding_mask
],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
),
dtype
=
key_padding_mask
.
dtype
,
device
=
key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
([
filler
,
key_padding_mask
],
dim
=
1
)
else
:
new_key_padding_mask
=
prev_
key_padding_mask
return
new_key_padding_mask
new_key_padding_mask
=
key_padding_mask
return
k
,
v
,
new_key_padding_mask
class
BartClassificationHead
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment