Unverified Commit f2bd53c4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Bump FlashAttn version and add deterministic option for FAv2 (#585)



* Deterministic FA, bump minimum supported version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix MQA/GQA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2a75314
...@@ -284,7 +284,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -284,7 +284,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6,<=2.3.3,!=2.0.9,!=2.1.0"]) add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -54,19 +54,17 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode ...@@ -54,19 +54,17 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6") _flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
if _flash_attn_2_available: if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
...@@ -442,8 +440,8 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -442,8 +440,8 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
# [b, s, np, hn] -> [b, 2, s//2, np, hn] # [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]] q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
if _flash_attn_2_available:
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8" assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
# Flash Attn inputs # Flash Attn inputs
q_inputs = [None, None] q_inputs = [None, None]
kv_inputs = [None, None] kv_inputs = [None, None]
...@@ -480,25 +478,22 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -480,25 +478,22 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv_inputs[i%2] = p2p_comm_buffers[i] kv_inputs[i%2] = p2p_comm_buffers[i]
if causal: if causal:
fa_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["return_softmax"]=False
if i == 0: if i == 0:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:]) q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \ _, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, return_softmax=False, dropout_p, softmax_scale, causal=True, **fa_forward_kwargs,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
causal=True, return_softmax=False,
) )
elif i <= rank: elif i <= rank:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
...@@ -506,40 +501,22 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -506,40 +501,22 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \ _, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, return_softmax=False, dropout_p, softmax_scale, causal=False, **fa_forward_kwargs,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q, cu_seqlens_k//2,
max_seqlen_q, max_seqlen_k//2, dropout_p, softmax_scale,
causal=False, return_softmax=False,
) )
else: else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn] # [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \ _, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward( softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False, dropout_p, softmax_scale, causal=False, **fa_forward_kwargs,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q//2, cu_seqlens_k,
max_seqlen_q//2, max_seqlen_k, dropout_p, softmax_scale,
causal=False, return_softmax=False,
) )
else: else:
assert False, "Not implemented yet!" assert False, "Not implemented yet!"
...@@ -625,10 +602,6 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -625,10 +602,6 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
p2p_comm_buffers[0][0].copy_(kv) p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = [] send_recv_reqs = []
fa_optional_backward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_backward_kwargs["num_splits"] = 1 if ctx.deterministic else 0
for i in range(cp_size): for i in range(cp_size):
# wait until KV is received # wait until KV is received
for req in send_recv_reqs: for req in send_recv_reqs:
...@@ -654,6 +627,14 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -654,6 +627,14 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv = p2p_comm_buffers[i%2][0] kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd # In reversed order of fwd
if ctx.causal: if ctx.causal:
fa_backward_kwargs = {}
if _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["rng_state"]=ctx.rng_states[cp_size-i-1]
if i == (cp_size-1): if i == (cp_size-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:]) q_ = q.view(-1, *q.shape[-2:])
...@@ -669,8 +650,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -669,8 +650,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k, dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True, ctx.dropout_p, ctx.softmax_scale, True,
rng_state=ctx.rng_states[cp_size-i-1], **fa_backward_kwargs,
**fa_optional_backward_kwargs
) )
elif i >= (cp_size-rank-1): elif i >= (cp_size-rank-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
...@@ -687,8 +667,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -687,8 +667,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2, dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2, ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False, ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1], **fa_backward_kwargs,
**fa_optional_backward_kwargs
) )
else: else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn] # [b, sq//2, np, hn] -> [b*sq//2, np, hn]
...@@ -705,8 +684,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function): ...@@ -705,8 +684,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k, dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k, ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False, ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1], **fa_backward_kwargs,
**fa_optional_backward_kwargs
) )
if i >= (cp_size-rank-1): if i >= (cp_size-rank-1):
...@@ -1143,6 +1121,7 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -1143,6 +1121,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv return dq, dk, dv
def _get_qkv_layout( def _get_qkv_layout(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -1249,6 +1228,7 @@ def _get_qkv_layout( ...@@ -1249,6 +1228,7 @@ def _get_qkv_layout(
return qkv_layout, q, k, v return qkv_layout, q, k, v
def check_set_window_size( def check_set_window_size(
attn_mask_type: str, attn_mask_type: str,
window_size: Tuple[int, int] = None, window_size: Tuple[int, int] = None,
...@@ -1268,6 +1248,7 @@ def check_set_window_size( ...@@ -1268,6 +1248,7 @@ def check_set_window_size(
window_size = (-1, -1) window_size = (-1, -1)
return window_size return window_size
class FlashAttention(torch.nn.Module): class FlashAttention(torch.nn.Module):
"""Dot product attention, using HazyResearch flash-attn package: """Dot product attention, using HazyResearch flash-attn package:
https://github.com/Dao-AILab/flash-attention https://github.com/Dao-AILab/flash-attention
...@@ -1308,6 +1289,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1308,6 +1289,7 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
...@@ -1420,8 +1402,6 @@ class FlashAttention(torch.nn.Module): ...@@ -1420,8 +1402,6 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd': elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!" assert not context_parallel, "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available
), "flash-attn v2 is required for variable sequence length support!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None: if max_seqlen_q is None:
...@@ -1435,6 +1415,9 @@ class FlashAttention(torch.nn.Module): ...@@ -1435,6 +1415,9 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
window_size in ((-1, -1), (-1, 0)) window_size in ((-1, -1), (-1, 0))
), "Sliding window attention is not supported with context parallelism." ), "Sliding window attention is not supported with context parallelism."
assert (
alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism."
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp( output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
...@@ -1448,16 +1431,18 @@ class FlashAttention(torch.nn.Module): ...@@ -1448,16 +1431,18 @@ class FlashAttention(torch.nn.Module):
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {} fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size fa_optional_forward_kwargs["window_size"] = window_size
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if _flash_attn_2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func( output = flash_attn_forward_func(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type, softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs **fa_optional_forward_kwargs,
) )
if 'padding' in attn_mask_type: if 'padding' in attn_mask_type:
...@@ -1542,6 +1527,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1542,6 +1527,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
class FusedAttnFunc_kvpacked(torch.autograd.Function): class FusedAttnFunc_kvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed KV input""" """Function for FusedAttention with packed KV input"""
...@@ -1613,6 +1599,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1613,6 +1599,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
class FusedAttnFunc(torch.autograd.Function): class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors""" """Function for FusedAttention with separate Q, K, V tensors"""
...@@ -1686,6 +1673,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1686,6 +1673,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
class FusedAttention(torch.nn.Module): class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends: """Dot product attention, with multiple backends:
...@@ -1730,7 +1718,6 @@ class FusedAttention(torch.nn.Module): ...@@ -1730,7 +1718,6 @@ class FusedAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_type = attention_type self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1" self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available
and get_device_compute_capability() == (9, 0)) and get_device_compute_capability() == (9, 0))
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
if deterministic: if deterministic:
...@@ -1877,7 +1864,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1877,7 +1864,7 @@ class DotProductAttention(torch.nn.Module):
.. warning:: .. warning::
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version < `2.0.0` deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
...@@ -2013,12 +2000,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -2013,12 +2000,12 @@ class DotProductAttention(torch.nn.Module):
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= (8, 0) and self.device_compute_capability >= (8, 0)
) )
if _flash_attn_2_available and self.deterministic: if not _flash_attn_2_4_1_plus and self.deterministic:
self.use_flash_attention = False self.use_flash_attention = False
warnings.warn( warnings.warn(
"Disabling usage of FlashAttention since version 2 does not support deterministic " "Disabling usage of FlashAttention since version <2.4.1 does not support "
"execution. In order to use FA with deterministic behavior, please install " "deterministic execution. In order to use FA with deterministic behavior,"
"FlashAttention version 1." " please install FlashAttention version >=2.4.1."
) )
self.use_fused_attention = ( self.use_fused_attention = (
...@@ -2115,6 +2102,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2115,6 +2102,7 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -2196,6 +2184,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -2196,6 +2184,10 @@ class DotProductAttention(torch.nn.Module):
softmax operation. 'padding,causal' and 'causal,padding' are equivalent. softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention. sliding window size for local attention.
alibi_slopes: Optional[torch.Tensor], default = `None`
An fp32 bias of shape (nheads,) or (batch_size, nheads)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
...@@ -2299,26 +2291,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -2299,26 +2291,14 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = False use_fused_attention = False
# Filter: Device and dimensions. # Filter: Device and dimensions.
# FAv1 supports head_dim <= 128, and for >64 requires sm80/sm90
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90 # FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# Both FAv1 and FAv2 require head_dim % 8 == 0 # FAv2 requires head_dim % 8 == 0
if not _flash_attn_2_available:
if (key_layer.shape[-1] > 128
or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 64
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
if _flash_attn_2_available:
if (key_layer.shape[-1] > 256 if (key_layer.shape[-1] > 256
or key_layer.shape[-1] % 8 != 0 or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 192 or (key_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0)))): and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False use_flash_attention = False
# Filter: MQA/GQA.
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False
# Filter: cross attention + causal mask. # Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus if (_flash_attn_2_1_plus
and "causal" in attn_mask_type and "causal" in attn_mask_type
...@@ -2383,6 +2363,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -2383,6 +2363,13 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = (use_fused_attention use_fused_attention = (use_fused_attention
and is_backend_avail) and is_backend_avail)
# Filter: Alibi slopes
if alibi_slopes is not None:
use_fused_attention = False
assert (
use_flash_attention
), "Alibi slopes bias is only supported in the FlashAttention backend."
# Filter: determinism. # Filter: determinism.
# backend | deterministic # backend | deterministic
# --------------------------------------------------------- # ---------------------------------------------------------
...@@ -2420,6 +2407,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2420,6 +2407,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment