Unverified Commit 41c42f85 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`FA2`] Fix flash attention 2 fine-tuning with Falcon (#26852)

fix fa2 + dropout issue
parent 4b423e60
...@@ -606,7 +606,7 @@ class FalconFlashAttention2(FalconAttention): ...@@ -606,7 +606,7 @@ class FalconFlashAttention2(FalconAttention):
if alibi is not None: if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True") raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
attn_dropout = self.attention_dropout if self.training else 0.0 attn_dropout = self.config.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
......
...@@ -2810,6 +2810,10 @@ class ModelTesterMixin: ...@@ -2810,6 +2810,10 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment