Unverified Commit 7396c527 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Switch to upstream flash-attn (#151)



* use upstream flash-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* get correct FA for linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7c9fb403
......@@ -30,6 +30,7 @@ jobs:
uses: actions/checkout@v3
- name: 'Lint'
run: |
pip install flash-attn==1.0.2
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_lint/test.sh
......
......@@ -157,7 +157,7 @@ class PyTorchBuilder(FrameworkBuilderBase):
@staticmethod
def install_requires():
return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",]
return ["flash-attn>=1.0.2",]
class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self):
......
......@@ -7,6 +7,7 @@ import os
import math
import warnings
from importlib.metadata import version
from distutils.version import LooseVersion
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
......@@ -43,7 +44,8 @@ from transformer_engine.pytorch.distributed import (
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = version("flash-attn")
_flash_attn_version = LooseVersion(version("flash-attn"))
_flash_attn_version_required = LooseVersion("1.0.2")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......@@ -215,12 +217,9 @@ class FlashAttention(torch.nn.Module):
) -> None:
super().__init__()
if "dev" not in _flash_attn_version:
raise ImportError(
'Please install correct version of flash-attn with ' \
'pip install git+https://github.com/ksivaman/flash-attention.git@hopper. ' \
'If running on Hopper, ' \
'please install from source with compute capability 9.0.')
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.'
......@@ -229,6 +228,7 @@ class FlashAttention(torch.nn.Module):
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
def forward(
self,
......@@ -274,7 +274,8 @@ class FlashAttention(torch.nn.Module):
output = flash_attn_unpadded_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
deterministic=self.deterministic,
)
# [(b sq), np, hn] -> [sq, b, (np hn)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment