Unverified Commit faefacb8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

FlashAttention 2.0 support (#329)



* FA v2.0 support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ded8b9bd
......@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=1.0.7"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.0.post1"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -12,8 +12,6 @@ from pkg_resources import packaging
import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
......@@ -47,6 +45,12 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
if _flash_attn_2_available:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module
__all__ = ["DotProductAttention"]
......@@ -397,11 +401,14 @@ class FlashAttention(torch.nn.Module):
device=query_layer.device)
with self.attention_dropout_ctx():
output = flash_attn_unpadded_func(
fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
deterministic=self.deterministic,
**fa_optional_forward_kwargs
)
# [(b sq), np, hn] -> [sq, b, (np hn)]
......@@ -700,11 +707,10 @@ class DotProductAttention(torch.nn.Module):
.. warning::
For the default attention mechanism, this module executes a non-deterministic version of
`flash-attn <https://github.com/ksivaman/flash-attention>`_ whenever possible in order to
achieve optimal performance. To observe deterministic behavior, set the environment
variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable
`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version < `2.0.0`
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
Parameters
----------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment