Unverified Commit be3fd8a2 authored by bytebarde's avatar bytebarde Committed by GitHub
Browse files

[Flash Attention 2] Add flash attention 2 for GPT-J (#28295)



* initial implementation of flash attention for gptj

* modify flash attention and overwrite test_flash_attn_2_generate_padding_right

* update flash attention support list

* remove the copy line in the `CodeGenBlock`

* address copy mechanism

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Add GPTJ attention classes

* add expected outputs in the gptj test

* Ensure repo consistency with 'make fix-copies'

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent d522afea
......@@ -44,6 +44,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
......
......@@ -266,6 +266,7 @@ class CodeGenMLP(nn.Module):
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
class CodeGenBlock(nn.Module):
# Ignore copy
def __init__(self, config):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
......
......@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.fx
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -35,6 +36,8 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_fx_proxy,
logging,
)
......@@ -42,6 +45,11 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gptj import GPTJConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
......@@ -55,6 +63,19 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
......@@ -82,7 +103,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
class GPTJAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
......@@ -96,6 +117,8 @@ class GPTJAttention(nn.Module):
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.is_causal = True
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
......@@ -269,6 +292,256 @@ class GPTJAttention(nn.Module):
return outputs # a, present, (attentions)
class GPTJFlashAttention2(GPTJAttention):
"""
GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
# tanspose to have the desired shape
# before transpose: batch_size x seq_length x num_attention_heads x head_dim
# after transpose: batch_size x num_attention_heads x seq_length x head_dim
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
# value: batch_size x num_attention_heads x seq_length x head_dim
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else:
present = None
# The Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we need to keep the original shape for query and key, and reshape value
# to have the correct shape.
key = key.permute(0, 2, 1, 3).contiguous()
query = query.permute(0, 2, 1, 3).contiguous()
value = value.permute(0, 2, 1, 3).contiguous()
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
attention_dropout = self.config.attn_pdrop if self.training else 0.0 # attn_pdrop in gptj
query_length = query.shape[1]
# Compute attention
attn_weights = self._flash_attention_forward(
query,
key,
value,
attention_mask,
query_length,
dropout=attention_dropout,
)
# Reshape outputs
attn_output = attn_weights.reshape(
attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3]
)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->num_attention_heads
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
GPTJ_ATTENTION_CLASSES = {
"eager": GPTJAttention,
"flash_attention_2": GPTJFlashAttention2,
}
class GPTJMLP(nn.Module):
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
super().__init__()
......@@ -293,7 +566,7 @@ class GPTJBlock(nn.Module):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config)
self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = GPTJMLP(inner_dim, config)
def forward(
......@@ -343,6 +616,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -496,6 +770,8 @@ class GPTJModel(GPTJPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
warnings.warn(
......@@ -600,25 +876,26 @@ class GPTJModel(GPTJPreTrainedModel):
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if not self._use_flash_attention_2:
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......
......@@ -17,8 +17,18 @@
import datetime
import unittest
from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
import pytest
from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
tooslow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -518,6 +528,44 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
texts = ["hi", "Hello this is a very long sentence"]
expected_outputs = [
"hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
"Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6b",
device_map={"": 0},
attn_implementation="flash_attention_2",
revision="float16",
torch_dtype=torch.float16,
quantization_config=quantization_config,
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(expected_outputs, output_fa_2)
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment