Unverified Commit f2bd53c4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Bump FlashAttn version and add deterministic option for FAv2 (#585)



* Deterministic FA, bump minimum supported version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix MQA/GQA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2a75314
......@@ -284,7 +284,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6,<=2.3.3,!=2.0.9,!=2.1.0"])
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -54,19 +54,17 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6")
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
if _flash_attn_2_available:
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
......@@ -442,8 +440,8 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
if _flash_attn_2_available:
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
# Flash Attn inputs
q_inputs = [None, None]
kv_inputs = [None, None]
......@@ -480,67 +478,46 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv_inputs[i%2] = p2p_comm_buffers[i]
if causal:
fa_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_forward_kwargs["alibi_slopes"] = None
fa_forward_kwargs["return_softmax"]=False
if i == 0:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, return_softmax=False,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
causal=True, return_softmax=False,
)
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, **fa_forward_kwargs,
)
elif i <= rank:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, return_softmax=False,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q, cu_seqlens_k//2,
max_seqlen_q, max_seqlen_k//2, dropout_p, softmax_scale,
causal=False, return_softmax=False,
)
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, **fa_forward_kwargs,
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_available:
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False,
)
else:
out_per_step[i] = torch.empty_like(q_inputs[i%2])
_, softmax_lse_per_step[i], rng_states[i], _ = _flash_attn_forward( # pylint: disable=unbalanced-tuple-unpacking
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
out_per_step[i], cu_seqlens_q//2, cu_seqlens_k,
max_seqlen_q//2, max_seqlen_k, dropout_p, softmax_scale,
causal=False, return_softmax=False,
)
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, **fa_forward_kwargs,
)
else:
assert False, "Not implemented yet!"
......@@ -625,10 +602,6 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
fa_optional_backward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_backward_kwargs["num_splits"] = 1 if ctx.deterministic else 0
for i in range(cp_size):
# wait until KV is received
for req in send_recv_reqs:
......@@ -654,6 +627,14 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd
if ctx.causal:
fa_backward_kwargs = {}
if _flash_attn_2_3_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
if _flash_attn_2_4_plus:
fa_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_backward_kwargs["deterministic"] = ctx.deterministic
fa_backward_kwargs["rng_state"]=ctx.rng_states[cp_size-i-1]
if i == (cp_size-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
......@@ -669,8 +650,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
**fa_backward_kwargs,
)
elif i >= (cp_size-rank-1):
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
......@@ -687,8 +667,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
**fa_backward_kwargs,
)
else:
# [b, sq//2, np, hn] -> [b*sq//2, np, hn]
......@@ -705,8 +684,7 @@ class FlashAttnUnpaddedFuncWithCP(torch.autograd.Function):
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
**fa_backward_kwargs,
)
if i >= (cp_size-rank-1):
......@@ -1143,6 +1121,7 @@ class _PrepareQKVForFA(torch.autograd.Function):
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv
def _get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
......@@ -1249,6 +1228,7 @@ def _get_qkv_layout(
return qkv_layout, q, k, v
def check_set_window_size(
attn_mask_type: str,
window_size: Tuple[int, int] = None,
......@@ -1268,6 +1248,7 @@ def check_set_window_size(
window_size = (-1, -1)
return window_size
class FlashAttention(torch.nn.Module):
"""Dot product attention, using HazyResearch flash-attn package:
https://github.com/Dao-AILab/flash-attention
......@@ -1308,6 +1289,7 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
......@@ -1420,8 +1402,6 @@ class FlashAttention(torch.nn.Module):
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available
), "flash-attn v2 is required for variable sequence length support!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None:
......@@ -1435,6 +1415,9 @@ class FlashAttention(torch.nn.Module):
assert (
window_size in ((-1, -1), (-1, 0))
), "Sliding window attention is not supported with context parallelism."
assert (
alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism."
with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer,
......@@ -1448,16 +1431,18 @@ class FlashAttention(torch.nn.Module):
else:
with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if _flash_attn_2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs
**fa_optional_forward_kwargs,
)
if 'padding' in attn_mask_type:
......@@ -1542,6 +1527,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttnFunc_kvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed KV input"""
......@@ -1613,6 +1599,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors"""
......@@ -1686,6 +1673,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttention(torch.nn.Module):
"""Dot product attention, with multiple backends:
......@@ -1730,7 +1718,6 @@ class FusedAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available
and get_device_compute_capability() == (9, 0))
self.layer_number = 1 if layer_number is None else layer_number
if deterministic:
......@@ -1877,7 +1864,7 @@ class DotProductAttention(torch.nn.Module):
.. warning::
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version < `2.0.0`
deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
......@@ -2013,12 +2000,12 @@ class DotProductAttention(torch.nn.Module):
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and self.device_compute_capability >= (8, 0)
)
if _flash_attn_2_available and self.deterministic:
if not _flash_attn_2_4_1_plus and self.deterministic:
self.use_flash_attention = False
warnings.warn(
"Disabling usage of FlashAttention since version 2 does not support deterministic "
"execution. In order to use FA with deterministic behavior, please install "
"FlashAttention version 1."
"Disabling usage of FlashAttention since version <2.4.1 does not support "
"deterministic execution. In order to use FA with deterministic behavior,"
" please install FlashAttention version >=2.4.1."
)
self.use_fused_attention = (
......@@ -2115,6 +2102,7 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -2196,6 +2184,10 @@ class DotProductAttention(torch.nn.Module):
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
alibi_slopes: Optional[torch.Tensor], default = `None`
An fp32 bias of shape (nheads,) or (batch_size, nheads)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
......@@ -2299,24 +2291,12 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = False
# Filter: Device and dimensions.
# FAv1 supports head_dim <= 128, and for >64 requires sm80/sm90
# FAv2 supports head_dim <= 256, and for >192 requires sm80/sm90
# Both FAv1 and FAv2 require head_dim % 8 == 0
if not _flash_attn_2_available:
if (key_layer.shape[-1] > 128
or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 64
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
if _flash_attn_2_available:
if (key_layer.shape[-1] > 256
or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
# Filter: MQA/GQA.
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
# FAv2 requires head_dim % 8 == 0
if (key_layer.shape[-1] > 256
or key_layer.shape[-1] % 8 != 0
or (key_layer.shape[-1] > 192
and self.device_compute_capability not in ((8, 0), (9, 0)))):
use_flash_attention = False
# Filter: cross attention + causal mask.
......@@ -2383,6 +2363,13 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention = (use_fused_attention
and is_backend_avail)
# Filter: Alibi slopes
if alibi_slopes is not None:
use_fused_attention = False
assert (
use_flash_attention
), "Alibi slopes bias is only supported in the FlashAttention backend."
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
......@@ -2420,6 +2407,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes=alibi_slopes,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment