Unverified Commit faefacb8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

FlashAttention 2.0 support (#329)



* FA v2.0 support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ded8b9bd
...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=1.0.7"]) add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.0.post1"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -12,8 +12,6 @@ from pkg_resources import packaging ...@@ -12,8 +12,6 @@ from pkg_resources import packaging
import torch import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked, fused_attn_fwd_qkvpacked,
...@@ -47,6 +45,12 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode ...@@ -47,6 +45,12 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6") _flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module
__all__ = ["DotProductAttention"] __all__ = ["DotProductAttention"]
...@@ -397,11 +401,14 @@ class FlashAttention(torch.nn.Module): ...@@ -397,11 +401,14 @@ class FlashAttention(torch.nn.Module):
device=query_layer.device) device=query_layer.device)
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
output = flash_attn_unpadded_func( fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask, softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
deterministic=self.deterministic, **fa_optional_forward_kwargs
) )
# [(b sq), np, hn] -> [sq, b, (np hn)] # [(b sq), np, hn] -> [sq, b, (np hn)]
...@@ -700,11 +707,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -700,11 +707,10 @@ class DotProductAttention(torch.nn.Module):
.. warning:: .. warning::
For the default attention mechanism, this module executes a non-deterministic version of FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
`flash-attn <https://github.com/ksivaman/flash-attention>`_ whenever possible in order to deterministic behavior at the cost of performance, use FlashAttention version < `2.0.0`
achieve optimal performance. To observe deterministic behavior, set the environment and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
Parameters Parameters
---------- ----------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment