Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4c7f564f
Unverified
Commit
4c7f564f
authored
Jun 09, 2020
by
ZhuBaohe
Committed by
GitHub
Jun 08, 2020
Browse files
fix (#4839)
parent
37be3786
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
5 deletions
+3
-5
src/transformers/modeling_longformer.py
src/transformers/modeling_longformer.py
+3
-5
No files found.
src/transformers/modeling_longformer.py
View file @
4c7f564f
...
@@ -153,12 +153,11 @@ class LongformerSelfAttention(nn.Module):
...
@@ -153,12 +153,11 @@ class LongformerSelfAttention(nn.Module):
beginning_mask_2d
=
input_tensor
.
new_ones
(
w
,
w
+
1
).
tril
().
flip
(
dims
=
[
0
])
beginning_mask_2d
=
input_tensor
.
new_ones
(
w
,
w
+
1
).
tril
().
flip
(
dims
=
[
0
])
beginning_mask
=
beginning_mask_2d
[
None
,
:,
None
,
:]
beginning_mask
=
beginning_mask_2d
[
None
,
:,
None
,
:]
ending_mask
=
beginning_mask
.
flip
(
dims
=
(
1
,
3
))
ending_mask
=
beginning_mask
.
flip
(
dims
=
(
1
,
3
))
seqlen
=
input_tensor
.
size
(
1
)
beginning_input
=
input_tensor
[:,
:
affected_seqlen
,
:,
:
w
+
1
]
beginning_input
=
input_tensor
[:,
:
affected_seqlen
,
:,
:
w
+
1
]
beginning_mask
=
beginning_mask
[:,
:
seqlen
]
.
expand
(
beginning_input
.
size
())
beginning_mask
=
beginning_mask
.
expand
(
beginning_input
.
size
())
beginning_input
.
masked_fill_
(
beginning_mask
==
1
,
-
float
(
"inf"
))
# `== 1` converts to bool or uint8
beginning_input
.
masked_fill_
(
beginning_mask
==
1
,
-
float
(
"inf"
))
# `== 1` converts to bool or uint8
ending_input
=
input_tensor
[:,
-
affected_seqlen
:,
:,
-
(
w
+
1
)
:]
ending_input
=
input_tensor
[:,
-
affected_seqlen
:,
:,
-
(
w
+
1
)
:]
ending_mask
=
ending_mask
[:,
-
seqlen
:]
.
expand
(
ending_input
.
size
())
ending_mask
=
ending_mask
.
expand
(
ending_input
.
size
())
ending_input
.
masked_fill_
(
ending_mask
==
1
,
-
float
(
"inf"
))
# `== 1` converts to bool or uint8
ending_input
.
masked_fill_
(
ending_mask
==
1
,
-
float
(
"inf"
))
# `== 1` converts to bool or uint8
def
_sliding_chunks_matmul_qk
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
w
:
int
):
def
_sliding_chunks_matmul_qk
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
w
:
int
):
...
@@ -301,7 +300,6 @@ class LongformerSelfAttention(nn.Module):
...
@@ -301,7 +300,6 @@ class LongformerSelfAttention(nn.Module):
k
=
k
.
view
(
seqlen
,
batch_size
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
view
(
seqlen
,
batch_size
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
attn_weights
=
self
.
_sliding_chunks_matmul_qk
(
q
,
k
,
self
.
one_sided_attention_window_size
)
attn_weights
=
self
.
_sliding_chunks_matmul_qk
(
q
,
k
,
self
.
one_sided_attention_window_size
)
self
.
_mask_invalid_locations
(
attn_weights
,
self
.
one_sided_attention_window_size
)
if
remove_from_windowed_attention_mask
is
not
None
:
if
remove_from_windowed_attention_mask
is
not
None
:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
...
@@ -329,7 +327,7 @@ class LongformerSelfAttention(nn.Module):
...
@@ -329,7 +327,7 @@ class LongformerSelfAttention(nn.Module):
selected_k
[
selection_padding_mask_nonzeros
]
=
k
[
extra_attention_mask_nonzeros
]
selected_k
[
selection_padding_mask_nonzeros
]
=
k
[
extra_attention_mask_nonzeros
]
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights
=
torch
.
einsum
(
"blhd,bshd->blhs"
,
(
q
,
selected_k
))
selected_attn_weights
=
torch
.
einsum
(
"blhd,bshd->blhs"
,
(
q
,
selected_k
))
selected_attn_weights
[
selection_padding_mask_zeros
[
0
],
:,
:,
selection_padding_mask_zeros
[
1
]]
=
-
10000
selected_attn_weights
[
selection_padding_mask_zeros
[
0
],
:,
:,
selection_padding_mask_zeros
[
1
]]
=
-
10000
.0
# concat to attn_weights
# concat to attn_weights
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
attn_weights
=
torch
.
cat
((
selected_attn_weights
,
attn_weights
),
dim
=-
1
)
attn_weights
=
torch
.
cat
((
selected_attn_weights
,
attn_weights
),
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment