Unverified Commit 81233c06 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Flash-Attn: fix generation when no attention mask or no pading (#32241)

* fix

* fix prev test (half of failures)

* [run-slow] llama, gemma2

* [run-slow] llama, gemma2
parent 27c7f971
...@@ -264,9 +264,11 @@ def _flash_attention_forward( ...@@ -264,9 +264,11 @@ def _flash_attention_forward(
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# if position_ids is provided and check not all examples (row) contain only 1 sequence, # if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach # then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all(): elif (
position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1
):
batch_size = query_states.size(0) batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids query_states, key_states, value_states, position_ids
......
...@@ -4270,6 +4270,18 @@ class ModelTesterMixin: ...@@ -4270,6 +4270,18 @@ class ModelTesterMixin:
use_cache=True, use_cache=True,
) )
# Generate with one batch only to test generation when attention mask will be None
# when real inputs are used, because there is no padding. See issue #32237 for more
dummy_input = dummy_input[:1, ...]
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes
...@@ -4342,6 +4354,8 @@ class ModelTesterMixin: ...@@ -4342,6 +4354,8 @@ class ModelTesterMixin:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name] dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]: if dummy_input.dtype in [torch.float32, torch.bfloat16]:
...@@ -4356,7 +4370,6 @@ class ModelTesterMixin: ...@@ -4356,7 +4370,6 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models # ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]: if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment