"src/targets/vscode:/vscode.git/clone" did not exist on "58711bbc2da5db8126c775f2163e3562098b35ab"
Unverified Commit faee0e8b authored by yuzhongw-nvidia's avatar yuzhongw-nvidia Committed by GitHub
Browse files

Support Context Parallel for Multi Latent Attention (MLA) (#1729)



* Support MLA (qk_dim != v_dim) for AttnFuncWithCPAndKVP2P
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* add UT for MLA CP
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refine the code
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refine the code
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarXiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
parent 031c6cf6
......@@ -107,6 +107,18 @@ model_configs_fused_attn = {
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA
}
......@@ -159,6 +171,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
)
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run(
get_bash_arguments(
......
......@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
enable_mla = k.shape[-1] != v.shape[-1]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_half, cu_seqlens_kv_half = None, None
if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s")
if enable_mla:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
else:
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
if use_fused_attention:
......@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)]
if qkv_format in ["bshd", "sbhd"]:
if enable_mla:
# If MLA, the shape of k and v does not match, so we flatten them
# and split them after receiving them.
k_shape = k.shape
k_numel = k.numel()
v_shape = v.shape
p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1)
elif qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else:
else: # qkv_format == "thd"
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []]
......@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if enable_mla:
# If MLA, k and v are flattened, so split them after receiving.
k_part = kv_inputs[i % 2][:k_numel].view(*k_shape)
v_part = kv_inputs[i % 2][k_numel:].view(*v_shape)
if causal:
if i == 0:
if pad_between_seqs:
......@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:]
......@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[2:])
v_part = v_part.view(-1, *v_part.shape[2:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
......@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous()
q_part = q_inputs[i % 2]
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
......@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
......@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn]
k_part = k_part[:, 0, ...]
v_part = v_part[:, 0, ...]
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk//2, b, np, hn]
k_part = k_part[0]
v_part = v_part[0]
else:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][0]
elif qkv_format == "thd":
q_inputs[i % 2] = q
if enable_mla:
# [t, np, hn] -> [t/2, np, hn]
k_part = tex.thd_read_half_tensor(
k_part, cu_seqlens_kv_padded, 0
)
v_part = tex.thd_read_half_tensor(
v_part, cu_seqlens_kv_padded, 0
)
else:
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_kv_padded, 0
)
if use_fused_attention:
if enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
q_part = q_inputs[i % 2]
if not enable_mla:
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
......@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
......@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...]
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:]
......@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1]
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[2:])
v_part = v_part.view(-1, *v_part.shape[2:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:]
......@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous()
q_part = q_inputs[i % 2]
if not enable_mla:
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
......@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
......@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous()
q_part = q
if not enable_mla:
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
......@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q,
(
......@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if i == 1:
softmax_lse = torch.clone(softmax_lse_per_step[0])
if qkv_format == "thd":
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
if enable_mla:
out = torch.zeros_like(v if not fp8 else out_per_step[0]).view(
v_shape
)
else:
# MHA or GQA
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(
q.shape
)
elif (i - 1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1]
......@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[0],
seq_dim,
)
if enable_mla:
out = out.view(v_shape)
else:
out = out.view(q.shape)
else:
flash_attn_fwd_out_correction(
......@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
ctx.enable_mla = enable_mla
if enable_mla:
ctx.k_numel = k_numel
ctx.k_shape = k_shape
ctx.v_shape = v_shape
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
......@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
seq_dim = None
if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s")
if ctx.enable_mla:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
......@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
dout = dout.dequantize(dtype=dout_dtype)
if ctx.enable_mla:
out = out.view(*ctx.v_shape)
dout = dout.view(*ctx.v_shape)
else:
# MHA or GQA
out = out.view(*q.shape)
dout = dout.view(*q.shape)
send_recv_reqs = []
......@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv = p2p_comm_buffers[i % 2][0]
q_, kv_, out_, dout_ = None, None, None, None
dq_, dk_, dv_ = None, None, None
if ctx.enable_mla:
k_part = kv[: ctx.k_numel].view(*ctx.k_shape)
v_part = kv[ctx.k_numel :].view(*ctx.v_shape)
# In reversed order of fwd
if causal:
if i == (cp_size - 1):
......@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[-3:])
v_part = v_part.view(-1, *v_part.shape[-3:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd":
......@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
if not ctx.enable_mla:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_
dout_part = dout_
......@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = 0
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout_,
q_,
......@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part[:, 0]
v_part = v_part[:, 0]
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, 0]
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part[0]
v_part = v_part[0]
else:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[0]
elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout
if ctx.enable_mla:
# [t, np, hn] -> [t/2, np, hn]
k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0)
v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0)
else:
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
if ctx.use_fused_attention:
if ctx.enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
kv_ = kv_.contiguous()
if ctx.fp8:
aux_ctx_tensors = [
......@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
if not ctx.enable_mla:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_
dout_part = dout_
......@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout_,
q_,
......@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_, out_, dout_ = q[1], out[1], dout[1]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[-3:])
v_part = v_part.view(-1, *v_part.shape[-3:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd":
......@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
if not ctx.enable_mla:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_
dout_part = dout_
......@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout_,
q_,
......@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q
if not ctx.enable_mla:
k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
out_part = out
......@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout,
q,
......@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention:
if ctx.qkv_format in ["bshd", "sbhd"]:
if ctx.enable_mla:
dkv_ = None
elif ctx.qkv_format in ["bshd", "sbhd"]:
dkv_ = combine_tensors([dk_, dv_], -2)
elif ctx.qkv_format == "thd":
dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]:
if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
# dkv is a buffer, so we do not need to transpose it, but only need to reshape it.
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
dkv_ = dkv_.movedim(-3, 0)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
......@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv_ = dkv_.view(*dkv.shape)
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] or
# [2, sk//2, b, np, hn]
dk = dkv[: ctx.k_numel].view(*ctx.k_shape)
dv = dkv[ctx.k_numel :].view(*ctx.v_shape)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
dk_ = dk_.view(*ctx.k_shape)
dv_ = dv_.view(*ctx.v_shape)
if ctx.fp8:
# enable_mla and fp8
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
dk[:, 0, ...].copy_(dk_)
dk[:, 1, ...].fill_(0)
dv[:, 0, ...].copy_(dv_)
dv[:, 1, ...].fill_(0)
elif ctx.qkv_format == "sbhd":
dk[0].copy_(dk_)
dk[1].fill_(0)
dv[0].copy_(dv_)
dv[1].fill_(0)
else:
dk.copy_(dk_)
dv.copy_(dv_)
elif causal:
# enable_mla and not fp8 and causal
if i == (cp_size - 1):
if rank == 0:
if ctx.qkv_format == "bshd":
dk[:, 0, ...].add_(dk_[:, 0, ...])
dk[:, 1, ...].copy_(dk_[:, 1, ...])
dv[:, 0, ...].add_(dv_[:, 0, ...])
dv[:, 1, ...].copy_(dv_[:, 1, ...])
elif ctx.qkv_format == "sbhd":
dk[0, ...].add_(dk_[0, ...])
dk[1, ...].copy_(dk_[1, ...])
dv[0, ...].add_(dv_[0, ...])
dv[1, ...].copy_(dv_[1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "add", "copy"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "add", "copy"
)
else:
dk.add_(dk_)
dv.add_(dv_)
elif i >= (cp_size - rank - 1):
if i == 0 and rank == (cp_size - 1):
if ctx.qkv_format == "bshd":
dk[:, 0, ...].copy_(dk_)
dv[:, 0, ...].copy_(dv_)
elif ctx.qkv_format == "sbhd":
dk[0, ...].copy_(dk_)
dv[0, ...].copy_(dv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "copy", "none"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "copy", "none"
)
else:
if ctx.qkv_format == "bshd":
dk[:, 0, ...].add_(dk_)
dv[:, 0, ...].add_(dv_)
elif ctx.qkv_format == "sbhd":
dk[0, ...].add_(dk_)
dv[0, ...].add_(dv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "add", "none"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "add", "none"
)
elif i > 0:
dk.add_(dk_)
dv.add_(dv_)
else: # i == 0
dk.copy_(dk_)
dv.copy_(dv_)
else:
# enable_mla and not fp8 and not causal
if i == 0:
dk.copy_(dk_)
dv.copy_(dv_)
else: # i > 0
dk.add_(dk_)
dv.add_(dv_)
else:
if ctx.fp8:
# fp8
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].copy_(dkv_)
......@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dkv.copy_(dkv_)
elif causal:
# not fp8 and causal
if i == (cp_size - 1):
if rank == 0:
if ctx.qkv_format == "bshd":
......@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy")
tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "add", "copy"
)
else:
dkv.add_(dkv_)
elif i >= (cp_size - rank - 1):
......@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none")
tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "copy", "none"
)
else:
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none")
tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "add", "none"
)
elif i > 0:
dkv.add_(dkv_)
else:
else: # i == 0
dkv.copy_(dkv_)
else:
# not fp8 and not causal
if i == 0:
dkv.copy_(dkv_)
else:
else: # i > 0
dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1])
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
if ctx.enable_mla:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dk_fp8, fake_dtype=torch.float32, internal=True
)
dv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dv_fp8, fake_dtype=torch.float32, internal=True
)
dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]]
dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]]
else:
if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dkv_fp8, fake_dtype=torch.float32, internal=True
)
......@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:])
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq = dq.view(-1, *dq.shape[-3:])
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
else:
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])
if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
if ctx.enable_mla:
dk[cu_seqlens_kv_padded[-1] :].fill_(0)
dv[cu_seqlens_kv_padded[-1] :].fill_(0)
else:
dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
if ctx.fp8 and ctx.is_input_fp8:
assert torch.uint8 not in [dq.dtype, dkv.dtype]
if ctx.enable_mla:
dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]]
else:
dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
if not ctx.enable_mla:
dk, dv = dkv[0], dkv[1]
if cp_size_a2a > 1:
......@@ -3584,6 +3876,12 @@ def attn_forward_func_with_cp(
"all_gather",
], "The context parallel running configs cannot support sliding window attetnion!"
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "The context parallel running configs cannot support MLA!"
args = [
is_training,
q,
......
......@@ -608,11 +608,6 @@ def get_attention_backend(
" bias for THD format"
)
use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment