Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b4f64e5b
Unverified
Commit
b4f64e5b
authored
Jan 20, 2026
by
Lucas Wilkinson
Committed by
GitHub
Jan 21, 2026
Browse files
Update FlashMLA (#32491)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
7ab80a8e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
169 additions
and
42 deletions
+169
-42
cmake/external_projects/flashmla.cmake
cmake/external_projects/flashmla.cmake
+40
-10
tests/kernels/attention/test_flashmla_sparse.py
tests/kernels/attention/test_flashmla_sparse.py
+1
-1
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+64
-11
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+64
-20
No files found.
cmake/external_projects/flashmla.cmake
View file @
b4f64e5b
...
@@ -19,7 +19,7 @@ else()
...
@@ -19,7 +19,7 @@ else()
FetchContent_Declare
(
FetchContent_Declare
(
flashmla
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG
46d64a8ebef03fa50b4ae74937276a5c940e3f95
GIT_TAG
526781394b33d9888e4c41952e692266267dd8bf
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
BUILD_COMMAND
""
...
@@ -55,16 +55,43 @@ if(FLASH_MLA_ARCHS)
...
@@ -55,16 +55,43 @@ if(FLASH_MLA_ARCHS)
set
(
FlashMLA_SOURCES
set
(
FlashMLA_SOURCES
${
flashmla_SOURCE_DIR
}
/csrc/torch_api.cpp
${
flashmla_SOURCE_DIR
}
/csrc/torch_api.cpp
${
flashmla_SOURCE_DIR
}
/csrc/pybind.cpp
${
flashmla_SOURCE_DIR
}
/csrc/smxx/get_mla_metadata.cu
# Misc kernels for decoding
${
flashmla_SOURCE_DIR
}
/csrc/smxx/mla_combine.cu
${
flashmla_SOURCE_DIR
}
/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/dense/splitkv_mla.cu
${
flashmla_SOURCE_DIR
}
/csrc/smxx/decode/combine/combine.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
# sm90 dense decode
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/dense/instantiations/fp16.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/dense/instantiations/bf16.cu
# sm90 sparse decode
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu
# sm90 sparse prefill
${
flashmla_SOURCE_DIR
}
/csrc/sm90/prefill/sparse/fwd.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/prefill/sparse/fwd.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu
# sm100 dense prefill & backward
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd.cu
# sm100 sparse prefill
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu
# sm100 sparse decode
${
flashmla_SOURCE_DIR
}
/csrc/sm100/decode/head64/instantiations/v32.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/decode/head64/instantiations/model1.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu
)
)
set
(
FlashMLA_Extension_SOURCES
set
(
FlashMLA_Extension_SOURCES
...
@@ -76,6 +103,7 @@ if(FLASH_MLA_ARCHS)
...
@@ -76,6 +103,7 @@ if(FLASH_MLA_ARCHS)
set
(
FlashMLA_INCLUDES
set
(
FlashMLA_INCLUDES
${
flashmla_SOURCE_DIR
}
/csrc
${
flashmla_SOURCE_DIR
}
/csrc
${
flashmla_SOURCE_DIR
}
/csrc/kerutils/include
${
flashmla_SOURCE_DIR
}
/csrc/sm90
${
flashmla_SOURCE_DIR
}
/csrc/sm90
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/include
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/include
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/tools/util/include
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/tools/util/include
...
@@ -83,7 +111,6 @@ if(FLASH_MLA_ARCHS)
...
@@ -83,7 +111,6 @@ if(FLASH_MLA_ARCHS)
set
(
FlashMLA_Extension_INCLUDES
set
(
FlashMLA_Extension_INCLUDES
${
flashmla_SOURCE_DIR
}
/csrc
${
flashmla_SOURCE_DIR
}
/csrc
${
flashmla_SOURCE_DIR
}
/csrc/sm90
${
flashmla_SOURCE_DIR
}
/csrc/extension/sm90/dense_fp8/
${
flashmla_SOURCE_DIR
}
/csrc/extension/sm90/dense_fp8/
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/include
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/include
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/tools/util/include
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/tools/util/include
...
@@ -110,9 +137,12 @@ if(FLASH_MLA_ARCHS)
...
@@ -110,9 +137,12 @@ if(FLASH_MLA_ARCHS)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
# Also enable C++20 for the FlashMLA sources (required for std::span, requires, etc.)
target_compile_options
(
_flashmla_C PRIVATE
target_compile_options
(
_flashmla_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>
)
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-std=c++20>
$<$<COMPILE_LANGUAGE:CUDA>:-std=c++20>
)
define_extension_target
(
define_extension_target
(
_flashmla_extension_C
_flashmla_extension_C
...
...
tests/kernels/attention/test_flashmla_sparse.py
View file @
b4f64e5b
...
@@ -43,7 +43,7 @@ def test_sparse_flashmla_decode_smoke():
...
@@ -43,7 +43,7 @@ def test_sparse_flashmla_decode_smoke():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
batch_size
=
1
batch_size
=
1
seqlen_q
=
1
seqlen_q
=
1
num_heads_q
=
1
num_heads_q
=
64
head_dim_k
=
576
head_dim_k
=
576
head_dim_v
=
512
head_dim_v
=
512
num_heads_k
=
1
num_heads_k
=
1
...
...
tests/v1/attention/test_sparse_mla_backends.py
View file @
b4f64e5b
...
@@ -51,10 +51,34 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
...
@@ -51,10 +51,34 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
)
)
def
_float_to_e8m0_truncate
(
f
:
float
)
->
float
:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
e8m0 format only stores the exponent (power of 2).
cudaRoundZero truncates toward zero, meaning we round down to the
nearest power of 2.
"""
if
f
<=
0
:
return
0.0
# e8m0 = floor(log2(f)), then 2^(e8m0)
# This is equivalent to truncating to the nearest power of 2 below f
exp
=
math
.
floor
(
math
.
log2
(
f
))
return
2.0
**
exp
def
_dequantize_fp8_ds_mla_entry
(
def
_dequantize_fp8_ds_mla_entry
(
cache_slice
:
torch
.
Tensor
,
kv_lora_rank
:
int
,
rope_dim
:
int
,
dtype
:
torch
.
dtype
cache_slice
:
torch
.
Tensor
,
kv_lora_rank
:
int
,
rope_dim
:
int
,
dtype
:
torch
.
dtype
,
simulate_sm100_e8m0_scales
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope.
Args:
simulate_sm100_e8m0_scales: If True, simulate the SM100 kernel's
float -> e8m0 -> bf16 scale conversion path.
"""
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# 128 element tile written as float32 right after the latent payload.
# 128 element tile written as float32 right after the latent payload.
...
@@ -63,10 +87,14 @@ def _dequantize_fp8_ds_mla_entry(
...
@@ -63,10 +87,14 @@ def _dequantize_fp8_ds_mla_entry(
for
tile_idx
in
range
(
4
):
for
tile_idx
in
range
(
4
):
tile_start
=
tile_idx
*
128
tile_start
=
tile_idx
*
128
tile_end
=
tile_start
+
128
tile_end
=
tile_start
+
128
scale_val
=
float
(
scales
[
tile_idx
].
item
())
if
simulate_sm100_e8m0_scales
:
# Simulate the lossy float -> e8m0 -> bf16 conversion
scale_val
=
_float_to_e8m0_truncate
(
scale_val
)
ops
.
convert_fp8
(
ops
.
convert_fp8
(
latent
[
tile_start
:
tile_end
],
latent
[
tile_start
:
tile_end
],
cache_slice
[
tile_start
:
tile_end
],
cache_slice
[
tile_start
:
tile_end
],
float
(
scales
[
tile_idx
].
item
())
,
scale_val
,
kv_dtype
=
"fp8"
,
kv_dtype
=
"fp8"
,
)
)
latent
=
latent
.
to
(
dtype
)
latent
=
latent
.
to
(
dtype
)
...
@@ -77,9 +105,18 @@ def _dequantize_fp8_ds_mla_entry(
...
@@ -77,9 +105,18 @@ def _dequantize_fp8_ds_mla_entry(
def
_quantize_dequantize_fp8_ds_mla
(
def
_quantize_dequantize_fp8_ds_mla
(
kv_c
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
block_size
:
int
,
scale
:
torch
.
Tensor
kv_c
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
block_size
:
int
,
scale
:
torch
.
Tensor
,
simulate_sm100_e8m0_scales
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.
Args:
simulate_sm100_e8m0_scales: If True, simulate the SM100 kernel's
float -> e8m0 -> bf16 scale conversion in dequantization.
"""
if
kv_c
.
numel
()
==
0
:
if
kv_c
.
numel
()
==
0
:
return
kv_c
.
clone
(),
k_pe
.
clone
()
return
kv_c
.
clone
(),
k_pe
.
clone
()
...
@@ -108,7 +145,11 @@ def _quantize_dequantize_fp8_ds_mla(
...
@@ -108,7 +145,11 @@ def _quantize_dequantize_fp8_ds_mla(
block_offset
=
slot
%
block_size
block_offset
=
slot
%
block_size
cache_slice
=
tmp_cache
[
block_idx
,
block_offset
]
cache_slice
=
tmp_cache
[
block_idx
,
block_offset
]
latent
,
rope_vals
=
_dequantize_fp8_ds_mla_entry
(
latent
,
rope_vals
=
_dequantize_fp8_ds_mla_entry
(
cache_slice
,
kv_lora_rank
,
rope_dim
,
kv_c
.
dtype
cache_slice
,
kv_lora_rank
,
rope_dim
,
kv_c
.
dtype
,
simulate_sm100_e8m0_scales
=
simulate_sm100_e8m0_scales
,
)
)
dequant_kv_c
[
token_idx
]
=
latent
dequant_kv_c
[
token_idx
]
=
latent
dequant_k_pe
[
token_idx
]
=
rope_vals
dequant_k_pe
[
token_idx
]
=
rope_vals
...
@@ -143,7 +184,10 @@ def test_sparse_backend_decode_correctness(
...
@@ -143,7 +184,10 @@ def test_sparse_backend_decode_correctness(
batch_spec
=
SPARSE_BACKEND_BATCH_SPECS
[
batch_name
]
batch_spec
=
SPARSE_BACKEND_BATCH_SPECS
[
batch_name
]
# Model hyper-parameters (kept intentionally small for the unit test)
# Model hyper-parameters (kept intentionally small for the unit test)
num_heads
=
128
total_num_heads
=
128
# Compute per-rank heads for simulated TP
num_heads
=
max
(
1
,
total_num_heads
//
tensor_parallel_size
)
kv_lora_rank
=
512
kv_lora_rank
=
512
qk_nope_head_dim
=
128
qk_nope_head_dim
=
128
qk_rope_head_dim
=
64
qk_rope_head_dim
=
64
...
@@ -179,7 +223,7 @@ def test_sparse_backend_decode_correctness(
...
@@ -179,7 +223,7 @@ def test_sparse_backend_decode_correctness(
)
)
model_config
.
dtype
=
dtype
model_config
.
dtype
=
dtype
model_config
.
get_num_attention_heads
=
MethodType
(
model_config
.
get_num_attention_heads
=
MethodType
(
lambda
self
,
parallel_config
:
max
(
1
,
num_heads
//
tensor_parallel_size
)
,
lambda
self
,
parallel_config
:
num_heads
,
model_config
,
model_config
,
)
)
model_config
.
get_num_kv_heads
=
MethodType
(
model_config
.
get_num_kv_heads
=
MethodType
(
...
@@ -195,10 +239,10 @@ def test_sparse_backend_decode_correctness(
...
@@ -195,10 +239,10 @@ def test_sparse_backend_decode_correctness(
scale
=
1.0
/
math
.
sqrt
(
head_size
)
scale
=
1.0
/
math
.
sqrt
(
head_size
)
# Shared MLA projection weights to keep reference and backend in sync
# Shared MLA projection weights to keep reference and backend in sync
W_UK
=
torch
.
rand
n
(
W_UK
=
torch
.
rand
(
kv_lora_rank
,
num_heads
,
qk_nope_head_dim
,
dtype
=
dtype
,
device
=
device
kv_lora_rank
,
num_heads
,
qk_nope_head_dim
,
dtype
=
dtype
,
device
=
device
)
)
W_UV
=
torch
.
rand
n
(
kv_lora_rank
,
num_heads
,
v_head_dim
,
dtype
=
dtype
,
device
=
device
)
W_UV
=
torch
.
rand
(
kv_lora_rank
,
num_heads
,
v_head_dim
,
dtype
=
dtype
,
device
=
device
)
# Build synthetic decode-only workload
# Build synthetic decode-only workload
seq_lens
=
batch_spec
.
seq_lens
seq_lens
=
batch_spec
.
seq_lens
...
@@ -225,11 +269,15 @@ def test_sparse_backend_decode_correctness(
...
@@ -225,11 +269,15 @@ def test_sparse_backend_decode_correctness(
kv_c_full
=
torch
.
rand
(
s_len
,
kv_lora_rank
,
dtype
=
dtype
,
device
=
device
)
kv_c_full
=
torch
.
rand
(
s_len
,
kv_lora_rank
,
dtype
=
dtype
,
device
=
device
)
k_pe_full
=
torch
.
rand
(
s_len
,
1
,
qk_rope_head_dim
,
dtype
=
dtype
,
device
=
device
)
k_pe_full
=
torch
.
rand
(
s_len
,
1
,
qk_rope_head_dim
,
dtype
=
dtype
,
device
=
device
)
# SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion
# which truncates scales to powers of 2. Simulate this in reference.
is_sm100
=
torch
.
cuda
.
get_device_capability
()[
0
]
>=
10
kv_c_full
,
k_pe_full
=
_quantize_dequantize_fp8_ds_mla
(
kv_c_full
,
k_pe_full
=
_quantize_dequantize_fp8_ds_mla
(
kv_c_full
,
kv_c_full
,
k_pe_full
.
squeeze
(
1
),
k_pe_full
.
squeeze
(
1
),
block_size
=
vllm_config
.
cache_config
.
block_size
,
block_size
=
vllm_config
.
cache_config
.
block_size
,
scale
=
kv_cache_scale
,
scale
=
kv_cache_scale
,
simulate_sm100_e8m0_scales
=
is_sm100
,
)
)
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
...
@@ -381,7 +429,12 @@ def test_sparse_backend_decode_correctness(
...
@@ -381,7 +429,12 @@ def test_sparse_backend_decode_correctness(
assert
backend_output
.
dtype
==
sdpa_reference
.
dtype
assert
backend_output
.
dtype
==
sdpa_reference
.
dtype
assert
torch
.
isfinite
(
backend_output
).
all
()
assert
torch
.
isfinite
(
backend_output
).
all
()
torch
.
testing
.
assert_close
(
backend_output
,
sdpa_reference
,
rtol
=
0.5
,
atol
=
0.5
)
# FP8 quantization introduces some error, but should be within reasonable bounds
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
if
kv_cache_dtype
==
"fp8_ds_mla"
:
torch
.
testing
.
assert_close
(
backend_output
,
sdpa_reference
,
rtol
=
0.05
,
atol
=
0.05
)
else
:
torch
.
testing
.
assert_close
(
backend_output
,
sdpa_reference
,
rtol
=
0.01
,
atol
=
0.01
)
def
_triton_convert_reference_impl
(
def
_triton_convert_reference_impl
(
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
b4f64e5b
...
@@ -17,7 +17,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -17,7 +17,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
...
@@ -397,6 +396,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -397,6 +396,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
self
.
fp8_decode_padded_heads
=
(
FlashMLASparseImpl
.
_compute_fp8_decode_padded_heads
(
self
.
num_heads
)
)
self
.
topk_tokens
=
vllm_config
.
model_config
.
hf_config
.
index_topk
self
.
topk_tokens
=
vllm_config
.
model_config
.
hf_config
.
index_topk
self
.
use_fp8_kv_cache
=
cache_config
.
cache_dtype
==
"fp8_ds_mla"
self
.
use_fp8_kv_cache
=
cache_config
.
cache_dtype
==
"fp8_ds_mla"
...
@@ -417,14 +420,20 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -417,14 +420,20 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
(
max_num_seqs
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_num_seqs
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
# Equation taken from FlashMLA/csrc/pybind.cpp
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
h_q
,
h_k
=
self
.
num_heads
,
1
# For sparse FP8 decode, the formula depends on architecture:
s_q
=
1
# inversely proportional to s_q, so s_q = 1 is the largest
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
max_num_sm_parts
=
int
(
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
max
((
sm_count
//
2
)
/
h_k
//
(
cdiv
(
h_q
//
h_k
,
2
*
64
)
*
s_q
),
1
)
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
)
# For max buffer size, use s_q = 1 (the case that produces largest output)
# Use padded head count since that's what will be passed to the kernel
h_q
=
self
.
fp8_decode_padded_heads
if
current_platform
.
is_device_capability_family
(
100
):
if
current_platform
.
is_device_capability_family
(
100
):
max_num_sm_parts
*=
2
# SM100 head64 or head64x2 uses full SM count
max_num_sm_parts
=
sm_count
else
:
# SM90 uses h_q/64 divisor
max_num_sm_parts
=
sm_count
//
max
(
1
,
h_q
//
64
)
self
.
tile_scheduler_metadata_buffer
=
torch
.
empty
(
self
.
tile_scheduler_metadata_buffer
=
torch
.
empty
(
# TileSchedulerMetaDataSize = 8
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
# see: FlashMLA/csrc/params.h
...
@@ -455,12 +464,15 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -455,12 +464,15 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
"""
"""
num_tokens
=
common_attn_metadata
.
num_actual_tokens
num_tokens
=
common_attn_metadata
.
num_actual_tokens
# Use padded head count since that's what the kernel will see
padded_heads
=
self
.
fp8_decode_padded_heads
# Build metadata for all tokens as a single batch
# Build metadata for all tokens as a single batch
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
=
self
.
topk_tokens_tensor
[:
1
],
# Single batch
cache_seqlens
=
self
.
topk_tokens_tensor
[:
1
],
# Single batch
num_q_tokens_per_head_k
=
num_tokens
*
self
.
num
_heads
,
num_q_tokens_per_head_k
=
num_tokens
*
padded
_heads
,
topk
=
self
.
topk_tokens
,
topk
=
self
.
topk_tokens
,
num_heads_q
=
self
.
num
_heads
,
num_heads_q
=
padded
_heads
,
num_heads_k
=
1
,
num_heads_k
=
1
,
is_fp8_kvcache
=
True
,
is_fp8_kvcache
=
True
,
)
)
...
@@ -606,11 +618,13 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -606,11 +618,13 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
decode_query_len
=
(
query_start_loc_cpu
[
1
]
-
query_start_loc_cpu
[
0
]).
item
()
decode_query_len
=
(
query_start_loc_cpu
[
1
]
-
query_start_loc_cpu
[
0
]).
item
()
# Use padded head count since that's what the kernel will see
padded_heads
=
self
.
fp8_decode_padded_heads
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
=
self
.
topk_tokens_tensor
[:
num_decodes
],
cache_seqlens
=
self
.
topk_tokens_tensor
[:
num_decodes
],
num_q_tokens_per_head_k
=
decode_query_len
*
self
.
num
_heads
,
num_q_tokens_per_head_k
=
decode_query_len
*
padded
_heads
,
topk
=
self
.
topk_tokens
,
topk
=
self
.
topk_tokens
,
num_heads_q
=
self
.
num
_heads
,
num_heads_q
=
padded
_heads
,
num_heads_k
=
1
,
num_heads_k
=
1
,
is_fp8_kvcache
=
True
,
is_fp8_kvcache
=
True
,
)
)
...
@@ -689,6 +703,12 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -689,6 +703,12 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
class
FlashMLASparseImpl
(
MLACommonBaseImpl
[
FlashMLASparseMetadata
]):
class
FlashMLASparseImpl
(
MLACommonBaseImpl
[
FlashMLASparseMetadata
]):
@
staticmethod
def
_compute_fp8_decode_padded_heads
(
num_heads
:
int
)
->
int
:
# FP8 decode kernel only supports h_q = 64 or 128
# Compute padded head count for decode
return
64
if
num_heads
<=
64
else
128
def
__init__
(
def
__init__
(
self
,
self
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -722,7 +742,11 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -722,7 +742,11 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
self
.
softmax_scale
=
scale
self
.
softmax_scale
=
scale
assert
indexer
is
not
None
assert
indexer
is
not
None
self
.
topk_indices_buffer
:
torch
.
Tensor
|
None
=
indexer
.
topk_indices_buffer
self
.
topk_indices_buffer
:
torch
.
Tensor
|
None
=
indexer
.
topk_indices_buffer
self
.
padding
=
128
if
current_platform
.
is_device_capability_family
(
100
)
else
64
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
self
.
prefill_padding
=
(
128
if
current_platform
.
is_device_capability_family
(
100
)
else
64
)
self
.
fp8_decode_padded_heads
=
self
.
_compute_fp8_decode_padded_heads
(
num_heads
)
if
kv_cache_dtype
==
"fp8_ds_mla"
:
if
kv_cache_dtype
==
"fp8_ds_mla"
:
# Reserve workspace during initialization
# Reserve workspace during initialization
...
@@ -903,8 +927,22 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -903,8 +927,22 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
kernel_metadata
:
FlashMLASparseMetadata
.
FP8KernelMetadata
,
kernel_metadata
:
FlashMLASparseMetadata
.
FP8KernelMetadata
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
flash_mla_with_kvcache
(
# q shape: (batch, seq_len, num_heads, head_dim)
actual_num_heads
=
q
.
size
(
2
)
padded_num_heads
=
self
.
fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128)
if
actual_num_heads
<
padded_num_heads
:
logger
.
warning_once
(
f
"Padding num_heads from
{
actual_num_heads
}
to "
f
"
{
padded_num_heads
}
for FP8 sparse decode kernel"
)
q_padded
=
q
.
new_zeros
((
q
.
size
(
0
),
q
.
size
(
1
),
padded_num_heads
,
q
.
size
(
3
)))
q_padded
[:,
:,
:
actual_num_heads
,
:]
=
q
q
=
q_padded
out
,
lse
=
flash_mla_with_kvcache
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
uint8
).
unsqueeze
(
-
2
),
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
uint8
).
unsqueeze
(
-
2
),
block_table
=
kernel_metadata
.
dummy_block_table
,
block_table
=
kernel_metadata
.
dummy_block_table
,
...
@@ -917,6 +955,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -917,6 +955,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
softmax_scale
=
self
.
softmax_scale
,
softmax_scale
=
self
.
softmax_scale
,
)
)
# Slice output back to actual head count if we padded
if
actual_num_heads
<
padded_num_heads
:
out
=
out
[:,
:,
:
actual_num_heads
,
:]
return
out
,
lse
def
_bf16_flash_mla_kernel
(
def
_bf16_flash_mla_kernel
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
@@ -930,13 +974,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -930,13 +974,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
# 64 on hopper and 128 on blackwell
if
self
.
num_heads
%
self
.
padding
!=
0
:
if
self
.
num_heads
%
self
.
prefill_
padding
!=
0
:
assert
self
.
padding
%
self
.
num_heads
==
0
assert
self
.
prefill_
padding
%
self
.
num_heads
==
0
logger
.
warning_once
(
logger
.
warning_once
(
f
"
p
adding num_heads
to
{
self
.
padding
}
due to sparse attn
"
f
"
P
adding num_heads
from
{
self
.
num_heads
}
to
"
"kernel requirement
"
f
"
{
self
.
prefill_padding
}
for BF16 sparse prefill kernel
"
)
)
q_padded
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
padding
,
q
.
shape
[
2
]))
q_padded
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
prefill_
padding
,
q
.
shape
[
2
]))
q_padded
[:,
:
self
.
num_heads
,
:]
=
q
q_padded
[:,
:
self
.
num_heads
,
:]
=
q
q
=
q_padded
q
=
q_padded
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment