Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
98109464
"...audio/git@developer.sourcefind.cn:OpenDAS/lightx2v.git" did not exist on "a1ebc651ab830a381e8960029145b557990342d6"
Unverified
Commit
98109464
authored
Jun 28, 2020
by
Patrick von Platen
Committed by
GitHub
Jun 28, 2020
Browse files
clean reformer reverse sort (#5343)
parent
1af58c07
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
38 deletions
+21
-38
src/transformers/modeling_reformer.py
src/transformers/modeling_reformer.py
+21
-38
No files found.
src/transformers/modeling_reformer.py
View file @
98109464
...
@@ -384,11 +384,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -384,11 +384,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
)
)
# make sure bucket idx is not longer then sequence length
# make sure bucket idx is not longer then sequence length
sorted_bucket_idx
=
sorted_bucket_idx
%
sequence_length
sorted_bucket_idx
_per_hash
=
sorted_bucket_idx
%
sequence_length
# cluster query key value vectors according to hashed buckets
# cluster query key value vectors according to hashed buckets
query_key_vectors
=
self
.
_gather_by_expansion
(
query_key_vectors
,
sorted_bucket_idx
,
num_hashes
)
query_key_vectors
=
self
.
_gather_by_expansion
(
query_key_vectors
,
sorted_bucket_idx
_per_hash
,
num_hashes
)
value_vectors
=
self
.
_gather_by_expansion
(
value_vectors
,
sorted_bucket_idx
,
num_hashes
)
value_vectors
=
self
.
_gather_by_expansion
(
value_vectors
,
sorted_bucket_idx
_per_hash
,
num_hashes
)
query_key_vectors
=
self
.
_split_seq_length_dim_to
(
query_key_vectors
=
self
.
_split_seq_length_dim_to
(
query_key_vectors
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
,
self
.
attention_head_size
,
query_key_vectors
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
,
self
.
attention_head_size
,
...
@@ -403,7 +403,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -403,7 +403,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
),
"If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
),
"If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
else
:
else
:
# get sequence length indices
# get sequence length indices
sorted_bucket_idx
=
torch
.
arange
(
sequence_length
,
device
=
query_key_vectors
.
device
).
repeat
(
sorted_bucket_idx
_per_hash
=
torch
.
arange
(
sequence_length
,
device
=
query_key_vectors
.
device
).
repeat
(
batch_size
,
self
.
num_attention_heads
,
1
batch_size
,
self
.
num_attention_heads
,
1
)
)
...
@@ -415,7 +415,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -415,7 +415,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_vectors
=
query_key_vectors
,
query_vectors
=
query_key_vectors
,
key_vectors
=
key_vectors
,
key_vectors
=
key_vectors
,
value_vectors
=
value_vectors
,
value_vectors
=
value_vectors
,
sorted_bucket_idx
=
sorted_bucket_idx
,
sorted_bucket_idx
_per_hash
=
sorted_bucket_idx
_per_hash
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
...
@@ -427,9 +427,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -427,9 +427,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# re-order out_vectors and logits
# re-order out_vectors and logits
if
self
.
chunk_length
<
sequence_length
:
if
self
.
chunk_length
<
sequence_length
:
# sort clusters back to correct ordering
# sort clusters back to correct ordering
out_vectors
,
logits
=
ReverseSort
.
apply
(
out_vectors
,
logits
=
ReverseSort
.
apply
(
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
)
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
,
self
.
num_hashes
)
# sum up all hash rounds
# sum up all hash rounds
if
num_hashes
>
1
:
if
num_hashes
>
1
:
...
@@ -578,7 +576,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -578,7 +576,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self
.
num_buckets
=
num_buckets
self
.
num_buckets
=
num_buckets
def
_attend
(
def
_attend
(
self
,
query_vectors
,
key_vectors
,
value_vectors
,
sorted_bucket_idx
,
attention_mask
,
head_mask
,
sequence_length
self
,
query_vectors
,
key_vectors
,
value_vectors
,
sorted_bucket_idx_per_hash
,
attention_mask
,
head_mask
,
sequence_length
,
):
):
# look at previous and following chunks if chunked attention
# look at previous and following chunks if chunked attention
...
@@ -595,11 +600,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
...
@@ -595,11 +600,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# if chunked attention split bucket idxs to query and key
# if chunked attention split bucket idxs to query and key
if
self
.
chunk_length
<
sequence_length
:
if
self
.
chunk_length
<
sequence_length
:
query_bucket_idx
=
self
.
_split_seq_length_dim_to
(
query_bucket_idx
=
self
.
_split_seq_length_dim_to
(
sorted_bucket_idx
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
sorted_bucket_idx
_per_hash
,
-
1
,
self
.
chunk_length
,
self
.
num_attention_heads
)
)
key_value_bucket_idx
=
self
.
_look_adjacent
(
query_bucket_idx
,
self
.
num_chunks_before
,
self
.
num_chunks_after
)
key_value_bucket_idx
=
self
.
_look_adjacent
(
query_bucket_idx
,
self
.
num_chunks_before
,
self
.
num_chunks_after
)
else
:
else
:
query_bucket_idx
=
key_value_bucket_idx
=
sorted_bucket_idx
query_bucket_idx
=
key_value_bucket_idx
=
sorted_bucket_idx
_per_hash
# get correct mask values depending on precision
# get correct mask values depending on precision
if
query_key_dots
.
dtype
==
torch
.
float16
:
if
query_key_dots
.
dtype
==
torch
.
float16
:
...
@@ -741,11 +746,10 @@ class ReverseSort(Function):
...
@@ -741,11 +746,10 @@ class ReverseSort(Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
,
num_hashes
):
def
forward
(
ctx
,
out_vectors
,
logits
,
sorted_bucket_idx
,
undo_sorted_bucket_idx
):
# save sorted_bucket_idx for backprop
# save sorted_bucket_idx for backprop
with
torch
.
no_grad
():
with
torch
.
no_grad
():
ctx
.
sorted_bucket_idx
=
sorted_bucket_idx
ctx
.
sorted_bucket_idx
=
sorted_bucket_idx
ctx
.
num_hashes
=
num_hashes
# undo sort to have correct order for next layer
# undo sort to have correct order for next layer
expanded_undo_sort_indices
=
undo_sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
out_vectors
.
shape
)
expanded_undo_sort_indices
=
undo_sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
out_vectors
.
shape
)
...
@@ -757,35 +761,14 @@ class ReverseSort(Function):
...
@@ -757,35 +761,14 @@ class ReverseSort(Function):
def
backward
(
ctx
,
grad_out_vectors
,
grad_logits
):
def
backward
(
ctx
,
grad_out_vectors
,
grad_logits
):
# get parameters saved in ctx
# get parameters saved in ctx
sorted_bucket_idx
=
ctx
.
sorted_bucket_idx
sorted_bucket_idx
=
ctx
.
sorted_bucket_idx
num_hashes
=
ctx
.
num_hashes
# get real gradient shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes
grad_logits_shape
=
grad_logits
.
shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes x ChunkLen
grad_out_vectors_shape
=
grad_out_vectors
.
shape
# split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen
grad_logits
=
grad_logits
.
view
((
grad_logits_shape
[:
2
]
+
(
num_hashes
,
-
1
)))
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen
grad_out_vectors
=
grad_out_vectors
.
view
(
(
grad_out_vectors_shape
[:
2
]
+
(
num_hashes
,
-
1
)
+
grad_out_vectors_shape
[
-
1
:])
)
# reshape and expand
sorted_bucket_idx
=
torch
.
reshape
(
sorted_bucket_idx
,
(
sorted_bucket_idx
.
shape
[:
2
]
+
(
num_hashes
,
-
1
)))
expanded_sort_indices
=
sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
grad_out_vectors
.
shape
)
expanded_sort_indices
=
sorted_bucket_idx
.
unsqueeze
(
-
1
).
expand
(
grad_out_vectors
.
shape
)
# reverse sort of forward
# reverse sort of forward
grad_out_vectors
=
torch
.
gather
(
grad_out_vectors
,
3
,
expanded_sort_indices
)
grad_out_vectors
=
torch
.
gather
(
grad_out_vectors
,
2
,
expanded_sort_indices
)
grad_logits
=
torch
.
gather
(
grad_logits
,
3
,
sorted_bucket_idx
)
grad_logits
=
torch
.
gather
(
grad_logits
,
2
,
sorted_bucket_idx
)
# reshape into correct shape
grad_logits
=
torch
.
reshape
(
grad_logits
,
grad_logits_shape
)
grad_out_vectors
=
torch
.
reshape
(
grad_out_vectors
,
grad_out_vectors_shape
)
# return grad and `None` fillers for last
3
forward args
# return grad and `None` fillers for last
2
forward args
return
grad_out_vectors
,
grad_logits
,
None
,
None
,
None
return
grad_out_vectors
,
grad_logits
,
None
,
None
class
LocalSelfAttention
(
nn
.
Module
,
EfficientAttentionMixin
):
class
LocalSelfAttention
(
nn
.
Module
,
EfficientAttentionMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment