Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b2505f7d
"vscode:/vscode.git/clone" did not exist on "b5187e317f97722238d1217fbd07e565f21a83a2"
Unverified
Commit
b2505f7d
authored
Jul 14, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 14, 2020
Browse files
Cleanup bart caching logic (#5640)
parent
838950ee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
30 deletions
+6
-30
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+6
-30
No files found.
src/transformers/modeling_bart.py
View file @
b2505f7d
...
...
@@ -628,8 +628,8 @@ class SelfAttention(nn.Module):
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
cache_key
=
"encoder_decoder"
if
self
.
encoder_decoder_attention
else
"self"
def
_shape
(
self
,
tensor
,
dim_0
,
bsz
):
return
tensor
.
contiguous
().
view
(
dim_0
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
def
_shape
(
self
,
tensor
,
seq_len
,
bsz
):
return
tensor
.
contiguous
().
view
(
seq_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
def
forward
(
self
,
...
...
@@ -648,9 +648,8 @@ class SelfAttention(nn.Module):
# get here for encoder decoder cause of static_kv
if
layer_state
is
not
None
:
# reuse k,v and encoder_padding_mask
saved_state
=
layer_state
.
get
(
self
.
cache_key
,
{})
if
"prev_key"
in
saved_state
:
if
"prev_key"
in
saved_state
and
static_kv
:
# previous time steps are cached - no need to recompute key and value if they are static
if
static_kv
:
key
=
None
else
:
saved_state
=
None
...
...
@@ -738,37 +737,14 @@ class SelfAttention(nn.Module):
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
assert
k
is
not
None
and
v
is
not
None
prev_key_padding_mask
:
Optional
[
Tensor
]
=
saved_state
.
get
(
"prev_key_padding_mask"
,
None
)
key_padding_mask
=
self
.
_cat_prev_key_padding_mask
(
key_padding_mask
,
prev_key_padding_mask
,
bsz
,
k
.
size
(
1
),
static_kv
)
return
k
,
v
,
key_padding_mask
@
staticmethod
def
_cat_prev_key_padding_mask
(
key_padding_mask
:
Optional
[
Tensor
],
prev_key_padding_mask
:
Optional
[
Tensor
],
batch_size
:
int
,
src_len
:
int
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
:
if
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
else
:
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
,
key_padding_mask
],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
),
dtype
=
key_padding_mask
.
dtype
,
device
=
key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
([
filler
,
key_padding_mask
],
dim
=
1
)
else
:
new_key_padding_mask
=
prev_
key_padding_mask
return
new_key_padding_mask
new_key_padding_mask
=
key_padding_mask
return
k
,
v
,
new_key_padding_mask
class
BartClassificationHead
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment