Unverified Commit 9d173c93 authored by yuzhongw-nvidia's avatar yuzhongw-nvidia Committed by GitHub
Browse files

Fix MLA CP Bugs (#1896)



* fix: (1) UT ignores MLA; (2) bshd format runtime error. Ban fp8 mla attn + cp due to correctness problem
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* only disable FP8 CP for MLA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent cc0cb35d
......@@ -89,7 +89,7 @@ def run_dpa_with_cp(
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim_qk,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
......@@ -106,16 +106,22 @@ def run_dpa_with_cp(
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
k_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
config.num_heads * config.head_dim_v,
)
cu_seqlens_q = None
cu_seqlens_kv = None
......@@ -128,16 +134,22 @@ def run_dpa_with_cp(
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
k_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim_qk,
config.num_heads * config.head_dim_v,
)
cu_seqlens_q = None
cu_seqlens_kv = None
......@@ -149,14 +161,19 @@ def run_dpa_with_cp(
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
k_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
config.num_heads * config.head_dim_v,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
......@@ -177,8 +194,8 @@ def run_dpa_with_cp(
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
......
......@@ -173,6 +173,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
subprocess.run(
get_bash_arguments(
......
......@@ -2559,8 +2559,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.enable_mla:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dk_fp8, fake_dtype=torch.float32, internal=True
)
......@@ -2586,8 +2586,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:])
dk = dk.view(dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(dv.shape[0], -1, *dv.shape[-2:])
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
......
......@@ -608,6 +608,12 @@ def get_attention_backend(
" bias for THD format"
)
use_fused_attention = False
elif fp8 and head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with FP8"
" MLA attention"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment