Unverified Commit 1da1302e authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Flash Attention 2 support for RoCm (#27611)



* support FA2

* fix typo

* fix broken tests

* fix more test errors

* left/right

* fix bug

* more test

* typo

* fix layout flash attention falcon

* do not support this case

* use allclose instead of equal

* fix various bugs with flash attention

* bump

* fix test

* fix mistral

* use skiptest instead of return that may be misleading

* add fix causal arg flash attention

* fix copies

* more explicit comment

* still use self.is_causal

* fix causal argument

* comment

* fixes

* update documentation

* add link

* wrong test

* simplify FA2 RoCm requirements

* update opt

* make flash_attn_uses_top_left_mask attribute private and precise comment

* better error handling

* fix copy & mistral

* Update src/transformers/modeling_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/import_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210

* fix merge

* simplify

* inline args

---------
Co-authored-by: default avatarFelix Marty <felix@hf.co>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 4d4febb7
...@@ -56,13 +56,9 @@ The `generate()` method can be used to generate text using GPT Neo model. ...@@ -56,13 +56,9 @@ The `generate()` method can be used to generate text using GPT Neo model.
## Combining GPT-Neo and Flash Attention 2 ## Combining GPT-Neo and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature, and make sure your hardware is compatible with Flash-Attention 2. More details are available [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2) concerning the installation.
```bash Make sure as well to load your model in half-precision (e.g. `torch.float16`).
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below: To load and run a model using Flash Attention 2, refer to the snippet below:
......
...@@ -38,11 +38,9 @@ FlashAttention-2 is experimental and may change considerably in future versions. ...@@ -38,11 +38,9 @@ FlashAttention-2 is experimental and may change considerably in future versions.
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites): Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
```bash FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
pip install flash-attn --no-build-isolation
```
To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]: To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
...@@ -62,7 +60,7 @@ model = AutoModelForCausalLM.from_pretrained( ...@@ -62,7 +60,7 @@ model = AutoModelForCausalLM.from_pretrained(
<Tip> <Tip>
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`, and it only runs on Nvidia GPUs. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2. FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
</Tip> </Tip>
......
...@@ -1281,17 +1281,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1281,17 +1281,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
raise ImportError(
"Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it. Make sure to have at least the version 2.1.0"
)
else:
flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
is_flash_greater_than_2 = flash_attention_version >= version.parse("2.1.0")
if not is_flash_greater_than_2: preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
raise ValueError( install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
f"You need flash_attn package version to be greater or equal than 2.1. Make sure to have that version installed - detected version {flash_attention_version}" if torch.version.cuda:
) if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
......
...@@ -34,6 +34,7 @@ from ...utils import ( ...@@ -34,6 +34,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_accelerate_available, is_accelerate_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
) )
from ..auto import AutoModel from ..auto import AutoModel
...@@ -214,6 +215,15 @@ class BarkSelfFlashAttention2(BarkSelfAttention): ...@@ -214,6 +215,15 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _split_heads(self, tensor, num_heads, attn_head_size): def _split_heads(self, tensor, num_heads, attn_head_size):
""" """
Splits hidden_size dim into attn_head_size and num_heads Splits hidden_size dim into attn_head_size and num_heads
...@@ -301,6 +311,12 @@ class BarkSelfFlashAttention2(BarkSelfAttention): ...@@ -301,6 +311,12 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -321,13 +337,13 @@ class BarkSelfFlashAttention2(BarkSelfAttention): ...@@ -321,13 +337,13 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -42,6 +42,7 @@ from ...utils import ( ...@@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -294,6 +295,15 @@ class BartFlashAttention2(BartAttention): ...@@ -294,6 +295,15 @@ class BartFlashAttention2(BartAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
...@@ -418,6 +428,12 @@ class BartFlashAttention2(BartAttention): ...@@ -418,6 +428,12 @@ class BartFlashAttention2(BartAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -438,13 +454,13 @@ class BartFlashAttention2(BartAttention): ...@@ -438,13 +454,13 @@ class BartFlashAttention2(BartAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -46,6 +46,7 @@ from ...utils import ( ...@@ -46,6 +46,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -269,6 +270,15 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): ...@@ -269,6 +270,15 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
API of flash attention and deal with padding tokens in case the input contains any of them. API of flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -363,6 +373,12 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): ...@@ -363,6 +373,12 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -383,13 +399,13 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): ...@@ -383,13 +399,13 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -38,6 +38,7 @@ from ...utils import ( ...@@ -38,6 +38,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
) )
from .configuration_falcon import FalconConfig from .configuration_falcon import FalconConfig
...@@ -516,6 +517,15 @@ class FalconFlashAttention2(FalconAttention): ...@@ -516,6 +517,15 @@ class FalconFlashAttention2(FalconAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -631,6 +641,12 @@ class FalconFlashAttention2(FalconAttention): ...@@ -631,6 +641,12 @@ class FalconFlashAttention2(FalconAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -651,13 +667,13 @@ class FalconFlashAttention2(FalconAttention): ...@@ -651,13 +667,13 @@ class FalconFlashAttention2(FalconAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -34,6 +34,7 @@ from ...utils import ( ...@@ -34,6 +34,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
) )
from .configuration_gpt_bigcode import GPTBigCodeConfig from .configuration_gpt_bigcode import GPTBigCodeConfig
...@@ -292,6 +293,15 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): ...@@ -292,6 +293,15 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
API of flash attention and deal with padding tokens in case the input contains any of them. API of flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -422,6 +432,12 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): ...@@ -422,6 +432,12 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -442,13 +458,13 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): ...@@ -442,13 +458,13 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -42,6 +42,7 @@ from ...utils import ( ...@@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_fx_available, is_torch_fx_available,
logging, logging,
) )
...@@ -299,6 +300,15 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): ...@@ -299,6 +300,15 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -400,6 +410,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): ...@@ -400,6 +410,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -420,13 +436,13 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): ...@@ -420,13 +436,13 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -37,6 +37,7 @@ from ...utils import ( ...@@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -442,6 +443,14 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -442,6 +443,14 @@ class LlamaFlashAttention2(LlamaAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -491,6 +500,8 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -491,6 +500,8 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
...@@ -555,6 +566,12 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -555,6 +566,12 @@ class LlamaFlashAttention2(LlamaAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -575,13 +592,13 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -575,13 +592,13 @@ class LlamaFlashAttention2(LlamaAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -41,6 +41,7 @@ from ...utils import ( ...@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -289,6 +290,15 @@ class MBartFlashAttention2(MBartAttention): ...@@ -289,6 +290,15 @@ class MBartFlashAttention2(MBartAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
...@@ -413,6 +423,12 @@ class MBartFlashAttention2(MBartAttention): ...@@ -413,6 +423,12 @@ class MBartFlashAttention2(MBartAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -433,13 +449,13 @@ class MBartFlashAttention2(MBartAttention): ...@@ -433,13 +449,13 @@ class MBartFlashAttention2(MBartAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -37,6 +37,7 @@ from ...utils import ( ...@@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -312,6 +313,15 @@ class MistralFlashAttention2(MistralAttention): ...@@ -312,6 +313,15 @@ class MistralFlashAttention2(MistralAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -470,6 +480,12 @@ class MistralFlashAttention2(MistralAttention): ...@@ -470,6 +480,12 @@ class MistralFlashAttention2(MistralAttention):
use_sliding_windows (`bool`, *optional*): use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention. Whether to activate sliding window attention.
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -491,7 +507,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -491,7 +507,7 @@ class MistralFlashAttention2(MistralAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
else: else:
attn_output_unpad = flash_attn_varlen_func( attn_output_unpad = flash_attn_varlen_func(
...@@ -504,7 +520,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -504,7 +520,7 @@ class MistralFlashAttention2(MistralAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window), window_size=(self.config.sliding_window, self.config.sliding_window),
) )
...@@ -517,7 +533,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -517,7 +533,7 @@ class MistralFlashAttention2(MistralAttention):
value_states, value_states,
dropout, dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
...@@ -526,7 +542,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -526,7 +542,7 @@ class MistralFlashAttention2(MistralAttention):
value_states, value_states,
dropout, dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window), window_size=(self.config.sliding_window, self.config.sliding_window),
) )
......
...@@ -35,6 +35,7 @@ from ...utils import ( ...@@ -35,6 +35,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -288,6 +289,15 @@ class OptFlashAttention2(OPTAttention): ...@@ -288,6 +289,15 @@ class OptFlashAttention2(OPTAttention):
attention and deal with padding tokens in case the input contains any of them. attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -404,6 +414,12 @@ class OptFlashAttention2(OPTAttention): ...@@ -404,6 +414,12 @@ class OptFlashAttention2(OPTAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -424,13 +440,13 @@ class OptFlashAttention2(OPTAttention): ...@@ -424,13 +440,13 @@ class OptFlashAttention2(OPTAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -41,6 +41,7 @@ from ...utils import ( ...@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -478,6 +479,15 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -478,6 +479,15 @@ class WhisperFlashAttention2(WhisperAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
...@@ -602,6 +612,12 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -602,6 +612,12 @@ class WhisperFlashAttention2(WhisperAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
...@@ -622,13 +638,13 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -622,13 +638,13 @@ class WhisperFlashAttention2(WhisperAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output
......
...@@ -118,6 +118,7 @@ from .import_utils import ( ...@@ -118,6 +118,7 @@ from .import_utils import (
is_faiss_available, is_faiss_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_available, is_flash_attn_available,
is_flash_attn_greater_or_equal_2_10,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
......
...@@ -71,9 +71,6 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10") ...@@ -71,9 +71,6 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex") _apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
importlib.metadata.version("flash_attn")
) >= version.parse("2.1.0")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
...@@ -608,10 +605,29 @@ def is_flash_attn_2_available(): ...@@ -608,10 +605,29 @@ def is_flash_attn_2_available():
if not is_torch_available(): if not is_torch_available():
return False return False
if not _is_package_available("flash_attn"):
return False
# Let's add an extra check to see if cuda is available # Let's add an extra check to see if cuda is available
import torch import torch
return _flash_attn_2_available and torch.cuda.is_available() if not torch.cuda.is_available():
return False
if torch.version.cuda:
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
elif torch.version.hip:
# TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
else:
return False
def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"):
return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_available(): def is_flash_attn_available():
......
...@@ -3087,7 +3087,7 @@ class ModelTesterMixin: ...@@ -3087,7 +3087,7 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
) )
self.assertTrue(torch.equal(out, out_fa)) self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
...@@ -3130,7 +3130,7 @@ class ModelTesterMixin: ...@@ -3130,7 +3130,7 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
) )
self.assertTrue(torch.equal(out, out_fa)) self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment