Unverified Commit 1897874e authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix falcon with SDPA, alibi but no passed mask (#30123)



* fix falcon without attention_mask & alibi

* add test

* Update tests/models/falcon/test_modeling_falcon.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 1773afce
...@@ -1098,16 +1098,12 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1098,16 +1098,12 @@ class FalconModel(FalconPreTrainedModel):
elif head_mask is None: elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
# We take care to integrate alibi bias in the attention_mask here. # We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill( attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads), alibi / math.sqrt(self.config.hidden_size // self.num_heads),
......
...@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase): ...@@ -666,3 +666,27 @@ class FalconLanguageGenerationTest(unittest.TestCase):
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
self.assertEqual(unpadded_gen_text[0], expected_output) self.assertEqual(unpadded_gen_text[0], expected_output)
self.assertEqual(padded_gen_text[0], expected_output) self.assertEqual(padded_gen_text[0], expected_output)
@slow
@require_torch_sdpa
def test_falcon_alibi_sdpa_matches_eager(self):
input_ids = torch.randint(0, 1000, (5, 20))
config = FalconConfig(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=3,
num_attention_heads=4,
new_decoder_architecture=True,
alibi=True,
)
falcon = FalconForCausalLM(config)
falcon = falcon.eval()
with torch.no_grad():
# output_attentions=True dispatches to eager path
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
falcon_output_sdpa = falcon(input_ids)[0]
self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment