Unverified Commit 7f552e28 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Gemma2 and flash-attention (#32188)

* enable flash-attn & static cache

* this works, not the prev

* fix for sliding window layers

* not needed anymore
parent a3264332
......@@ -327,6 +327,11 @@ class Gemma2FlashAttention2(Gemma2Attention):
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
......@@ -510,16 +515,18 @@ class Gemma2DecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states
......@@ -824,10 +831,12 @@ class Gemma2Model(Gemma2PreTrainedModel):
past_key_values: Cache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
......
......@@ -16,8 +16,11 @@
import unittest
from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
......@@ -161,3 +164,28 @@ class Gemma2IntegrationTest(unittest.TestCase):
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
@require_read_token
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="flash_attention_2", torch_dtype="float16"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXTS)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment