Unverified Commit c334fc46 authored by zhujian's avatar zhujian Committed by GitHub
Browse files

[PyTorch] Support FA3 for MLA and with CP (#1907)



feature(FA3,MLA,CP):
1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward
2. Update get_attention_backend method because FA3 support MLA now
3. Add CP MLA support for FA3
4. Add unit tests for FA3 MLA CP
5. Update attention doc
Signed-off-by: default avatarzhujian <zhujian.whu.cs@gmail.com>
parent 8aee1bb7
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n",
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n", "\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
......
...@@ -36,6 +36,12 @@ model_configs_flash_attn = { ...@@ -36,6 +36,12 @@ model_configs_flash_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA ), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA
"cp_3_2": ModelConfig(
2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA
} }
...@@ -81,6 +87,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -81,6 +87,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
) )
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -358,7 +358,7 @@ def get_fa_args( ...@@ -358,7 +358,7 @@ def get_fa_args(
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
*[None] *[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale
] ]
return [ return [
*[None] *[None]
...@@ -366,7 +366,7 @@ def get_fa_args( ...@@ -366,7 +366,7 @@ def get_fa_args(
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
*[None] *[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale
] ]
if qkv_format == "thd": if qkv_format == "thd":
return [ return [
...@@ -829,6 +829,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -829,6 +829,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args( fa_forward_args_thd = get_fa_args(
True, True,
use_flash_attn_3, use_flash_attn_3,
...@@ -838,19 +851,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -838,19 +851,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( k_part,
kv_inputs[i % 2][..., 0, :, :] v_part,
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd, *fa_forward_args_thd,
causal=True, causal=True,
**fa_forward_kwargs, **fa_forward_kwargs,
...@@ -985,6 +989,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -985,6 +989,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
if enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args( fa_forward_args_thd = get_fa_args(
True, True,
use_flash_attn_3, use_flash_attn_3,
...@@ -1001,19 +1021,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1001,19 +1021,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1 fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( k_part,
kv_inputs[i % 2][..., 0, :, :] v_part,
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd, *fa_forward_args_thd,
causal=False, causal=False,
**fa_forward_kwargs, **fa_forward_kwargs,
...@@ -1144,6 +1155,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1144,6 +1155,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args( fa_forward_args_thd = get_fa_args(
True, True,
use_flash_attn_3, use_flash_attn_3,
...@@ -1160,19 +1184,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1160,19 +1184,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1 fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( k_part,
kv_inputs[i % 2][..., 0, :, :] v_part,
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd, *fa_forward_args_thd,
causal=False, causal=False,
**fa_forward_kwargs, **fa_forward_kwargs,
...@@ -1269,6 +1284,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1269,6 +1284,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args( fa_forward_args_thd = get_fa_args(
True, True,
use_flash_attn_3, use_flash_attn_3,
...@@ -1278,19 +1306,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1278,19 +1306,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q, q,
( k_part,
kv_inputs[i % 2][..., 0, :, :] v_part,
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
*fa_forward_args_thd, *fa_forward_args_thd,
causal=False, causal=False,
**fa_forward_kwargs, **fa_forward_kwargs,
...@@ -1865,7 +1884,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1865,7 +1884,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_ = dv_._data dv_ = dv_._data
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
if ctx.enable_mla:
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
dk_ = (
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
)
dv_ = (
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
)
fa_backward_args_thd = get_fa_args( fa_backward_args_thd = get_fa_args(
False, False,
ctx.use_flash_attn_3, ctx.use_flash_attn_3,
...@@ -1875,16 +1914,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1875,16 +1914,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q, max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv, max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_, dq=dq_,
dk=( dk=dk_,
dkv_[..., 0, :, :] dv=dv_,
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
) )
if ctx.use_flash_attn_3 or ( if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
...@@ -1895,12 +1926,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1895,12 +1926,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = 0 fa_backward_kwargs["window_size_right"] = 0
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], k_part,
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], v_part,
out_, out_,
softmax_lse, softmax_lse,
*fa_backward_args_thd, *fa_backward_args_thd,
...@@ -2016,7 +2046,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2016,7 +2046,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_ = dv_._data dv_ = dv_._data
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
if ctx.enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
dk_ = (
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
)
dv_ = (
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
)
fa_backward_args_thd = get_fa_args( fa_backward_args_thd = get_fa_args(
False, False,
ctx.use_flash_attn_3, ctx.use_flash_attn_3,
...@@ -2026,16 +2078,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2026,16 +2078,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q, max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv // 2, max_seqlen_kv=ctx.max_seqlen_kv // 2,
dq=dq_, dq=dq_,
dk=( dk=dk_,
dkv_[..., 0, :, :] dv=dv_,
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
) )
if ctx.use_flash_attn_3 or ( if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
...@@ -2046,12 +2090,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2046,12 +2090,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], k_part,
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], v_part,
out_, out_,
softmax_lse, softmax_lse,
*fa_backward_args_thd, *fa_backward_args_thd,
...@@ -2160,7 +2203,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2160,7 +2203,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_ = dv_._data dv_ = dv_._data
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
if ctx.enable_mla:
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
dk_ = (
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
)
dv_ = (
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
)
fa_backward_args_thd = get_fa_args( fa_backward_args_thd = get_fa_args(
False, False,
ctx.use_flash_attn_3, ctx.use_flash_attn_3,
...@@ -2170,16 +2233,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2170,16 +2233,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q // 2, max_seqlen_q=ctx.max_seqlen_q // 2,
max_seqlen_kv=ctx.max_seqlen_kv, max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_, dq=dq_,
dk=( dk=dk_,
dkv_[..., 0, :, :] dv=dv_,
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
) )
if ctx.use_flash_attn_3 or ( if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
...@@ -2190,12 +2245,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2190,12 +2245,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], k_part,
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], v_part,
out_, out_,
softmax_lse_, softmax_lse_,
*fa_backward_args_thd, *fa_backward_args_thd,
...@@ -2267,7 +2321,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2267,7 +2321,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dq_ = torch.empty_like(q) dq_ = torch.empty_like(q)
if ctx.enable_mla:
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
dkv_ = torch.empty_like(kv) dkv_ = torch.empty_like(kv)
dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0]
dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1]
fa_backward_args_thd = get_fa_args( fa_backward_args_thd = get_fa_args(
False, False,
ctx.use_flash_attn_3, ctx.use_flash_attn_3,
...@@ -2277,8 +2339,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2277,8 +2339,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q, max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv, max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_, dq=dq_,
dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], dk=dk_,
dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], dv=dv_,
) )
if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
...@@ -2287,12 +2349,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2287,12 +2349,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout, dout,
q, q,
kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], k_part,
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], v_part,
out, out,
softmax_lse, softmax_lse,
*fa_backward_args_thd, *fa_backward_args_thd,
......
...@@ -126,10 +126,10 @@ class FlashAttentionUtils: ...@@ -126,10 +126,10 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3 # Please follow these instructions to install FA3
v3_installation_steps = """\ v3_installation_steps = """\
(1) git clone https://github.com/Dao-AILab/flash-attention.git (1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install (2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3 (4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" (5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
v3_warning_printed = False v3_warning_printed = False
@staticmethod @staticmethod
...@@ -477,11 +477,10 @@ def get_attention_backend( ...@@ -477,11 +477,10 @@ def get_attention_backend(
# Filter: Head dimension # Filter: Head dimension
if head_dim_qk != head_dim_v: if head_dim_qk != head_dim_v:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( if use_flash_attention_2 and FlashAttentionUtils.is_installed:
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
): use_flash_attention_2 = False
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd": if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug( logger.debug(
...@@ -508,9 +507,40 @@ def get_attention_backend( ...@@ -508,9 +507,40 @@ def get_attention_backend(
".".join([str(i) for i in device_compute_capability]), ".".join([str(i) for i in device_compute_capability]),
) )
use_flash_attention_2 = False use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): if use_flash_attention_3:
def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype):
if head_dim_qk > 256 or num_heads % num_gqa_groups != 0:
return False
if head_dim_qk != head_dim_v:
cond1 = 128 < head_dim_qk <= 192
cond2 = 96 < head_dim_v <= 128
cond3 = head_dim_qk <= 64 and head_dim_v <= 512
if not ((cond1 and cond2) or cond3):
return False
if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16):
return False
return True
if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype):
if FlashAttentionUtils.v3_is_installed: if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128") logger.debug(
"Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, "
"head_dim_qk, head_dim_v or qkv_dtype. "
"Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and "
"if head_dim_qk is different from head_dim_v, then "
"(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or "
"(head_dim_qk <= 64 and head_dim_v <= 512), and "
"if head_dim_qk is different from head_dim_v and head_dim_v > 256, then "
"qkv_dtype requires fp16 and bf16 data type. "
"Found: num_heads = %s, num_gqa_groups = %s, "
"head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.",
num_heads,
num_gqa_groups,
head_dim_qk,
head_dim_v,
qkv_dtype,
)
use_flash_attention_3 = False use_flash_attention_3 = False
# Filter: QKV layout # Filter: QKV layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment