Unverified Commit 9ca48573 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Improved memory if input is shorter than chunk length (#4720)

* improve handling of short inputs for reformer

* correct typo in assert statement

* fix other tests
parent b231a413
...@@ -351,6 +351,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -351,6 +351,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
value_vectors.shape[-1], self.attention_head_size value_vectors.shape[-1], self.attention_head_size
) )
# LSH attention only makes sense if chunked attention should be performed
if self.chunk_length < sequence_length:
# set `num_buckets` on the fly, recommended way to do it # set `num_buckets` on the fly, recommended way to do it
if self.num_buckets is None: if self.num_buckets is None:
self._set_num_buckets(sequence_length) self._set_num_buckets(sequence_length)
...@@ -386,6 +388,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -386,6 +388,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
assert ( assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0 self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
else:
# get sequence length indices
sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
# scale key vectors # scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors) key_vectors = self._len_and_dim_norm(query_key_vectors)
...@@ -398,10 +405,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -398,10 +405,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
sorted_bucket_idx=sorted_bucket_idx, sorted_bucket_idx=sorted_bucket_idx,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
sequence_length=sequence_length,
) )
# free memory # free memory
del query_key_vectors, key_vectors, value_vectors del query_key_vectors, key_vectors, value_vectors
# re-order out_vectors and logits
if self.chunk_length < sequence_length:
# sort clusters back to correct ordering # sort clusters back to correct ordering
out_vectors, logits = ReverseSort.apply( out_vectors, logits = ReverseSort.apply(
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
...@@ -554,8 +565,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -554,8 +565,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self.num_buckets = num_buckets self.num_buckets = num_buckets
def _attend( def _attend(
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length
): ):
# look at previous and following chunks if chunked attention
if self.chunk_length < sequence_length:
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
...@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# free memory # free memory
del query_vectors, key_vectors del query_vectors, key_vectors
# if chunked attention split bucket idxs to query and key
if self.chunk_length < sequence_length:
query_bucket_idx = self._split_seq_length_dim_to( query_bucket_idx = self._split_seq_length_dim_to(
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
) )
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
else:
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx
# get correct mask values depending on precision # get correct mask values depending on precision
if query_key_dots.dtype == torch.float16: if query_key_dots.dtype == torch.float16:
...@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self_mask_value = self.self_mask_value_float32 self_mask_value = self.self_mask_value_float32
mask_value = self.mask_value_float32 mask_value = self.mask_value_float32
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask) mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
if mask is not None: if mask is not None:
query_key_dots = torch.where(mask, query_key_dots, mask_value) query_key_dots = torch.where(mask, query_key_dots, mask_value)
...@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors del value_vectors
# merge chunk length # merge chunk length
if self.chunk_length < sequence_length:
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
return out_vectors, logits, attention_probs return out_vectors, logits, attention_probs
def _compute_attn_mask(self, query_indices, key_indices, attention_mask): def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
mask = None mask = None
# Causal mask # Causal mask
...@@ -642,6 +661,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -642,6 +661,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx # Attention mask: chunk, look up correct mask value from key_value_bucket_idx
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
if attention_mask is not None: if attention_mask is not None:
# if chunked attention, the attention mask has to correspond to LSH order
if sequence_length > self.chunk_length:
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
# expand attn_mask to fit with key_value_bucket_idx shape # expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
...@@ -649,8 +670,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -649,8 +670,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_attn_mask = torch.gather(attention_mask, -1, query_indices) query_attn_mask = torch.gather(attention_mask, -1, query_indices)
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk # expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
# free memory
del query_attn_mask, key_attn_mask
else:
# usual attention mask creation
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
query_indices.shape + attention_mask.shape[-1:]
)
# free memory # free memory
del query_attn_mask, key_attn_mask, attention_mask del attention_mask
# multiply by casaul mask if necessary # multiply by casaul mask if necessary
if mask is not None: if mask is not None:
...@@ -810,6 +841,13 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -810,6 +841,13 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype) torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
) )
# get sequence length indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
# if input should be chunked
if self.chunk_length < sequence_length:
# chunk vectors # chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to( query_vectors = self._split_seq_length_dim_to(
...@@ -823,9 +861,6 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -823,9 +861,6 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
) )
# chunk indices # chunk indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads) key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
...@@ -833,13 +868,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -833,13 +868,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
else:
query_indices = key_indices = indices
# query-key matmul: QK^T
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
# free memory # free memory
del query_vectors, key_vectors del query_vectors, key_vectors
mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape) mask = self._compute_attn_mask(
query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length
)
if mask is not None: if mask is not None:
# get mask tensor depending on half precision or not # get mask tensor depending on half precision or not
...@@ -874,6 +914,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -874,6 +914,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors del value_vectors
# merge chunk length # merge chunk length
if self.chunk_length < sequence_length:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
...@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape): def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
mask = None mask = None
# chunk attention mask and look before and after # chunk attention mask and look before and after
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = attention_mask.to(torch.uint8)[:, None, :]
if self.chunk_length < sequence_length:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
else:
attention_mask_key = attention_mask
# Causal mask # Causal mask
if self.is_decoder is True: if self.is_decoder is True:
...@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel):
# if needs padding # if needs padding
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 must_pad_to_match_chunk_length = (
input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length
)
if must_pad_to_match_chunk_length: if must_pad_to_match_chunk_length:
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
......
...@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward( ...@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward(
assert ( assert (
input_tensors[0].shape[chunk_dim] % chunk_size == 0 input_tensors[0].shape[chunk_dim] % chunk_size == 0
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
input_tensors[0][chunk_dim], chunk_size input_tensors[0].shape[chunk_dim], chunk_size
) )
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
......
...@@ -388,6 +388,16 @@ class ReformerModelTester: ...@@ -388,6 +388,16 @@ class ReformerModelTester:
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
# force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs (config, input_ids, input_mask,) = config_and_inputs
...@@ -433,6 +443,10 @@ class ReformerTesterMixin: ...@@ -433,6 +443,10 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
def test_reformer_no_chunking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
@slow @slow
def test_dropout_random_seed_is_changing(self): def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward(self): def test_lsh_layer_forward(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"] config["attn_layers"] = ["lsh"]
config["is_decoder"] = False config["is_decoder"] = False
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward_complex(self): def test_lsh_layer_forward_complex(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"] config["attn_layers"] = ["lsh"]
config["num_buckets"] = [2, 4] config["num_buckets"] = [2, 4]
attn_mask = self._get_attn_mask() attn_mask = self._get_attn_mask()
...@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward(self): def test_local_layer_forward(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"] config["attn_layers"] = ["local"]
config["is_decoder"] = False config["is_decoder"] = False
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward_complex(self): def test_local_layer_forward_complex(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"] config["attn_layers"] = ["local"]
attn_mask = self._get_attn_mask() attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,) reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
output_slice = reformer_output.hidden_states[0, 0, :5] output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor( expected_output_slice = torch.tensor(
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device, [1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
) )
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment