Unverified Commit 1897874e authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix falcon with SDPA, alibi but no passed mask (#30123)



* fix falcon without attention_mask & alibi

* add test

* Update tests/models/falcon/test_modeling_falcon.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 1773afce
...@@ -1098,27 +1098,23 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1098,27 +1098,23 @@ class FalconModel(FalconPreTrainedModel):
elif head_mask is None: elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
# We take care to integrate alibi bias in the attention_mask here. # We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None: min_dtype = torch.finfo(alibi.dtype).min
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) attention_mask = torch.masked_fill(
else: alibi / math.sqrt(self.config.hidden_size // self.num_heads),
min_dtype = torch.finfo(alibi.dtype).min attention_mask < -1,
attention_mask = torch.masked_fill( min_dtype,
alibi / math.sqrt(self.config.hidden_size // self.num_heads), )
attention_mask < -1,
min_dtype, # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
) # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else: else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
......
...@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase): ...@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase):
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
self.assertEqual(unpadded_gen_text[0], expected_output) self.assertEqual(unpadded_gen_text[0], expected_output)
self.assertEqual(padded_gen_text[0], expected_output) self.assertEqual(padded_gen_text[0], expected_output)
@slow
@require_torch_sdpa
def test_falcon_alibi_sdpa_matches_eager(self):
input_ids = torch.randint(0, 1000, (5, 20))
config = FalconConfig(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=3,
num_attention_heads=4,
new_decoder_architecture=True,
alibi=True,
)
falcon = FalconForCausalLM(config)
falcon = falcon.eval()
with torch.no_grad():
# output_attentions=True dispatches to eager path
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
falcon_output_sdpa = falcon(input_ids)[0]
self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment