Unverified Commit 1ac2463d authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

[`FA2`] Add flash attention for for `DistilBert` (#26489)

* flash attention added for DistilBert

* fixes

* removed padding_masks

* Update modeling_distilbert.py

* Update test_modeling_distilbert.py

* style fix
parent 5964f820
...@@ -133,6 +133,37 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ...@@ -133,6 +133,37 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
- A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker). - A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker).
- A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker). - A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker).
## Combining DistilBERT and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModel
>>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> text = "Replace me by any text you'd like."
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
>>> model.to(device)
>>> output = model(**encoded_input)
```
## DistilBertConfig ## DistilBertConfig
[[autodoc]] DistilBertConfig [[autodoc]] DistilBertConfig
......
...@@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union ...@@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -44,12 +45,18 @@ from ...utils import ( ...@@ -44,12 +45,18 @@ from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "distilbert-base-uncased" _CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
_CONFIG_FOR_DOC = "DistilBertConfig" _CONFIG_FOR_DOC = "DistilBertConfig"
...@@ -69,6 +76,19 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -69,6 +76,19 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed import deepspeed
...@@ -141,10 +161,12 @@ class Embeddings(nn.Module): ...@@ -141,10 +161,12 @@ class Embeddings(nn.Module):
class MultiHeadSelfAttention(nn.Module): class MultiHeadSelfAttention(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
self.config = config
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.dim = config.dim self.dim = config.dim
self.dropout = nn.Dropout(p=config.attention_dropout) self.dropout = nn.Dropout(p=config.attention_dropout)
self.is_causal = False
# Have an even number of multi heads that divide the dimensions # Have an even number of multi heads that divide the dimensions
if self.dim % self.n_heads != 0: if self.dim % self.n_heads != 0:
...@@ -240,6 +262,178 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -240,6 +262,178 @@ class MultiHeadSelfAttention(nn.Module):
return (context,) return (context,)
class DistilBertFlashAttention2(MultiHeadSelfAttention):
"""
DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, ...]:
"""
Parameters:
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
"""
batch_size, q_length, dim = query.size()
dim_per_head = self.dim // self.n_heads
def reshape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(batch_size, -1, self.n_heads, dim_per_head)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = reshape(self.q_lin(query))
key_states = reshape(self.k_lin(key))
value_states = reshape(self.v_lin(value))
attn_dropout = self.config.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
if query_states.dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_lin.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_weights = self._flash_attention_forward(
query_states, key_states, value_states, mask, q_length, dropout=attn_dropout
)
attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
attn_output = self.out_lin(attn_weights_reshaped)
if output_attentions:
return (attn_output, attn_weights)
else:
return (attn_output,)
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
...@@ -269,7 +463,11 @@ class TransformerBlock(nn.Module): ...@@ -269,7 +463,11 @@ class TransformerBlock(nn.Module):
if config.dim % config.n_heads != 0: if config.dim % config.n_heads != 0:
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
self.attention = MultiHeadSelfAttention(config) self.attention = (
MultiHeadSelfAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else DistilBertFlashAttention2(config)
)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config) self.ffn = FFN(config)
...@@ -407,6 +605,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -407,6 +605,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None load_tf_weights = None
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
"""Initialize the weights.""" """Initialize the weights."""
...@@ -588,14 +787,17 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -588,14 +787,17 @@ class DistilBertModel(DistilBertPreTrainedModel):
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
return self.transformer( return self.transformer(
x=embeddings, x=embeddings,
attn_mask=attention_mask, attn_mask=attention_mask,
......
...@@ -16,8 +16,10 @@ import os ...@@ -16,8 +16,10 @@ import os
import tempfile import tempfile
import unittest import unittest
from pytest import mark
from transformers import DistilBertConfig, is_torch_available from transformers import DistilBertConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
...@@ -285,6 +287,114 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -285,6 +287,114 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
dummy_input = torch.LongTensor(
[
[1, 2, 3, 4],
[1, 2, 8, 9],
[1, 2, 11, 12],
[1, 2, 13, 14],
]
).to(torch_device)
dummy_attention_mask = torch.LongTensor(
[
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
]
).to(torch_device)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
dummy_input = torch.LongTensor(
[
[1, 2, 3, 4],
[1, 2, 8, 9],
[1, 2, 11, 12],
[1, 2, 13, 14],
]
).to(torch_device)
dummy_attention_mask = torch.LongTensor(
[
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
]
).to(torch_device)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
@require_torch @require_torch
class DistilBertModelIntergrationTest(unittest.TestCase): class DistilBertModelIntergrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment