Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
98109464
Unverified
Commit
98109464
authored
Jun 28, 2020
by
Patrick von Platen
Committed by
GitHub
Jun 28, 2020
Browse files
clean reformer reverse sort (#5343)
parent
1af58c07
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
38 deletions
+21
-38
src/transformers/modeling_reformer.py
src/transformers/modeling_reformer.py
+21
-38
No files found.
src/transformers/modeling_reformer.py
View file @
98109464
...
@@ -384,11 +384,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -384,11 +384,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
)
)
# make sure bucket idx is not longer then sequence length
# make sure bucket idx is not longer then sequence length
sorted_bucket_idx
=
sorted_bucket_idx
%
sequence_length
sorted_bucket_idx
_per_hash
=
sorted_bucket_idx
%
sequence_length
# cluster query key value vectors according to hashed buckets
# cluster query key value vectors according to hashed buckets
query_key_vectors
=
self
.
_gather_by_expansion
(
query_key_vectors
,
sorted_bucket_idx
,
num_hashes
)
query_key_vectors
=
self
.
_gather_by_expansion
(
query_key_vectors
,
sorted_bucket_idx
_per_hash
,
num_hashes
)
value_vectors
=
self
.
_gather_by_expansion
(
value_vectors
,
sorted_bucket_idx
,
num_hashes
)
value_vectors
=
self
.
_gather_by_expansion
(
value_vectors
,
sorted_bucket_idx
_per_hash
,
num_hashes
)
query_key_vectors
=
self
.
_split_seq_length_dim_to
(
query_key_vectors
=
self
.
_split_seq_length_dim_to
(
query_key_vectors
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
,
self
.
attention_head_size
,
query_key_vectors
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
,
self
.
attention_head_size
,
...
@@ -403,7 +403,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -403,7 +403,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
),
"If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
),
"If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
else
:
else
:
# get sequence length indices
# get sequence length indices
sorted_bucket_idx
=
torch
.
arange
(
sequence_length
,
device
=
query_key_vectors
.
device
).
repeat
(
sorted_bucket_idx
_per_hash
=
torch
.
arange
(
sequence_length
,
device
=
query_key_vectors
.
device
).
repeat
(
batch_size
,
self
.
num_attention_heads
,
1
batch_size
,
self
.
num_attention_heads
,
1
)
)
...
@@ -415,7 +415,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -415,7 +415,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_vectors
=
query_key_vectors
,
query_vectors
=
query_key_vectors
,
key_vectors
=
key_vectors
,
key_vectors
=
key_vectors
,
value_vectors
=
value_vectors
,
value_vectors
=
value_vectors
,
sorted_bucket_idx
=
sorted_bucket_idx
,
sorted_bucket_idx
_per_hash
=
sorted_bucket_idx
_per_hash
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
...
@@ -427,9 +427,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -427,9 +427,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# re-order out_vectors and logits
# re-order out_vectors and logits
if
self
.
chunk_length
<
sequence_length
:
if
self
.
chunk_length
<
sequence_length
:
# sort clusters back to correct ordering
# sort clusters back to correct ordering
out_vectors
,
logits
=
ReverseSort
.
apply
(
out_vectors
,
logits
=
ReverseSort
.
apply
(
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
)
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
,
self
.
num_hashes
)
# sum up all hash rounds
# sum up all hash rounds
if
num_hashes
>
1
:
if
num_hashes
>
1
:
...
@@ -578,7 +576,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -578,7 +576,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self
.
num_buckets
=
num_buckets
self
.
num_buckets
=
num_buckets
def
_attend
(
def
_attend
(
self
,
query_vectors
,
key_vectors
,
value_vectors
,
sorted_bucket_idx
,
attention_mask
,
head_mask
,
sequence_length
self
,
query_vectors
,
key_vectors
,
value_vectors
,
sorted_bucket_idx_per_hash
,
attention_mask
,
head_mask
,
sequence_length
,
):
):
# look at previous and following chunks if chunked attention
# look at previous and following chunks if chunked attention
...
@@ -595,11 +600,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -595,11 +600,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# if chunked attention split bucket idxs to query and key
# if chunked attention split bucket idxs to query and key
if
self
.
chunk_length
<
sequence_length
:
if
self
.
chunk_length
<
sequence_length
:
query_bucket_idx
=
self
.
_split_seq_length_dim_to
(
query_bucket_idx
=
self
.
_split_seq_length_dim_to
(
sorted_bucket_idx
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
sorted_bucket_idx
_per_hash
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
)
)
key_value_bucket_idx
=
self
.
_look_adjacent
(
query_bucket_idx
,
self
.
num_chunks_before
,
self
.
num_chunks_after
)
key_value_bucket_idx
=
self
.
_look_adjacent
(
query_bucket_idx
,
self
.
num_chunks_before
,
self
.
num_chunks_after
)
else
:
else
:
query_bucket_idx
=
key_value_bucket_idx
=
sorted_bucket_idx
query_bucket_idx
=
key_value_bucket_idx
=
sorted_bucket_idx
_per_hash
# get correct mask values depending on precision
# get correct mask values depending on precision
if
query_key_dots
.
dtype
==
torch
.
float16
:
if
query_key_dots
.
dtype
==
torch
.
float16
:
...
@@ -741,11 +746,10 @@ class ReverseSort(Function):
...
@@ -741,11 +746,10 @@ class ReverseSort(Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
,
num_hashes
):
def
forward
(
ctx
,
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
):
# save sorted_bucket_idx for backprop
# save sorted_bucket_idx for backprop
with
torch
.
no_grad
():
with
torch
.
no_grad
():
ctx
.
sorted_bucket_idx
=
sorted_bucket_idx
ctx
.
sorted_bucket_idx
=
sorted_bucket_idx
ctx
.
num_hashes
=
num_hashes
# undo sort to have correct order for next layer
# undo sort to have correct order for next layer
expanded_undo_sort_indices
=
undo_sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
out_vectors
.
shape
)
expanded_undo_sort_indices
=
undo_sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
out_vectors
.
shape
)
...
@@ -757,35 +761,14 @@ class ReverseSort(Function):
...
@@ -757,35 +761,14 @@ class ReverseSort(Function):
def
backward
(
ctx
,
grad_out_vectors
,
grad_logits
):
def
backward
(
ctx
,
grad_out_vectors
,
grad_logits
):
# get parameters saved in ctx
# get parameters saved in ctx
sorted_bucket_idx
=
ctx
.
sorted_bucket_idx
sorted_bucket_idx
=
ctx
.
sorted_bucket_idx
num_hashes
=
ctx
.
num_hashes
# get real gradient shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes
grad_logits_shape
=
grad_logits
.
shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes x ChunkLen
grad_out_vectors_shape
=
grad_out_vectors
.
shape
# split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen
grad_logits
=
grad_logits
.
view
((
grad_logits_shape
[:
2
]
+
(
num_hashes
,
-
1
)))
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen
grad_out_vectors
=
grad_out_vectors
.
view
(
(
grad_out_vectors_shape
[:
2
]
+
(
num_hashes
,
-
1
)
+
grad_out_vectors_shape
[
-
1
:])
)
# reshape and expand
sorted_bucket_idx
=
torch
.
reshape
(
sorted_bucket_idx
,
(
sorted_bucket_idx
.
shape
[:
2
]
+
(
num_hashes
,
-
1
)))
expanded_sort_indices
=
sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
grad_out_vectors
.
shape
)
expanded_sort_indices
=
sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
grad_out_vectors
.
shape
)
# reverse sort of forward
# reverse sort of forward
grad_out_vectors
=
torch
.
gather
(
grad_out_vectors
,
3
,
expanded_sort_indices
)
grad_out_vectors
=
torch
.
gather
(
grad_out_vectors
,
2
,
expanded_sort_indices
)
grad_logits
=
torch
.
gather
(
grad_logits
,
3
,
sorted_bucket_idx
)
grad_logits
=
torch
.
gather
(
grad_logits
,
2
,
sorted_bucket_idx
)
# reshape into correct shape
grad_logits
=
torch
.
reshape
(
grad_logits
,
grad_logits_shape
)
grad_out_vectors
=
torch
.
reshape
(
grad_out_vectors
,
grad_out_vectors_shape
)
# return grad and `None` fillers for last
3
forward args
# return grad and `None` fillers for last
2
forward args
return
grad_out_vectors
,
grad_logits
,
None
,
None
,
None
return
grad_out_vectors
,
grad_logits
,
None
,
None
class
LocalSelfAttention
(
nn
.
Module
,
EfficientAttentionMixin
):
class
LocalSelfAttention
(
nn
.
Module
,
EfficientAttentionMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment