Unverified Commit b4f64e5b authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

Update FlashMLA (#32491)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 7ab80a8e
......@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 46d64a8ebef03fa50b4ae74937276a5c940e3f95
GIT_TAG 526781394b33d9888e4c41952e692266267dd8bf
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
......@@ -55,16 +55,43 @@ if(FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
# Misc kernels for decoding
${flashmla_SOURCE_DIR}/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
${flashmla_SOURCE_DIR}/csrc/smxx/decode/combine/combine.cu
# sm90 dense decode
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/fp16.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/bf16.cu
# sm90 sparse decode
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu
# sm90 sparse prefill
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu
# sm100 dense prefill & backward
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
# sm100 sparse prefill
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu
# sm100 sparse decode
${flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/v32.cu
${flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/model1.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu
)
set(FlashMLA_Extension_SOURCES
......@@ -76,6 +103,7 @@ if(FLASH_MLA_ARCHS)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/kerutils/include
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
......@@ -83,7 +111,6 @@ if(FLASH_MLA_ARCHS)
set(FlashMLA_Extension_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
......@@ -110,9 +137,12 @@ if(FLASH_MLA_ARCHS)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
# Also enable C++20 for the FlashMLA sources (required for std::span, requires, etc.)
target_compile_options(_flashmla_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-std=c++20>
$<$<COMPILE_LANGUAGE:CUDA>:-std=c++20>)
define_extension_target(
_flashmla_extension_C
......
......@@ -43,7 +43,7 @@ def test_sparse_flashmla_decode_smoke():
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 1
num_heads_q = 64
head_dim_k = 576
head_dim_v = 512
num_heads_k = 1
......
......@@ -51,10 +51,34 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
)
def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
e8m0 format only stores the exponent (power of 2).
cudaRoundZero truncates toward zero, meaning we round down to the
nearest power of 2.
"""
if f <= 0:
return 0.0
# e8m0 = floor(log2(f)), then 2^(e8m0)
# This is equivalent to truncating to the nearest power of 2 below f
exp = math.floor(math.log2(f))
return 2.0**exp
def _dequantize_fp8_ds_mla_entry(
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
cache_slice: torch.Tensor,
kv_lora_rank: int,
rope_dim: int,
dtype: torch.dtype,
simulate_sm100_e8m0_scales: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope.
Args:
simulate_sm100_e8m0_scales: If True, simulate the SM100 kernel's
float -> e8m0 -> bf16 scale conversion path.
"""
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# 128 element tile written as float32 right after the latent payload.
......@@ -63,10 +87,14 @@ def _dequantize_fp8_ds_mla_entry(
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = tile_start + 128
scale_val = float(scales[tile_idx].item())
if simulate_sm100_e8m0_scales:
# Simulate the lossy float -> e8m0 -> bf16 conversion
scale_val = _float_to_e8m0_truncate(scale_val)
ops.convert_fp8(
latent[tile_start:tile_end],
cache_slice[tile_start:tile_end],
float(scales[tile_idx].item()),
scale_val,
kv_dtype="fp8",
)
latent = latent.to(dtype)
......@@ -77,9 +105,18 @@ def _dequantize_fp8_ds_mla_entry(
def _quantize_dequantize_fp8_ds_mla(
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
kv_c: torch.Tensor,
k_pe: torch.Tensor,
block_size: int,
scale: torch.Tensor,
simulate_sm100_e8m0_scales: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.
Args:
simulate_sm100_e8m0_scales: If True, simulate the SM100 kernel's
float -> e8m0 -> bf16 scale conversion in dequantization.
"""
if kv_c.numel() == 0:
return kv_c.clone(), k_pe.clone()
......@@ -108,7 +145,11 @@ def _quantize_dequantize_fp8_ds_mla(
block_offset = slot % block_size
cache_slice = tmp_cache[block_idx, block_offset]
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
cache_slice,
kv_lora_rank,
rope_dim,
kv_c.dtype,
simulate_sm100_e8m0_scales=simulate_sm100_e8m0_scales,
)
dequant_kv_c[token_idx] = latent
dequant_k_pe[token_idx] = rope_vals
......@@ -143,7 +184,10 @@ def test_sparse_backend_decode_correctness(
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
# Model hyper-parameters (kept intentionally small for the unit test)
num_heads = 128
total_num_heads = 128
# Compute per-rank heads for simulated TP
num_heads = max(1, total_num_heads // tensor_parallel_size)
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
......@@ -179,7 +223,7 @@ def test_sparse_backend_decode_correctness(
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: max(1, num_heads // tensor_parallel_size),
lambda self, parallel_config: num_heads,
model_config,
)
model_config.get_num_kv_heads = MethodType(
......@@ -195,10 +239,10 @@ def test_sparse_backend_decode_correctness(
scale = 1.0 / math.sqrt(head_size)
# Shared MLA projection weights to keep reference and backend in sync
W_UK = torch.randn(
W_UK = torch.rand(
kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
)
W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
W_UV = torch.rand(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
# Build synthetic decode-only workload
seq_lens = batch_spec.seq_lens
......@@ -225,11 +269,15 @@ def test_sparse_backend_decode_correctness(
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
# SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion
# which truncates scales to powers of 2. Simulate this in reference.
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
k_pe_full.squeeze(1),
block_size=vllm_config.cache_config.block_size,
scale=kv_cache_scale,
simulate_sm100_e8m0_scales=is_sm100,
)
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
......@@ -381,7 +429,12 @@ def test_sparse_backend_decode_correctness(
assert backend_output.dtype == sdpa_reference.dtype
assert torch.isfinite(backend_output).all()
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
# FP8 quantization introduces some error, but should be within reasonable bounds
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
if kv_cache_dtype == "fp8_ds_mla":
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05)
else:
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
def _triton_convert_reference_impl(
......
......@@ -17,7 +17,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -397,6 +396,10 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
self.fp8_decode_padded_heads = (
FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
......@@ -417,14 +420,20 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
(max_num_seqs, 1), dtype=torch.int32, device=self.device
)
# Equation taken from FlashMLA/csrc/pybind.cpp
h_q, h_k = self.num_heads, 1
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
max_num_sm_parts = int(
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
)
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
# For sparse FP8 decode, the formula depends on architecture:
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
# For max buffer size, use s_q = 1 (the case that produces largest output)
# Use padded head count since that's what will be passed to the kernel
h_q = self.fp8_decode_padded_heads
if current_platform.is_device_capability_family(100):
max_num_sm_parts *= 2
# SM100 head64 or head64x2 uses full SM count
max_num_sm_parts = sm_count
else:
# SM90 uses h_q/64 divisor
max_num_sm_parts = sm_count // max(1, h_q // 64)
self.tile_scheduler_metadata_buffer = torch.empty(
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
......@@ -455,12 +464,15 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
"""
num_tokens = common_attn_metadata.num_actual_tokens
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
# Build metadata for all tokens as a single batch
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * self.num_heads,
num_q_tokens_per_head_k=num_tokens * padded_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
......@@ -606,11 +618,13 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
num_q_tokens_per_head_k=decode_query_len * padded_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_q=padded_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
......@@ -689,6 +703,12 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
@staticmethod
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128
# Compute padded head count for decode
return 64 if num_heads <= 64 else 128
def __init__(
self,
num_heads: int,
......@@ -722,7 +742,11 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
self.prefill_padding = (
128 if current_platform.is_device_capability_family(100) else 64
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
......@@ -903,8 +927,22 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
) -> torch.Tensor:
return flash_mla_with_kvcache(
) -> tuple[torch.Tensor, torch.Tensor]:
# q shape: (batch, seq_len, num_heads, head_dim)
actual_num_heads = q.size(2)
padded_num_heads = self.fp8_decode_padded_heads
# Pad query if needed (kernel only supports h_q = 64 or 128)
if actual_num_heads < padded_num_heads:
logger.warning_once(
f"Padding num_heads from {actual_num_heads} to "
f"{padded_num_heads} for FP8 sparse decode kernel"
)
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
q_padded[:, :, :actual_num_heads, :] = q
q = q_padded
out, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=kernel_metadata.dummy_block_table,
......@@ -917,6 +955,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
softmax_scale=self.softmax_scale,
)
# Slice output back to actual head count if we padded
if actual_num_heads < padded_num_heads:
out = out[:, :, :actual_num_heads, :]
return out, lse
def _bf16_flash_mla_kernel(
self,
q: torch.Tensor,
......@@ -930,13 +974,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.padding != 0:
assert self.padding % self.num_heads == 0
if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0
logger.warning_once(
f"padding num_heads to {self.padding} due to sparse attn "
"kernel requirement"
f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel"
)
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q
q = q_padded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment