"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0bab55d5d52e4d538888980d05d73acc6da6274a"
Unverified Commit 82cc0a79 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix flash attention bugs with Mistral and Falcon (#27625)

* fix various bugs with flash attention

* bump

* fix test

* fix mistral

* use skiptest instead of return that may be misleading

* fix on review
parent f93c1e9e
...@@ -564,6 +564,12 @@ class FalconFlashAttention2(FalconAttention): ...@@ -564,6 +564,12 @@ class FalconFlashAttention2(FalconAttention):
past_key_value = (key_layer, value_layer) if use_cache else None past_key_value = (key_layer, value_layer) if use_cache else None
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_layer = query_layer.transpose(1, 2)
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)
if alibi is not None: if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True") raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
......
...@@ -838,7 +838,7 @@ class MistralModel(MistralPreTrainedModel): ...@@ -838,7 +838,7 @@ class MistralModel(MistralPreTrainedModel):
attention_mask is not None attention_mask is not None
and hasattr(self.config, "_flash_attn_2_enabled") and hasattr(self.config, "_flash_attn_2_enabled")
and self.config._flash_attn_2_enabled and self.config._flash_attn_2_enabled
and past_key_values is not None and use_cache
): ):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right: if is_padding_right:
......
...@@ -22,6 +22,7 @@ from parameterized import parameterized ...@@ -22,6 +22,7 @@ from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
...@@ -385,6 +386,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -385,6 +386,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
......
...@@ -375,9 +375,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -375,9 +375,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
import torch import torch
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -405,36 +402,49 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -405,36 +402,49 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_generate_use_cache(self):
import torch import torch
for model_class in self.all_model_classes: max_new_tokens = 30
if not model_class._supports_flash_attn_2:
return for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
) # NOTE: Mistral apparently does not support right padding + use_cache with FA2.
model_fa.to(torch_device) dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained( model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False tmpdirname,
) torch_dtype=torch.float16,
model.to(torch_device) use_flash_attention_2=True,
low_cpu_mem_usage=True,
).to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) # Just test that a large cache works as expected
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device) _ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
_ = model(dummy_input, output_hidden_states=True).hidden_states[-1] @require_flash_attn
with self.assertRaises(ValueError): @require_torch_gpu
_ = model_fa( @pytest.mark.flash_attn_test
dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True @slow
).hidden_states[-1] def test_flash_attn_2_inference_padding_right(self):
self.skipTest("Mistral flash attention does not support right padding")
@require_torch @require_torch
......
...@@ -2835,7 +2835,7 @@ class ModelTesterMixin: ...@@ -2835,7 +2835,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
model = model_class(config) model = model_class(config)
...@@ -2860,7 +2860,7 @@ class ModelTesterMixin: ...@@ -2860,7 +2860,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -2957,7 +2957,7 @@ class ModelTesterMixin: ...@@ -2957,7 +2957,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -3050,7 +3050,7 @@ class ModelTesterMixin: ...@@ -3050,7 +3050,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -3093,7 +3093,7 @@ class ModelTesterMixin: ...@@ -3093,7 +3093,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -3109,7 +3109,7 @@ class ModelTesterMixin: ...@@ -3109,7 +3109,7 @@ class ModelTesterMixin:
dummy_input = dummy_input.to(torch.float16) dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding # make sure we do right padding
dummy_attention_mask[:, :-1] = 1 dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0 dummy_attention_mask[:, -1:] = 0
...@@ -3138,7 +3138,7 @@ class ModelTesterMixin: ...@@ -3138,7 +3138,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -3179,7 +3179,7 @@ class ModelTesterMixin: ...@@ -3179,7 +3179,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
...@@ -3279,7 +3279,7 @@ class ModelTesterMixin: ...@@ -3279,7 +3279,7 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn_2:
return self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes # TODO: to change it in the future with other relevant auto classes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment