Unverified Commit 368a58e6 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` ] Integrate Flash attention 2 in most used models (#25598)



* v1

* oops

* working v1

* fixup

* add some TODOs

* fixup

* padding support + try with module replacement

* nit

* alternative design

* oops

* add `use_cache` support for llama

* v1 falcon

* nit

* a bit of refactor

* nit

* nits nits

* add v1 padding support falcon (even though it seemed to work before)

* nit

* falcon works

* fixup

* v1 tests

* nit

* fix generation llama flash

* update tests

* fix tests + nits

* fix copies

* fix nit

* test- padding mask

* stype

* add more mem efficient support

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fixup

* nit

* fixup

* remove it from config when saving

* fixup

* revert docstring

* add more checks

* use values

* oops

* new version

* fixup

* add same trick for falcon

* nit

* add another test

* change tests

* fix issues with GC and also falcon

* fixup

* oops

* Update src/transformers/models/falcon/modeling_falcon.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* add init_rope

* updates

* fix copies

* fixup

* fixup

* more clarification

* fixup

* right padding tests

* add docs

* add FA in docker image

* more clarifications

* add some figures

* add todo

* rectify comment

* Change to FA2

* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* split in two lines

* change test name

* add more tests

* some clean up

* remove `rearrange` deps

* add more docs

* revert changes on dockerfile

* Revert "revert changes on dockerfile"

This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e.

* revert changes on dockerfile

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* address some comments

* docs

* use inheritance

* Update src/transformers/testing_utils.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* fixup

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

* final comments

* clean up

* style

* add cast + warning for PEFT models

* fixup

---------
Co-authored-by: default avatarFelix Marty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent dcbfd93d
......@@ -22,6 +22,10 @@ Note: A multi GPU setup can use the majority of the strategies described in the
</Tip>
## Flash Attention 2
Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2)
## BetterTransformer
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
......
......@@ -17,6 +17,154 @@ rendered properly in your Markdown viewer.
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).
## Flash Attention 2
<Tip>
Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future.
</Tip>
Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135).
Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature.
We natively support Flash Attention 2 for the following models:
- Llama
- Falcon
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
<Tip>
Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature.
</Tip>
### Quick usage
To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments:
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
)
```
And use it for generation or fine-tuning.
### Expected speedups
You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.
To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516).
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens:
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png">
</div>
Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png">
</div>
For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens).
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png">
</div>
But for large sequence length you can benefit from interesting speedup for pure inference (also training)
Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>
### Advanced usage
You can combine this feature with many exisiting feature for model optimization. Check out few examples below:
### Combining Flash Attention 2 and 8-bit models
You can combine this feature together with 8-bit quantization:
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
)
```
### Combining Flash Attention 2 and 4-bit models
You can combine this feature together with 4-bit quantization:
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
)
```
### Combining Flash Attention 2 and PEFT
You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood:
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from peft import LoraConfig
model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
)
lora_config = LoraConfig(
r=8,
task_type="CAUSAL_LM"
)
model.add_adapter(lora_config)
... # train your model
```
## BetterTransformer
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
......
......@@ -228,6 +228,10 @@ For additional information on tf32 vs other precisions, please refer to the foll
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).
## Flash Attention 2
You can speedup the training throughput by using Flash Attention 2 integration in transformers. Check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) to learn more about how to load a model with Flash Attention 2 modules.
## Optimizer choice
The most common optimizer used to train transformer models is Adam or AdamW (Adam with weight decay). Adam achieves
......
......@@ -855,6 +855,9 @@ class PretrainedConfig(PushToHubMixin):
self.dict_torch_dtype_to_str(serializable_config_dict)
if "_flash_attn_2_enabled" in serializable_config_dict:
del serializable_config_dict["_flash_attn_2_enabled"]
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
......@@ -871,6 +874,8 @@ class PretrainedConfig(PushToHubMixin):
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_flash_attn_2_enabled" in output:
del output["_flash_attn_2_enabled"]
# Transformers version when serializing the model
output["transformers_version"] = __version__
......
......@@ -70,6 +70,7 @@ from .utils import (
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
......@@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable = False
supports_gradient_checkpointing = False
# Flash Attention 2 support
_supports_flash_attn_2 = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
......@@ -1239,6 +1243,84 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return False
return True
@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
) -> PretrainedConfig:
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
"""
if not cls._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_available():
raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it."
)
else:
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
is_flash_greater_than_2 = flash_attention_version > version.parse("2.0.0")
if not is_flash_greater_than_2:
raise ValueError(
f"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed - detected version {flash_attention_version}"
)
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
raise ValueError(
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)
if torch_dtype is None:
logger.warning(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
" unexpected behaviour."
)
if device_map is None:
if torch.cuda.is_available():
logger.warning(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
config._flash_attn_2_enabled = True
return config
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
......@@ -2374,6 +2456,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
if is_fsdp_enabled():
low_cpu_mem_usage = True
......@@ -2977,6 +3060,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
......
......@@ -364,7 +364,6 @@ class OpenLlamaDecoderLayer(nn.Module):
self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
......
......@@ -32,11 +32,21 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
)
from ..auto.configuration_auto import sanitize_code_revision
from .configuration_falcon import FalconConfig
if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -67,6 +77,19 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
class FalconRotaryEmbedding(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX.
......@@ -405,6 +428,7 @@ class FalconAttention(nn.Module):
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
......@@ -519,6 +543,185 @@ class FalconAttention(nn.Module):
return output_tensor, present
class FalconFlashAttention2(FalconAttention):
"""
Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
if layer_past is not None and use_cache:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_seq_length, _ = key_layer.shape
torch_dtype = query_layer.dtype
past_key_value = (key_layer, value_layer) if use_cache else None
query_layer = (
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
)
key_layer = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
value_layer = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
attn_dropout = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)
query_layer = query_layer.to(torch.float16)
key_layer = key_layer.to(torch.float16)
value_layer = value_layer.to(torch.float16)
attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
)
attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
attn_output = self.dense(attn_weights)
if not output_attentions:
attn_weights = None
return attn_output, past_key_value, attn_weights
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if padding_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
......@@ -540,7 +743,12 @@ class FalconDecoderLayer(nn.Module):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config)
self.self_attention = (
FalconAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else FalconFlashAttention2(config)
)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
......@@ -565,6 +773,7 @@ class FalconDecoderLayer(nn.Module):
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
......@@ -584,6 +793,7 @@ class FalconDecoderLayer(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
padding_mask=padding_mask,
)
attention_output = attn_outputs[0]
......@@ -700,6 +910,7 @@ class FalconPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -917,9 +1128,15 @@ class FalconModel(FalconPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
padding_mask = None
else:
attention_mask = attention_mask.to(hidden_states.device)
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
......@@ -964,6 +1181,7 @@ class FalconModel(FalconPreTrainedModel):
causal_mask,
position_ids,
head_mask[i],
padding_mask,
)
else:
outputs = block(
......@@ -975,6 +1193,7 @@ class FalconModel(FalconPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
padding_mask=padding_mask,
)
hidden_states = outputs[0]
......
......@@ -31,15 +31,38 @@ from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
......@@ -261,6 +284,7 @@ class LlamaAttention(nn.Module):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
......@@ -301,6 +325,7 @@ class LlamaAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
......@@ -343,7 +368,6 @@ class LlamaAttention(nn.Module):
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
......@@ -373,6 +397,7 @@ class LlamaAttention(nn.Module):
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
......@@ -388,11 +413,189 @@ class LlamaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if padding_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.self_attn = (
LlamaAttention(config=config)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config=config)
)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -405,6 +608,7 @@ class LlamaDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
......@@ -432,6 +636,7 @@ class LlamaDecoderLayer(nn.Module):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
......@@ -479,6 +684,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -669,6 +875,13 @@ class LlamaModel(LlamaPreTrainedModel):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
......@@ -698,15 +911,12 @@ class LlamaModel(LlamaPreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
)
else:
layer_outputs = decoder_layer(
......@@ -716,6 +926,7 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = layer_outputs[0]
......
......@@ -452,7 +452,6 @@ PERSIMMON_START_DOCSTRING = r"""
"The bare Persimmon Model outputting raw hidden-states without any specific head on top.",
PERSIMMON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Persimmon
class PersimmonPreTrainedModel(PreTrainedModel):
config_class = PersimmonConfig
base_model_prefix = "model"
......@@ -544,7 +543,6 @@ PERSIMMON_INPUTS_DOCSTRING = r"""
"The bare Persimmon Model outputting raw hidden-states without any specific head on top.",
PERSIMMON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps
class PersimmonModel(PersimmonPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`]
......@@ -553,6 +551,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
config: PersimmonConfig
"""
# Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps
def __init__(self, config: PersimmonConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
......
......@@ -60,6 +60,7 @@ from .utils import (
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
......@@ -392,6 +393,16 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_flash_attn(test_case):
"""
Decorator marking a test that requires Flash Attention.
These tests are skipped when Flash Attention isn't installed.
"""
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case)
def require_peft(test_case):
"""
Decorator marking a test that requires PEFT.
......
......@@ -114,6 +114,7 @@ from .import_utils import (
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
......
......@@ -71,6 +71,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_flash_attn_available = _is_package_available("flash_attn")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
......@@ -570,6 +571,16 @@ def is_bitsandbytes_available():
return _bitsandbytes_available and torch.cuda.is_available()
def is_flash_attn_available():
if not is_torch_available():
return False
# Let's add an extra check to see if cuda is available
import torch
return _flash_attn_available and torch.cuda.is_available()
def is_torchdistx_available():
return _torchdistx_available
......
......@@ -18,9 +18,10 @@
import unittest
from parameterized import parameterized
from pytest import mark
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -375,6 +376,41 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@require_torch
class LlamaIntegrationTest(unittest.TestCase):
......
......@@ -64,6 +64,7 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_accelerate,
require_flash_attn,
require_safetensors,
require_torch,
require_torch_gpu,
......@@ -2722,6 +2723,191 @@ class ModelTesterMixin:
num_params < 1000000
), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
import torch
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
).to(torch_device)
for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return
self.assertTrue(False, "FlashAttention2 modules not found in model")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_left_padding(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.equal(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.equal(out, out_fa))
global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment