The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
...
@@ -56,14 +58,18 @@ It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8,
...
@@ -56,14 +58,18 @@ It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8,
- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)
- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)
- PyTorch 2.0 and above
- PyTorch 2.0 and above
For the Hygon DCU gfx93 path in this branch, build with the local AICC/ROCm
toolchain and the pinned CUTLASS submodule described below.
If you only run `pip install -v .` or `python setup.py install`, avoid launching
tests from a different unbuilt source checkout, otherwise Python may import that
checkout source `flash_mla/` directory and fail with `No module named flash_mla.cuda`.
## Usage
## Usage
### MLA Decoding
### MLA Decoding
...
@@ -119,6 +138,11 @@ Where
...
@@ -119,6 +138,11 @@ Where
**FP8 KV Cache:**
**FP8 KV Cache:**
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.
**BF16 Sparse Decode on gfx93:**
This branch also supports the DSA BF16 sparse decode path when
`is_fp8_kvcache=False` and `k_cache` / `extra_k_cache` are `torch.bfloat16`.
The test suite enables this path with `FLASH_MLA_DECODE_BF16=1`.
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
-**Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
-**Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.