Unverified Commit 7396c527 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Switch to upstream flash-attn (#151)



* use upstream flash-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* get correct FA for linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7c9fb403
...@@ -30,6 +30,7 @@ jobs: ...@@ -30,6 +30,7 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: 'Lint' - name: 'Lint'
run: | run: |
pip install flash-attn==1.0.2
export PYTHON_ONLY=1 export PYTHON_ONLY=1
export TE_PATH=. export TE_PATH=.
bash ./qa/L0_lint/test.sh bash ./qa/L0_lint/test.sh
......
...@@ -157,7 +157,7 @@ class PyTorchBuilder(FrameworkBuilderBase): ...@@ -157,7 +157,7 @@ class PyTorchBuilder(FrameworkBuilderBase):
@staticmethod @staticmethod
def install_requires(): def install_requires():
return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",] return ["flash-attn>=1.0.2",]
class TensorFlowBuilder(FrameworkBuilderBase): class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self): def cmake_flags(self):
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
import math import math
import warnings import warnings
from importlib.metadata import version from importlib.metadata import version
from distutils.version import LooseVersion
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union
...@@ -43,7 +44,8 @@ from transformer_engine.pytorch.distributed import ( ...@@ -43,7 +44,8 @@ from transformer_engine.pytorch.distributed import (
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = version("flash-attn") _flash_attn_version = LooseVersion(version("flash-attn"))
_flash_attn_version_required = LooseVersion("1.0.2")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
...@@ -215,20 +217,18 @@ class FlashAttention(torch.nn.Module): ...@@ -215,20 +217,18 @@ class FlashAttention(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
if "dev" not in _flash_attn_version: assert (
raise ImportError( _flash_attn_version >= _flash_attn_version_required
'Please install correct version of flash-attn with ' \ ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
'pip install git+https://github.com/ksivaman/flash-attention.git@hopper. ' \
'If running on Hopper, ' \
'please install from source with compute capability 9.0.')
assert ( assert (
attn_mask_type == "causal" attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.' ), 'FlashAttention currently only supports causal attention mask.'
self.attn_causal_mask = attn_mask_type == "causal" self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
def forward( def forward(
self, self,
...@@ -274,7 +274,8 @@ class FlashAttention(torch.nn.Module): ...@@ -274,7 +274,8 @@ class FlashAttention(torch.nn.Module):
output = flash_attn_unpadded_func( output = flash_attn_unpadded_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
deterministic=self.deterministic,
) )
# [(b sq), np, hn] -> [sq, b, (np hn)] # [(b sq), np, hn] -> [sq, b, (np hn)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment