Commit 635b2b33 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Update README for DSA gfx93 workflow

parent 98c9821f
...@@ -10,6 +10,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ...@@ -10,6 +10,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee
- Token-level sparse attention for the prefill stage - Token-level sparse attention for the prefill stage
- Token-level sparse attention for the decoding stage, with FP8 KV cache - Token-level sparse attention for the decoding stage, with FP8 KV cache
- DSA MLS sparse prefill and DSA BF16 sparse decoding on Hygon DCU gfx93
**Dense Attention Kernels** **Dense Attention Kernels**
...@@ -30,6 +31,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ...@@ -30,6 +31,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee
```bash ```bash
python tests/test_flash_mla_dense_decoding.py python tests/test_flash_mla_dense_decoding.py
python tests/test_flash_mla_sparse_decoding.py python tests/test_flash_mla_sparse_decoding.py
FLASH_MLA_DECODE_BF16=1 python tests/test_flash_mla_sparse_decoding.py
``` ```
The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet). The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
...@@ -56,14 +58,18 @@ It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, ...@@ -56,14 +58,18 @@ It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8,
- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels) - CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)
- PyTorch 2.0 and above - PyTorch 2.0 and above
For the Hygon DCU gfx93 path in this branch, build with the local AICC/ROCm
toolchain and the pinned CUTLASS submodule described below.
Support matrix: Support matrix:
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: |
| Dense Decoding | SM90 | MQA | BF16 | | Dense Decoding | SM90 | MQA | BF16 |
| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] | | Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] |
| DSA BF16 Sparse Decoding | gfx93 | MQA | BF16 |
| Dense Prefill | SM100 | MHA | | | Dense Prefill | SM100 | MHA | |
| Sparse Prefill | SM90 & SM100 | MQA | | | Sparse Prefill | SM90 & SM100 / gfx93 | MQA | |
[1]: For more details on using FP8 KV cache, see documents below. [1]: For more details on using FP8 KV cache, see documents below.
...@@ -72,8 +78,8 @@ Support matrix: ...@@ -72,8 +78,8 @@ Support matrix:
## Installation ## Installation
```bash ```bash
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla git clone -b master-aicc ssh://git@10.16.1.204:10022/dcutoolkit/deeplearing/flashmla.git flashmla
cd flash-mla cd flashmla
git submodule update --init --recursive git submodule update --init --recursive
pip install -v . pip install -v .
``` ```
...@@ -82,6 +88,19 @@ The CUTLASS dependency is pinned as `csrc/cutlass/cutlass_3.2.1` on branch ...@@ -82,6 +88,19 @@ The CUTLASS dependency is pinned as `csrc/cutlass/cutlass_3.2.1` on branch
`feature/16x64-mmac`. If the submodule is missing, `setup.py` will try to `feature/16x64-mmac`. If the submodule is missing, `setup.py` will try to
initialize it before compiling. initialize it before compiling.
When running tests directly from the source tree, prefer an in-place build so
the local `flash_mla` package can find the compiled `flash_mla.cuda` extension:
```bash
python setup.py build_ext --inplace
python tests/test_flash_mla_sparse_prefill.py
FLASH_MLA_DECODE_BF16=1 python tests/test_flash_mla_sparse_decoding.py
```
If you only run `pip install -v .` or `python setup.py install`, avoid launching
tests from a different unbuilt source checkout, otherwise Python may import that
checkout source `flash_mla/` directory and fail with `No module named flash_mla.cuda`.
## Usage ## Usage
### MLA Decoding ### MLA Decoding
...@@ -119,6 +138,11 @@ Where ...@@ -119,6 +138,11 @@ Where
**FP8 KV Cache:** **FP8 KV Cache:**
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16. If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.
**BF16 Sparse Decode on gfx93:**
This branch also supports the DSA BF16 sparse decode path when
`is_fp8_kvcache=False` and `k_cache` / `extra_k_cache` are `torch.bfloat16`.
The test suite enables this path with `FLASH_MLA_DECODE_BF16=1`.
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. - **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. - **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment