Unverified Commit 98109464 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

clean reformer reverse sort (#5343)

parent 1af58c07
......@@ -384,11 +384,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
)
# make sure bucket idx is not longer then sequence length
sorted_bucket_idx = sorted_bucket_idx % sequence_length
sorted_bucket_idx_per_hash = sorted_bucket_idx % sequence_length
# cluster query key value vectors according to hashed buckets
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes)
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
......@@ -403,7 +403,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
else:
# get sequence length indices
sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
sorted_bucket_idx_per_hash = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
......@@ -415,7 +415,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_vectors=query_key_vectors,
key_vectors=key_vectors,
value_vectors=value_vectors,
sorted_bucket_idx=sorted_bucket_idx,
sorted_bucket_idx_per_hash=sorted_bucket_idx_per_hash,
attention_mask=attention_mask,
head_mask=head_mask,
sequence_length=sequence_length,
......@@ -427,9 +427,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# re-order out_vectors and logits
if self.chunk_length < sequence_length:
# sort clusters back to correct ordering
out_vectors, logits = ReverseSort.apply(
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
)
out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx)
# sum up all hash rounds
if num_hashes > 1:
......@@ -578,7 +576,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self.num_buckets = num_buckets
def _attend(
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length
self,
query_vectors,
key_vectors,
value_vectors,
sorted_bucket_idx_per_hash,
attention_mask,
head_mask,
sequence_length,
):
# look at previous and following chunks if chunked attention
......@@ -595,11 +600,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# if chunked attention split bucket idxs to query and key
if self.chunk_length < sequence_length:
query_bucket_idx = self._split_seq_length_dim_to(
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads
)
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
else:
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx_per_hash
# get correct mask values depending on precision
if query_key_dots.dtype == torch.float16:
......@@ -741,11 +746,10 @@ class ReverseSort(Function):
"""
@staticmethod
def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, num_hashes):
def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx):
# save sorted_bucket_idx for backprop
with torch.no_grad():
ctx.sorted_bucket_idx = sorted_bucket_idx
ctx.num_hashes = num_hashes
# undo sort to have correct order for next layer
expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape)
......@@ -757,35 +761,14 @@ class ReverseSort(Function):
def backward(ctx, grad_out_vectors, grad_logits):
# get parameters saved in ctx
sorted_bucket_idx = ctx.sorted_bucket_idx
num_hashes = ctx.num_hashes
# get real gradient shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes
grad_logits_shape = grad_logits.shape
# shape is BatchSize x NumAttnHeads x ChunkLen * NumHashes x ChunkLen
grad_out_vectors_shape = grad_out_vectors.shape
# split gradient vectors and sorted bucket idxs by concatenated chunk dimension to gather correct indices
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen
grad_logits = grad_logits.view((grad_logits_shape[:2] + (num_hashes, -1)))
# shape is BatchSize x NumAttnHeads x NumHashes x ChunkLen x ChunkLen
grad_out_vectors = grad_out_vectors.view(
(grad_out_vectors_shape[:2] + (num_hashes, -1) + grad_out_vectors_shape[-1:])
)
# reshape and expand
sorted_bucket_idx = torch.reshape(sorted_bucket_idx, (sorted_bucket_idx.shape[:2] + (num_hashes, -1)))
expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape)
# reverse sort of forward
grad_out_vectors = torch.gather(grad_out_vectors, 3, expanded_sort_indices)
grad_logits = torch.gather(grad_logits, 3, sorted_bucket_idx)
# reshape into correct shape
grad_logits = torch.reshape(grad_logits, grad_logits_shape)
grad_out_vectors = torch.reshape(grad_out_vectors, grad_out_vectors_shape)
grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices)
grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx)
# return grad and `None` fillers for last 3 forward args
return grad_out_vectors, grad_logits, None, None, None
# return grad and `None` fillers for last 2 forward args
return grad_out_vectors, grad_logits, None, None
class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment