prefill_kernel_opt_readme 2.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
2026-05-22 sparse prefill phase1 optimization

Scope
- Target kernel: csrc/gfx93/prefill/sparse/phase1.cuh
- Target shape: D_QK=512, h_q=64, topk=512
- Optimized dispatch: HAVE_TOPK_LENGTH=true and attn_sink enabled.
- Non-target D512/H64 combinations remain on the generic KernelTemplate path to avoid the measured slowdown risk from routing all D512/H64 cases through the H64 fast path.

Implementation
- Extended KernelTemplate_B_H_64 so the QK pipeline supports D_QK=512 with 16 q/k chunks instead of the existing D_QK=576-only 18 chunk schedule.
- Avoided index LDS prefetch overhead when IS_TOPK_2048 is false.
- Added attn_sink output scaling for the D512/H64 topk_length path.
- Added KernelTemplate_D512_H64_TopkLen_AttnSink wrapper and dispatch for D_QK=512 && HAVE_TOPK_LENGTH && h_q=64 && attn_sink.

Build
- Command:
  source /parastor/home/public_user/zhanghj/dtk-26.04-DCC2602-0317/env.sh
  touch csrc/gfx93/prefill/sparse/instantiations/phase1_k512.hip csrc/gfx93/prefill/sparse/instantiations/phase1_k512_topklen.hip
  FLASH_MLA_OPT=phase1_d512_h64 python setup.py build_ext --inplace
- The touch is needed because phase1.cuh is included by generated instantiation .hip sources.

Benchmark
- Command:
  PYTHONPATH=/parastor/home/public_user/zhanghj/flashmla/tests:$PYTHONPATH HIP_VISIBLE_DEVICES=1 python /parastor/home/public_user/zhanghj/hygon_tmp/bench_sparse_prefill_target.py --runs 20 --correctness --topk-length
- Device reported by PyTorch in this run: gfx936:sramecc+:xnack-
- hy-smi was run immediately before the measurement; all HCUs were idle at the start.

Target results: D_QK=512, h_q=64, topk=512, s_q=4096, HAVE_TOPK_LENGTH=true, attn_sink=true

| s_kv | baseline us | optimized us | latency reduction |
| ---: | ----------: | -----------: | ----------------: |
| 8192 | 3727.012 | 1955.701 | 47.53% |
| 32768 | 7798.721 | 2955.996 | 62.10% |
| 49152 | 8790.056 | 3162.484 | 64.02% |
| 65536 | 8959.296 | 3212.299 | 64.15% |

Average latency reduction: 59.45%.

The baseline values are from the same target benchmark before enabling this fast path. The optimized run also checked correctness for every measured target row.

Full regression
- Command:
  HIP_VISIBLE_DEVICES=1 python tests/test_flash_mla_sparse_prefill.py
- Result:
  All 958 cases passed.

Notes
- A broader D512/H64 dispatch was tested for HAVE_TOPK_LENGTH=false and attn_sink true/false. It gave only small, noisy gains in some rows and could regress the existing performance path, so the committed dispatch is limited to the target slow path that clears the 30% improvement requirement.
- The test environment prints NumPy 2.x compatibility warnings from PyTorch import. They did not prevent correctness or benchmark execution.