Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
810079de
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c153bcc5c86014cdf821872a5b3ecc2d3109e046"
Commit
810079de
authored
Mar 05, 2020
by
sshleifer
Browse files
no ipdb
parent
c203509d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
4 deletions
+1
-4
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+1
-4
No files found.
src/transformers/modeling_bart.py
View file @
810079de
...
@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
...
@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
static_kv
:
bool
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
and
static_kv
:
if
prev_key_padding_mask
is
not
None
and
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
new_key_padding_mask
=
prev_key_padding_mask
elif
prev_key_padding_mask
is
not
None
and
key_padding_mask
is
not
None
:
elif
prev_key_padding_mask
is
not
None
and
key_padding_mask
is
not
None
:
...
@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
...
@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
if
prev_key_padding_mask
.
is_cuda
:
if
prev_key_padding_mask
.
is_cuda
:
filler
=
filler
.
to
(
prev_key_padding_mask
.
device
)
filler
=
filler
.
to
(
prev_key_padding_mask
.
device
)
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
print
(
new_key_padding_mask
.
device
,
new_key_padding_mask
.
dtype
)
import
ipdb
ipdb
.
set_trace
()
elif
key_padding_mask
is
not
None
:
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
))
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
))
if
key_padding_mask
.
is_cuda
:
if
key_padding_mask
.
is_cuda
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment