Unverified Commit 9ca48573 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Improved memory if input is shorter than chunk length (#4720)

* improve handling of short inputs for reformer

* correct typo in assert statement

* fix other tests
parent b231a413
...@@ -351,41 +351,48 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -351,41 +351,48 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
value_vectors.shape[-1], self.attention_head_size value_vectors.shape[-1], self.attention_head_size
) )
# set `num_buckets` on the fly, recommended way to do it # LSH attention only makes sense if chunked attention should be performed
if self.num_buckets is None: if self.chunk_length < sequence_length:
self._set_num_buckets(sequence_length) # set `num_buckets` on the fly, recommended way to do it
if self.num_buckets is None:
self._set_num_buckets(sequence_length)
# use cached buckets for backprop only # use cached buckets for backprop only
if buckets is None: if buckets is None:
# hash query key vectors into buckets # hash query key vectors into buckets
buckets = self._hash_vectors(query_key_vectors, num_hashes) buckets = self._hash_vectors(query_key_vectors, num_hashes)
assert ( assert (
int(buckets.shape[-1]) == num_hashes * sequence_length int(buckets.shape[-1]) == num_hashes * sequence_length
), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length) ), "last dim of buckets is {}, but should be {}".format(buckets.shape[-1], num_hashes * sequence_length)
sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx( sorted_bucket_idx, undo_sorted_bucket_idx = self._get_sorted_bucket_idx_and_undo_sorted_bucket_idx(
sequence_length, buckets, num_hashes sequence_length, buckets, num_hashes
) )
# make sure bucket idx is not longer then sequence length # make sure bucket idx is not longer then sequence length
sorted_bucket_idx = sorted_bucket_idx % sequence_length sorted_bucket_idx = sorted_bucket_idx % sequence_length
# cluster query key value vectors according to hashed buckets # cluster query key value vectors according to hashed buckets
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes) query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx, num_hashes)
query_key_vectors = self._split_seq_length_dim_to( query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
) )
if self.chunk_length is None: if self.chunk_length is None:
assert ( assert (
self.num_chunks_before == 0 and self.num_chunks_after == 0 self.num_chunks_before == 0 and self.num_chunks_after == 0
), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
else:
# get sequence length indices
sorted_bucket_idx = torch.arange(sequence_length, device=query_key_vectors.device).repeat(
batch_size, self.num_attention_heads, 1
)
# scale key vectors # scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors) key_vectors = self._len_and_dim_norm(query_key_vectors)
...@@ -398,31 +405,35 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -398,31 +405,35 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
sorted_bucket_idx=sorted_bucket_idx, sorted_bucket_idx=sorted_bucket_idx,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
sequence_length=sequence_length,
) )
# free memory # free memory
del query_key_vectors, key_vectors, value_vectors del query_key_vectors, key_vectors, value_vectors
# sort clusters back to correct ordering # re-order out_vectors and logits
out_vectors, logits = ReverseSort.apply( if self.chunk_length < sequence_length:
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes # sort clusters back to correct ordering
) out_vectors, logits = ReverseSort.apply(
out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx, self.num_hashes
# sum up all hash rounds
if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
) )
logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
).unsqueeze(-1)
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) # sum up all hash rounds
out_vectors = torch.sum(out_vectors * probs_vectors, dim=2) if num_hashes > 1:
# free memory out_vectors = self._split_seq_length_dim_to(
del probs_vectors out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
)
logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
).unsqueeze(-1)
# free memory probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
del logits out_vectors = torch.sum(out_vectors * probs_vectors, dim=2)
# free memory
del probs_vectors
# free memory
del logits
assert out_vectors.shape == ( assert out_vectors.shape == (
batch_size, batch_size,
...@@ -554,10 +565,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -554,10 +565,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self.num_buckets = num_buckets self.num_buckets = num_buckets
def _attend( def _attend(
self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, self, query_vectors, key_vectors, value_vectors, sorted_bucket_idx, attention_mask, head_mask, sequence_length
): ):
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # look at previous and following chunks if chunked attention
if self.chunk_length < sequence_length:
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
# get logits and dots # get logits and dots
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
...@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -565,10 +579,14 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# free memory # free memory
del query_vectors, key_vectors del query_vectors, key_vectors
query_bucket_idx = self._split_seq_length_dim_to( # if chunked attention split bucket idxs to query and key
sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads if self.chunk_length < sequence_length:
) query_bucket_idx = self._split_seq_length_dim_to(
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) sorted_bucket_idx, -1, self.chunk_length, self.num_attention_heads
)
key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after)
else:
query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx
# get correct mask values depending on precision # get correct mask values depending on precision
if query_key_dots.dtype == torch.float16: if query_key_dots.dtype == torch.float16:
...@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -578,7 +596,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self_mask_value = self.self_mask_value_float32 self_mask_value = self.self_mask_value_float32
mask_value = self.mask_value_float32 mask_value = self.mask_value_float32
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask) mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
if mask is not None: if mask is not None:
query_key_dots = torch.where(mask, query_key_dots, mask_value) query_key_dots = torch.where(mask, query_key_dots, mask_value)
...@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -627,12 +645,13 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors del value_vectors
# merge chunk length # merge chunk length
logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) if self.chunk_length < sequence_length:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1)
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
return out_vectors, logits, attention_probs return out_vectors, logits, attention_probs
def _compute_attn_mask(self, query_indices, key_indices, attention_mask): def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
mask = None mask = None
# Causal mask # Causal mask
...@@ -642,15 +661,27 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -642,15 +661,27 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx # Attention mask: chunk, look up correct mask value from key_value_bucket_idx
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why. # IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :] # if chunked attention, the attention mask has to correspond to LSH order
# expand attn_mask to fit with key_value_bucket_idx shape if sequence_length > self.chunk_length:
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
key_attn_mask = torch.gather(attention_mask, -1, key_indices) # expand attn_mask to fit with key_value_bucket_idx shape
query_attn_mask = torch.gather(attention_mask, -1, query_indices) attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk key_attn_mask = torch.gather(attention_mask, -1, key_indices)
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2) query_attn_mask = torch.gather(attention_mask, -1, query_indices)
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
# free memory
del query_attn_mask, key_attn_mask
else:
# usual attention mask creation
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
query_indices.shape + attention_mask.shape[-1:]
)
# free memory # free memory
del query_attn_mask, key_attn_mask, attention_mask del attention_mask
# multiply by casaul mask if necessary # multiply by casaul mask if necessary
if mask is not None: if mask is not None:
...@@ -810,36 +841,45 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -810,36 +841,45 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype) torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
) )
# chunk vectors # get sequence length indices
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
# chunk indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat( indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
batch_size, self.num_attention_heads, 1 batch_size, self.num_attention_heads, 1
) )
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
# append chunks before and after # if input should be chunked
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) if self.chunk_length < sequence_length:
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # chunk vectors
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after) # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
)
# chunk indices
query_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
key_indices = self._split_seq_length_dim_to(indices, -1, self.chunk_length, self.num_attention_heads)
# append chunks before and after
key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after)
value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after)
key_indices = self._look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after)
else:
query_indices = key_indices = indices
# query-key matmul: QK^T
query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
# free memory # free memory
del query_vectors, key_vectors del query_vectors, key_vectors
mask = self._compute_attn_mask(query_indices, key_indices, attention_mask, query_key_dots.shape) mask = self._compute_attn_mask(
query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length
)
if mask is not None: if mask is not None:
# get mask tensor depending on half precision or not # get mask tensor depending on half precision or not
...@@ -874,7 +914,8 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -874,7 +914,8 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
del value_vectors del value_vectors
# merge chunk length # merge chunk length
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) if self.chunk_length < sequence_length:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
...@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -885,14 +926,18 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape): def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
mask = None mask = None
# chunk attention mask and look before and after # chunk attention mask and look before and after
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) if self.chunk_length < sequence_length:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
else:
attention_mask_key = attention_mask
# Causal mask # Causal mask
if self.is_decoder is True: if self.is_decoder is True:
...@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1564,7 +1609,9 @@ class ReformerModel(ReformerPreTrainedModel):
# if needs padding # if needs padding
least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config)
must_pad_to_match_chunk_length = input_shape[-1] % least_common_mult_chunk_length != 0 must_pad_to_match_chunk_length = (
input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length
)
if must_pad_to_match_chunk_length: if must_pad_to_match_chunk_length:
padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length padding_length = least_common_mult_chunk_length - input_shape[-1] % least_common_mult_chunk_length
......
...@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward( ...@@ -2188,7 +2188,7 @@ def apply_chunking_to_forward(
assert ( assert (
input_tensors[0].shape[chunk_dim] % chunk_size == 0 input_tensors[0].shape[chunk_dim] % chunk_size == 0
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
input_tensors[0][chunk_dim], chunk_size input_tensors[0].shape[chunk_dim], chunk_size
) )
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
......
...@@ -388,6 +388,16 @@ class ReformerModelTester: ...@@ -388,6 +388,16 @@ class ReformerModelTester:
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
# force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask,) = config_and_inputs (config, input_ids, input_mask,) = config_and_inputs
...@@ -433,6 +443,10 @@ class ReformerTesterMixin: ...@@ -433,6 +443,10 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
def test_reformer_no_chunking(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
@slow @slow
def test_dropout_random_seed_is_changing(self): def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -772,6 +786,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward(self): def test_lsh_layer_forward(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"] config["attn_layers"] = ["lsh"]
config["is_decoder"] = False config["is_decoder"] = False
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -787,6 +802,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_lsh_layer_forward_complex(self): def test_lsh_layer_forward_complex(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["lsh_num_chunks_before"] = 0
config["attn_layers"] = ["lsh"] config["attn_layers"] = ["lsh"]
config["num_buckets"] = [2, 4] config["num_buckets"] = [2, 4]
attn_mask = self._get_attn_mask() attn_mask = self._get_attn_mask()
...@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -805,6 +821,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward(self): def test_local_layer_forward(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"] config["attn_layers"] = ["local"]
config["is_decoder"] = False config["is_decoder"] = False
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -820,6 +837,7 @@ class ReformerIntegrationTests(unittest.TestCase):
def test_local_layer_forward_complex(self): def test_local_layer_forward_complex(self):
config = self._get_basic_config_and_input() config = self._get_basic_config_and_input()
config["local_num_chunks_before"] = 0
config["attn_layers"] = ["local"] config["attn_layers"] = ["local"]
attn_mask = self._get_attn_mask() attn_mask = self._get_attn_mask()
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -829,7 +847,7 @@ class ReformerIntegrationTests(unittest.TestCase):
reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,) reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
output_slice = reformer_output.hidden_states[0, 0, :5] output_slice = reformer_output.hidden_states[0, 0, :5]
expected_output_slice = torch.tensor( expected_output_slice = torch.tensor(
[1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device, [1.4750, -2.0235, -0.9743, 1.4463, -0.1269], dtype=torch.float, device=torch_device,
) )
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment