Unverified Commit 9ca48573 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Improved memory if input is shorter than chunk length (#4720)

* improve handling of short inputs for reformer

* correct typo in assert statement

* fix other tests
parent b231a413
......@@ -351,41 +351,48 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
value_vectors.shape[-1], self.attention_head_size
)
# set `num_buckets` on the fly, recommended way to do it
if self.num_buckets is None:
self._set_num_buckets(sequence_length)
# LSH attention only makes sense if chunked attention should be performed
if self.chunk_length < sequence_length:
# set `num_buckets` on the fly, recommended way to do it
if self.num_buckets is None:
self._set_num_buckets(sequence_length)
# use cached buckets for backprop only
if buckets is None:
# hash query key vectors into buckets
buckets = self._hash_vectors(query_key_vectors, num_hashes)
# use cached buckets for backprop only
if buckets is None:
# hash query key vectors into buckets
buckets = self._hash_vectors(query_key_vectors, num_hashes)
assert (
int(buckets.shape[-1]) == num_hashes * sequence_length
), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length)
assert (
int(buckets.shape[-1]) == num_hashes * sequence_length
), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length)
sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
sequence_length, buckets, num_hashes
)
sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
sequence_length, buckets, num_hashes
)
# make sure bucket idx is not longer then sequence length
sorted_bucket_idx = sorted_bucket_idx % sequence_length
# make sure bucket idx is not longer then sequence length
sorted_bucket_idx = sorted_bucket_idx % sequence_length
# cluster query key value vectors according to hashed buckets
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
# cluster query key value vectors according to hashed buckets
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
if self.chunk_length is None:
assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
if self.chunk_length is None:
assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
else:
# get sequence length indices
sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
# scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors)
......@@ -398,31 +405,35 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
sorted_bucket_idx=sorted_bucket_idx,
attention_mask=attention_mask,
head_mask=head_mask,
sequence_length=sequence_length,
)
# free memory
del query_key_vectors, key_vectors, value_vectors
# sort clusters back to correct ordering
out_vectors, logits = ReverseSort.apply(
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
)
# sum up all hash rounds
if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
# re-order out_vectors and logits
if self.chunk_length < sequence_length:
# sort clusters back to correct ordering
out_vectors, logits = ReverseSort.apply(
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
)
logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
).unsqueeze(-1)
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
out_vectors = torch.sum(out_vectors * probs_vectors, dim=2)
# free memory
del probs_vectors
# sum up all hash rounds
if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
)
logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
).unsqueeze(-1)
# free memory
del logits
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
out_vectors = torch.sum(out_vectors * probs_vectors, dim=2)
# free memory
del probs_vectors
# free memory
del logits
assert out_vectors.shape == (
batch_size,
......@@ -554,10 +565,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self.num_buckets = num_buckets
def _attend(
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask,
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length
):
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
# look at previous and following chunks if chunked attention
if self.chunk_length < sequence_length:
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
# get logits and dots
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
......@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# free memory
del query_vectors, key_vectors
query_bucket_idx = self._split_seq_length_dim_to(
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
)
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
# if chunked attention split bucket idxs to query and key
if self.chunk_length < sequence_length:
query_bucket_idx = self._split_seq_length_dim_to(
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
)
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
else:
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx
# get correct mask values depending on precision
if query_key_dots.dtype == torch.float16:
......@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self_mask_value = self.self_mask_value_float32
mask_value = self.mask_value_float32
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask)
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
if mask is not None:
query_key_dots = torch.where(mask, query_key_dots, mask_value)
......@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors
# merge chunk length
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
if self.chunk_length < sequence_length:
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
return out_vectors, logits, attention_probs
def _compute_attn_mask(self, query_indices, key_indices, attention_mask):
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
mask = None
# Causal mask
......@@ -642,15 +661,27 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
# expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
# if chunked attention, the attention mask has to correspond to LSH order
if sequence_length > self.chunk_length:
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
# expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
# free memory
del query_attn_mask, key_attn_mask
else:
# usual attention mask creation
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
query_indices.shape + attention_mask.shape[-1:]
)
# free memory
del query_attn_mask, key_attn_mask, attention_mask
del attention_mask
# multiply by casaul mask if necessary
if mask is not None:
......@@ -810,36 +841,45 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
)
# chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
# chunk indices
# get sequence length indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
# append chunks before and after
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
# if input should be chunked
if self.chunk_length < sequence_length:
# chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
# chunk indices
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
# append chunks before and after
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
else:
query_indices = key_indices = indices
# query-key matmul: QK^T
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
# free memory
del query_vectors, key_vectors
mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape)
mask = self._compute_attn_mask(
query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length
)
if mask is not None:
# get mask tensor depending on half precision or not
......@@ -874,7 +914,8 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors
# merge chunk length
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
if self.chunk_length < sequence_length:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
......@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape):
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
mask = None
# chunk attention mask and look before and after
if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
if self.chunk_length < sequence_length:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
else:
attention_mask_key = attention_mask
# Causal mask
if self.is_decoder is True:
......@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel):
# if needs padding
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0
must_pad_to_match_chunk_length = (
input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length
)
if must_pad_to_match_chunk_length:
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
......
......@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward(
assert (
input_tensors[0].shape[chunk_dim] % chunk_size == 0
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
input_tensors[0][chunk_dim], chunk_size
input_tensors[0].shape[chunk_dim], chunk_size
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
......
......@@ -388,6 +388,16 @@ class ReformerModelTester:
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
# force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs
......@@ -433,6 +443,10 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
def test_reformer_no_chunking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
@slow
def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
......@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward(self):
config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
......@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"]
config["num_buckets"] = [2, 4]
attn_mask = self._get_attn_mask()
......@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward(self):
config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"]
config["is_decoder"] = False
hidden_states = self._get_hidden_states()
......@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward_complex(self):
config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"]
attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states()
......@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor(
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
[1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment