Unverified Commit 560bccf8 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

clean CP implementation for flash attention and cuDNN 9.6 (#1387)



* make pad_between_seqs check do not consider padding at the end
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change CP THD test to make it consider 0-length sequence
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change to flash func name
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* only use varlen func of flash attention while qkv_format is THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to converge code of flash and fused attentions
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bwd compute with P2P
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant out_per_step view
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable cudnn>9.6 and THD+GQA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable CP with FusedAttn+SWA+All_Gather
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable CP with FusedAttn+SWA+All_Gather
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning for cu_seqlens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix some pylint error
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor import change for pylint
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* more fix for pylint
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix lse_seqlen in thd out correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a4cb1d17
...@@ -163,12 +163,10 @@ def run_dpa_with_cp( ...@@ -163,12 +163,10 @@ def run_dpa_with_cp(
torch.tensor([q_input_shape[0]], dtype=torch.int32), torch.tensor([q_input_shape[0]], dtype=torch.int32),
] ]
).cuda() ).cuda()
if kernel_backend == "FlashAttention": cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
cu_seqlens_q = cu_seqlens_q_padded[:-1] if kernel_backend == "FusedAttention":
else: cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q = torch.cat( cu_seqlens_q[-1] = cu_seqlens_q[-2]
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded cu_seqlens_kv_padded = cu_seqlens_q_padded
else: else:
...@@ -204,10 +202,8 @@ def run_dpa_with_cp( ...@@ -204,10 +202,8 @@ def run_dpa_with_cp(
core_attention_bias=bias, core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=( cu_seqlens_kv_padded=cu_seqlens_kv_padded,
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
) )
if fp8_mha: if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
...@@ -276,10 +272,8 @@ def run_dpa_with_cp( ...@@ -276,10 +272,8 @@ def run_dpa_with_cp(
core_attention_bias=bias_, core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=( cu_seqlens_kv_padded=cu_seqlens_kv_padded,
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
) )
if fp8_mha: if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
...@@ -311,7 +305,7 @@ def run_dpa_with_cp( ...@@ -311,7 +305,7 @@ def run_dpa_with_cp(
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
) )
...@@ -327,7 +321,7 @@ def run_dpa_with_cp( ...@@ -327,7 +321,7 @@ def run_dpa_with_cp(
).item() ).item()
== 0 == 0
) )
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
) )
......
...@@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0): if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!") pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather": if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!") pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type: if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if dtype == "fp8" and cp_comm_type == "all_gather": if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip( pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!" "CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
...@@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention cannot work with bias yet!") pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!") pytest.skip("FP8 attention cannot work with sliding window yet!")
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!") pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
......
This diff is collapsed.
...@@ -1537,10 +1537,10 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ ...@@ -1537,10 +1537,10 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
int batch, lse_seqlen; int batch, lse_seqlen;
if (lse_packed) { if (lse_packed) {
batch = cu_seqlens.size(0) - 1; batch = cu_seqlens.size(0) - 1;
lse_seqlen = total_tokens; lse_seqlen = lse.size(1);
NVTE_CHECK(lse.size(0) == num_heads); NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse.size(1) == lse_seqlen); NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step.size(0) == num_heads); NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1));
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment