Unverified Commit faee0e8b authored by yuzhongw-nvidia's avatar yuzhongw-nvidia Committed by GitHub
Browse files

Support Context Parallel for Multi Latent Attention (MLA) (#1729)



* Support MLA (qk_dim != v_dim) for AttnFuncWithCPAndKVP2P
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* add UT for MLA CP
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refine the code
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refine the code
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarXiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
parent 031c6cf6
...@@ -107,6 +107,18 @@ model_configs_fused_attn = { ...@@ -107,6 +107,18 @@ model_configs_fused_attn = {
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA ), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA
} }
...@@ -159,6 +171,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -159,6 +171,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
) )
if dtype != "fp8" and fp8_mha: if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!") pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
enable_mla = k.shape[-1] != v.shape[-1]
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_half, cu_seqlens_kv_half = None, None cu_seqlens_q_half, cu_seqlens_kv_half = None, None
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s") seq_dim = qkv_format.index("s")
if enable_mla:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
else:
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
if use_fused_attention: if use_fused_attention:
...@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fwd_results_correction_done = torch.cuda.Event() fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)] p2p_comm_buffers = [None for _ in range(cp_size)]
if qkv_format in ["bshd", "sbhd"]: if enable_mla:
# If MLA, the shape of k and v does not match, so we flatten them
# and split them after receiving them.
k_shape = k.shape
k_numel = k.numel()
v_shape = v.shape
p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1)
elif qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else: else: # qkv_format == "thd"
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []] send_recv_reqs = [[], []]
...@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
# KV exchange is in BF16/FP16, cast received KV in each step # KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if enable_mla:
# If MLA, k and v are flattened, so split them after receiving.
k_part = kv_inputs[i % 2][:k_numel].view(*k_shape)
v_part = kv_inputs[i % 2][k_numel:].view(*v_shape)
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs: if pad_between_seqs:
...@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:] k.shape[0], -1, 2, *k.shape[-2:]
...@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[2:])
v_part = v_part.view(-1, *v_part.shape[2:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:] -1, k.shape[2], 2, *k.shape[-2:]
...@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous() ).contiguous()
q_part = q_inputs[i % 2] q_part = q_inputs[i % 2]
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn]
k_part = k_part[:, 0, ...]
v_part = v_part[:, 0, ...]
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk//2, b, np, hn]
k_part = k_part[0]
v_part = v_part[0]
else:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][0] kv_inputs[i % 2] = kv_inputs[i % 2][0]
elif qkv_format == "thd": elif qkv_format == "thd":
q_inputs[i % 2] = q q_inputs[i % 2] = q
if enable_mla:
# [t, np, hn] -> [t/2, np, hn]
k_part = tex.thd_read_half_tensor(
k_part, cu_seqlens_kv_padded, 0
)
v_part = tex.thd_read_half_tensor(
v_part, cu_seqlens_kv_padded, 0
)
else:
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_kv_padded, 0 kv_inputs[i % 2], cu_seqlens_kv_padded, 0
) )
if use_fused_attention: if use_fused_attention:
if enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
q_part = q_inputs[i % 2] q_part = q_inputs[i % 2]
if not enable_mla:
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1 fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...] q_inputs[i % 2] = q[:, 1, ...]
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:] k.shape[0], -1, 2, *k.shape[-2:]
...@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1] q_inputs[i % 2] = q[1]
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[2:])
v_part = v_part.view(-1, *v_part.shape[2:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:] -1, k.shape[2], 2, *k.shape[-2:]
...@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous() ).contiguous()
q_part = q_inputs[i % 2] q_part = q_inputs[i % 2]
if not enable_mla:
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1 fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous() ).contiguous()
q_part = q q_part = q
if not enable_mla:
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q, q,
( (
...@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if i == 1: if i == 1:
softmax_lse = torch.clone(softmax_lse_per_step[0]) softmax_lse = torch.clone(softmax_lse_per_step[0])
if qkv_format == "thd": if qkv_format == "thd":
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) if enable_mla:
out = torch.zeros_like(v if not fp8 else out_per_step[0]).view(
v_shape
)
else:
# MHA or GQA
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(
q.shape
)
elif (i - 1) <= rank or not causal: elif (i - 1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction( flash_attn_fwd_softmax_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1] softmax_lse, softmax_lse_per_step[i - 1]
...@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[0], softmax_lse_per_step[0],
seq_dim, seq_dim,
) )
if enable_mla:
out = out.view(v_shape)
else:
out = out.view(q.shape) out = out.view(q.shape)
else: else:
flash_attn_fwd_out_correction( flash_attn_fwd_out_correction(
...@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.enable_mla = enable_mla
if enable_mla:
ctx.k_numel = k_numel
ctx.k_shape = k_shape
ctx.v_shape = v_shape
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
...@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
seq_dim = None seq_dim = None
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s") seq_dim = ctx.qkv_format.index("s")
if ctx.enable_mla:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
else: else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
...@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
dout = dout.dequantize(dtype=dout_dtype) dout = dout.dequantize(dtype=dout_dtype)
if ctx.enable_mla:
out = out.view(*ctx.v_shape)
dout = dout.view(*ctx.v_shape)
else:
# MHA or GQA
out = out.view(*q.shape) out = out.view(*q.shape)
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
send_recv_reqs = [] send_recv_reqs = []
...@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv = p2p_comm_buffers[i % 2][0] kv = p2p_comm_buffers[i % 2][0]
q_, kv_, out_, dout_ = None, None, None, None q_, kv_, out_, dout_ = None, None, None, None
dq_, dk_, dv_ = None, None, None dq_, dk_, dv_ = None, None, None
if ctx.enable_mla:
k_part = kv[: ctx.k_numel].view(*ctx.k_shape)
v_part = kv[ctx.k_numel :].view(*ctx.v_shape)
# In reversed order of fwd # In reversed order of fwd
if causal: if causal:
if i == (cp_size - 1): if i == (cp_size - 1):
...@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = [ q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
] ]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[-3:])
v_part = v_part.view(-1, *v_part.shape[-3:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:]) kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
...@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_ q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] if not ctx.enable_mla:
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_ out_part = out_
dout_part = dout_ dout_part = dout_
...@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = 0 fa_backward_kwargs["window_size_right"] = 0
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
...@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = [ q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
] ]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part[:, 0]
v_part = v_part[:, 0]
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, 0] kv_ = kv[:, 0]
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part[0]
v_part = v_part[0]
else:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[0] kv_ = kv[0]
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout q_, out_, dout_ = q, out, dout
if ctx.enable_mla:
# [t, np, hn] -> [t/2, np, hn]
k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0)
v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0)
else:
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
if ctx.use_fused_attention: if ctx.use_fused_attention:
if ctx.enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
kv_ = kv_.contiguous() kv_ = kv_.contiguous()
if ctx.fp8: if ctx.fp8:
aux_ctx_tensors = [ aux_ctx_tensors = [
...@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_ q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] if not ctx.enable_mla:
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_ out_part = out_
dout_part = dout_ dout_part = dout_
...@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
...@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_, out_, dout_ = q[1], out[1], dout[1] q_, out_, dout_ = q[1], out[1], dout[1]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[-3:])
v_part = v_part.view(-1, *v_part.shape[-3:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:]) kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
...@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_ q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] if not ctx.enable_mla:
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_ out_part = out_
dout_part = dout_ dout_part = dout_
...@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
...@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q q_part = q
if not ctx.enable_mla:
k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
out_part = out out_part = out
...@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout, dout,
q, q,
...@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dkv = p2p_comm_buffers[(i + 1) % 2][1] dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention: if ctx.use_fused_attention:
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.enable_mla:
dkv_ = None
elif ctx.qkv_format in ["bshd", "sbhd"]:
dkv_ = combine_tensors([dk_, dv_], -2) dkv_ = combine_tensors([dk_, dv_], -2)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
dkv_ = torch.cat( dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment ) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]: if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
# dkv is a buffer, so we do not need to transpose it, but only need to reshape it.
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
dkv_ = dkv_.movedim(-3, 0) dkv_ = dkv_.movedim(-3, 0)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
...@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv_ = dkv_.view(*dkv.shape) dkv_ = dkv_.view(*dkv.shape)
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] or
# [2, sk//2, b, np, hn]
dk = dkv[: ctx.k_numel].view(*ctx.k_shape)
dv = dkv[ctx.k_numel :].view(*ctx.v_shape)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
dk_ = dk_.view(*ctx.k_shape)
dv_ = dv_.view(*ctx.v_shape)
if ctx.fp8: if ctx.fp8:
# enable_mla and fp8
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
dk[:, 0, ...].copy_(dk_)
dk[:, 1, ...].fill_(0)
dv[:, 0, ...].copy_(dv_)
dv[:, 1, ...].fill_(0)
elif ctx.qkv_format == "sbhd":
dk[0].copy_(dk_)
dk[1].fill_(0)
dv[0].copy_(dv_)
dv[1].fill_(0)
else:
dk.copy_(dk_)
dv.copy_(dv_)
elif causal:
# enable_mla and not fp8 and causal
if i == (cp_size - 1):
if rank == 0:
if ctx.qkv_format == "bshd":
dk[:, 0, ...].add_(dk_[:, 0, ...])
dk[:, 1, ...].copy_(dk_[:, 1, ...])
dv[:, 0, ...].add_(dv_[:, 0, ...])
dv[:, 1, ...].copy_(dv_[:, 1, ...])
elif ctx.qkv_format == "sbhd":
dk[0, ...].add_(dk_[0, ...])
dk[1, ...].copy_(dk_[1, ...])
dv[0, ...].add_(dv_[0, ...])
dv[1, ...].copy_(dv_[1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "add", "copy"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "add", "copy"
)
else:
dk.add_(dk_)
dv.add_(dv_)
elif i >= (cp_size - rank - 1):
if i == 0 and rank == (cp_size - 1):
if ctx.qkv_format == "bshd":
dk[:, 0, ...].copy_(dk_)
dv[:, 0, ...].copy_(dv_)
elif ctx.qkv_format == "sbhd":
dk[0, ...].copy_(dk_)
dv[0, ...].copy_(dv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "copy", "none"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "copy", "none"
)
else:
if ctx.qkv_format == "bshd":
dk[:, 0, ...].add_(dk_)
dv[:, 0, ...].add_(dv_)
elif ctx.qkv_format == "sbhd":
dk[0, ...].add_(dk_)
dv[0, ...].add_(dv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "add", "none"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "add", "none"
)
elif i > 0:
dk.add_(dk_)
dv.add_(dv_)
else: # i == 0
dk.copy_(dk_)
dv.copy_(dv_)
else:
# enable_mla and not fp8 and not causal
if i == 0:
dk.copy_(dk_)
dv.copy_(dv_)
else: # i > 0
dk.add_(dk_)
dv.add_(dv_)
else:
if ctx.fp8:
# fp8
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].copy_(dkv_) dkv[:, :, 0, ...].copy_(dkv_)
...@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dkv.copy_(dkv_) dkv.copy_(dkv_)
elif causal: elif causal:
# not fp8 and causal
if i == (cp_size - 1): if i == (cp_size - 1):
if rank == 0: if rank == 0:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
...@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv[:, 0, ...].add_(dkv_[:, 0, ...]) dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy") tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "add", "copy"
)
else: else:
dkv.add_(dkv_) dkv.add_(dkv_)
elif i >= (cp_size - rank - 1): elif i >= (cp_size - rank - 1):
...@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_) dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none") tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "copy", "none"
)
else: else:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_) dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_) dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none") tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "add", "none"
)
elif i > 0: elif i > 0:
dkv.add_(dkv_) dkv.add_(dkv_)
else: else: # i == 0
dkv.copy_(dkv_) dkv.copy_(dkv_)
else: else:
# not fp8 and not causal
if i == 0: if i == 0:
dkv.copy_(dkv_) dkv.copy_(dkv_)
else: else: # i > 0
dkv.add_(dkv_) dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1) amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1])
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
if ctx.enable_mla:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dk_fp8, fake_dtype=torch.float32, internal=True
)
dv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dv_fp8, fake_dtype=torch.float32, internal=True
)
dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]]
dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]]
else:
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dkv_fp8, fake_dtype=torch.float32, internal=True dkv_fp8, fake_dtype=torch.float32, internal=True
) )
...@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:])
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq = dq.view(-1, *dq.shape[-3:]) dq = dq.view(-1, *dq.shape[-3:])
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
else:
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])
if ctx.qkv_format == "thd" and not ctx.use_fused_attention: if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
dq[cu_seqlens_q_padded[-1] :].fill_(0) dq[cu_seqlens_q_padded[-1] :].fill_(0)
if ctx.enable_mla:
dk[cu_seqlens_kv_padded[-1] :].fill_(0)
dv[cu_seqlens_kv_padded[-1] :].fill_(0)
else:
dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
if ctx.fp8 and ctx.is_input_fp8: if ctx.fp8 and ctx.is_input_fp8:
assert torch.uint8 not in [dq.dtype, dkv.dtype] assert torch.uint8 not in [dq.dtype, dkv.dtype]
if ctx.enable_mla:
dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]]
else:
dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
if not ctx.enable_mla:
dk, dv = dkv[0], dkv[1] dk, dv = dkv[0], dkv[1]
if cp_size_a2a > 1: if cp_size_a2a > 1:
...@@ -3584,6 +3876,12 @@ def attn_forward_func_with_cp( ...@@ -3584,6 +3876,12 @@ def attn_forward_func_with_cp(
"all_gather", "all_gather",
], "The context parallel running configs cannot support sliding window attetnion!" ], "The context parallel running configs cannot support sliding window attetnion!"
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "The context parallel running configs cannot support MLA!"
args = [ args = [
is_training, is_training,
q, q,
......
...@@ -608,11 +608,6 @@ def get_attention_backend( ...@@ -608,11 +608,6 @@ def get_attention_backend(
" bias for THD format" " bias for THD format"
) )
use_fused_attention = False use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask # Filter: Attention mask
# attn_mask_type | attention_mask | supported backends # attn_mask_type | attention_mask | supported backends
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment