Unverified Commit 23afdfd1 authored by Fan Yin's avatar Fan Yin Committed by GitHub
Browse files

[sgl-kernel] support flashmla libtorch (#11717)

parent 9d61205d
cmake_minimum_required(VERSION 3.26 FATAL_ERROR) cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA) project(sgl-kernel LANGUAGES CXX CUDA)
# utils
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
include(FetchContent)
# CMake # CMake
cmake_policy(SET CMP0169 OLD) cmake_policy(SET CMP0169 OLD)
cmake_policy(SET CMP0177 NEW) cmake_policy(SET CMP0177 NEW)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
set(CMAKE_COLOR_DIAGNOSTICS ON) set(CMAKE_COLOR_DIAGNOSTICS ON)
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON") set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
...@@ -37,11 +40,9 @@ endif() ...@@ -37,11 +40,9 @@ endif()
# Torch # Torch
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG) clear_cuda_arches(CMAKE_FLAG)
include(FetchContent) # Third Party
# cutlass # cutlass
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
...@@ -69,7 +70,7 @@ FetchContent_Declare( ...@@ -69,7 +70,7 @@ FetchContent_Declare(
) )
FetchContent_Populate(repo-fmt) FetchContent_Populate(repo-fmt)
# Triton # Triton kernel
FetchContent_Declare( FetchContent_Declare(
repo-triton repo-triton
GIT_REPOSITORY "https://github.com/triton-lang/triton" GIT_REPOSITORY "https://github.com/triton-lang/triton"
...@@ -143,12 +144,6 @@ endif() ...@@ -143,12 +144,6 @@ endif()
include_directories( include_directories(
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc ${PROJECT_SOURCE_DIR}/csrc
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
${repo-fast-hadamard-transform}/csrc
) )
set(SGL_KERNEL_CUDA_FLAGS set(SGL_KERNEL_CUDA_FLAGS
...@@ -350,6 +345,7 @@ set(SOURCES ...@@ -350,6 +345,7 @@ set(SOURCES
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
) )
# =========================== Common SM90 Build ============================= #
# Build SM90 library with fast math optimization (same namespace, different directory) # Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
...@@ -360,7 +356,11 @@ target_compile_options(common_ops_sm90_build PRIVATE ...@@ -360,7 +356,11 @@ target_compile_options(common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math> $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
) )
target_include_directories(common_ops_sm90_build PRIVATE target_include_directories(common_ops_sm90_build PRIVATE
${PROJECT_SOURCE_DIR}/csrc ${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
${repo-cutlass_SOURCE_DIR}/examples/common ${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
...@@ -371,6 +371,7 @@ set_target_properties(common_ops_sm90_build PROPERTIES ...@@ -371,6 +371,7 @@ set_target_properties(common_ops_sm90_build PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90" LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90"
) )
# =========================== Common SM100+ Build ============================= #
# Build SM100+ library with precise math (same namespace, different directory) # Build SM100+ library with precise math (same namespace, different directory)
Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
...@@ -381,7 +382,11 @@ target_compile_options(common_ops_sm100_build PRIVATE ...@@ -381,7 +382,11 @@ target_compile_options(common_ops_sm100_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}> $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>
) )
target_include_directories(common_ops_sm100_build PRIVATE target_include_directories(common_ops_sm100_build PRIVATE
${PROJECT_SOURCE_DIR}/csrc ${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
${repo-cutlass_SOURCE_DIR}/examples/common ${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
...@@ -408,7 +413,7 @@ else() ...@@ -408,7 +413,7 @@ else()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
endif() endif()
# mscclpp # mscclpp option
set(MSCCLPP_USE_CUDA ON) set(MSCCLPP_USE_CUDA ON)
set(MSCCLPP_BYPASS_GPU_CHECK ON) set(MSCCLPP_BYPASS_GPU_CHECK ON)
set(MSCCLPP_BUILD_TESTS OFF) set(MSCCLPP_BUILD_TESTS OFF)
...@@ -419,7 +424,7 @@ add_subdirectory( ...@@ -419,7 +424,7 @@ add_subdirectory(
target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
# flash attention # sparse flash attention
target_compile_definitions(common_ops_sm90_build PRIVATE target_compile_definitions(common_ops_sm90_build PRIVATE
FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT FLASHATTENTION_DISABLE_DROPOUT
...@@ -506,6 +511,8 @@ if (SGL_KERNEL_ENABLE_FA3) ...@@ -506,6 +511,8 @@ if (SGL_KERNEL_ENABLE_FA3)
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>) target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
target_include_directories(flash_ops PRIVATE target_include_directories(flash_ops PRIVATE
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flash-attention_SOURCE_DIR}/hopper ${repo-flash-attention_SOURCE_DIR}/hopper
) )
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
...@@ -535,6 +542,8 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN ...@@ -535,6 +542,8 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
# ============================ Extra Install ============================= #
include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake)
# ============================ DeepGEMM (JIT) ============================= # # ============================ DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API. # Create a separate library for DeepGEMM's Python API.
......
include(FetchContent)
# flash_mla
FetchContent_Declare(
repo-flashmla
GIT_REPOSITORY https://github.com/sgl-project/FlashMLA
GIT_TAG bc8576abc3e507425cf6498f3d3393df7733ce37
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashmla)
set(FLASHMLA_CUDA_FLAGS
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
)
# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
if(${CUDA_VERSION} VERSION_GREATER 12.4)
list(APPEND FLASHMLA_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
endif()
if(${CUDA_VERSION} VERSION_GREATER 12.8)
list(APPEND FLASHMLA_CUDA_FLAGS
"-gencode=arch=compute_100a,code=sm_100a"
)
endif()
set(FlashMLA_SOURCES
"csrc/flashmla_extension.cc"
${repo-flashmla_SOURCE_DIR}/csrc/python_api.cpp
${repo-flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
${repo-flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
)
Python_add_library(flashmla_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FlashMLA_SOURCES})
target_compile_options(flashmla_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${FLASHMLA_CUDA_FLAGS}>)
target_include_directories(flashmla_ops PRIVATE
${repo-flashmla_SOURCE_DIR}/csrc
${repo-flashmla_SOURCE_DIR}/csrc/sm90
${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include
${repo-flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
target_link_libraries(flashmla_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS flashmla_ops LIBRARY DESTINATION "sgl_kernel")
target_compile_definitions(flashmla_ops PRIVATE)
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From FlashMLA
*/
m.def(
"get_mla_decoding_metadata(Tensor seqlens_k, int num_q_tokens_per_head_k, int h_k, int? h_q, bool "
"is_fp8_kvcache, int? topk) -> Tensor[]");
m.impl("get_mla_decoding_metadata", torch::kCUDA, &get_mla_decoding_metadata);
m.def(
"fwd_kvcache_mla(Tensor q, Tensor kv_cache, int head_size_v, Tensor seqlens_k, Tensor block_table, float "
"softmax_scale, bool is_causal, Tensor tile_scheduler_metadata, Tensor num_splits, bool is_fp8, Tensor? indices) "
"-> Tensor[]");
m.impl("fwd_kvcache_mla", torch::kCUDA, &fwd_kvcache_mla);
m.def(
"dense_prefill_fwd(Tensor workspace_buffer, Tensor q, Tensor k, Tensor v, Tensor cumulative_seqlen_q, Tensor "
"cumulative_seqlen_kv, Tensor o, Tensor lse, int mask_mode_code, float softmax_scale, int max_seqlen_q, int "
"max_seqlen_kv, bool is_varlen) -> ()");
m.impl("dense_prefill_fwd", torch::kCUDA, &FMHACutlassSM100FwdRun);
m.def("sparse_prefill_fwd(Tensor q, Tensor kv, Tensor indices, float sm_scale, int d_v) -> Tensor[]");
m.impl("sparse_prefill_fwd", torch::kCUDA, &sparse_prefill_fwd);
}
REGISTER_EXTENSION(flashmla_ops)
...@@ -842,6 +842,7 @@ void es_fp8_blockwise_scaled_grouped_mm( ...@@ -842,6 +842,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& expert_offsets,
const torch::Tensor& workspace); const torch::Tensor& workspace);
/* /*
* From fast-hadamard-transform * From fast-hadamard-transform
*/ */
...@@ -850,3 +851,47 @@ torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale); ...@@ -850,3 +851,47 @@ torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
/*
* From csrc/fastertransformer
*/
std::vector<at::Tensor> get_mla_decoding_metadata(
at::Tensor& seqlens_k,
const int64_t num_q_tokens_per_head_k,
const int64_t h_k,
const std::optional<int64_t> h_q,
const bool is_fp8_kvcache,
const std::optional<int64_t> topk);
std::vector<at::Tensor> fwd_kvcache_mla(
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or
// num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int64_t head_size_v,
const at::Tensor& seqlens_k, // batch_size
const at::Tensor& block_table, // batch_size x max_num_blocks_per_seq
const double softmax_scale,
bool is_causal,
const at::Tensor& tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor& num_splits, // batch_size + 1
const bool& is_fp8,
const std::optional<at::Tensor>& indices // None, or batch_size x seqlen_q x topk
);
void FMHACutlassSM100FwdRun(
at::Tensor workspace_buffer,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv,
at::Tensor o,
at::Tensor lse,
int64_t mask_mode_code,
double softmax_scale,
int64_t max_seqlen_q,
int64_t max_seqlen_kv,
bool is_varlen);
std::vector<at::Tensor>
sparse_prefill_fwd(const at::Tensor& q, const at::Tensor& kv, const at::Tensor& indices, double sm_scale, int64_t d_v);
from typing import Optional, Tuple
import torch
try:
from . import flashmla_ops # triggers TORCH extension registration
except Exception as _e:
_flashmla_import_error = _e
else:
_flashmla_import_error = None
_IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return torch.ops.sgl_kernel.get_mla_decoding_metadata.default(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
num_heads_q,
is_fp8_kvcache,
topk,
)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
assert causal == False, "causal must be `false` if sparse attention is enabled."
out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla.default(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
return out, softmax_lse
def flash_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512
Returns:
(output, max_logits, lse)
About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops.sgl_kernel.sparse_prefill_fwd.default(
q, kv, indices, sm_scale, d_v
)
return results
import math
import random
from typing import Optional, Tuple
import pytest
import torch
import triton
from sgl_kernel.flash_mla import (
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
get_mla_metadata,
)
skip_condition = torch.cuda.get_device_capability() < (10, 0)
# ================ prefill usage ================ #
S_Q_PREFILL = [1, 62]
KV_TOPK_PREFILL = [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
# ================= decode usage ================= #
B_DECODE = [1, 2, 6, 64]
S_Q_DECODE = [1, 2, 4]
S_K_DECODE = [20, 140, 4096]
IS_VARLEN = [False, True]
CAUSAL_TOPK = [(True, None), (False, None), (False, 128), (False, 2048)]
DTYPE = [torch.float16, torch.bfloat16]
def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
result = result.view(num_blocks, block_size, 1, -1)
return result
def dequantize_k_cache(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty(
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
def cdiv(x: int, y: int):
return (x + y - 1) // y
def get_window_size(causal, window):
if window > 0:
window_size = (window - 1, 0) if causal else (window - 1, window - 1)
else:
window_size = (-1, -1)
return window_size
def get_attn_bias(s_q, s_k, causal, window):
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32, device="cuda")
if causal:
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
diagonal=s_k - s_q
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
if window > 0:
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
diagonal=s_k - s_q - window
)
attn_bias.masked_fill_(temp_mask, float("-inf"))
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
diagonal=s_k - s_q + window - 1
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
return attn_bias
def sdpa(query, key, value, attn_bias, softmax_scale=None):
query = query.float().transpose(-3, -2)
key = key.float().transpose(-3, -2)
value = value.float().transpose(-3, -2)
key = key.repeat_interleave(h // h_k, dim=-3)
value = value.repeat_interleave(h // h_k, dim=-3)
if softmax_scale is None:
softmax_scale = query.shape[-1] ** (-0.5)
attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight.to(query.dtype) @ value, lse
def sdpa_checkpoint(*args, **kwargs):
return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)
def reference_torch_prefill(
s_q, s_kv, topk, indices, q, kv, sm_scale: float
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
indices = indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
qs = q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = kv[0, :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(
kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()
).view(
s_q, topk, 576
) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :512]
return (max_logits, lse, result)
def reference_torch_decode(
cache_seqlens: torch.Tensor, # [batch_size]
block_table: torch.Tensor, # [batch_size, ?]
q: torch.Tensor, # [batch_size, s_q, h_q, d]
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool, device="cuda")
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
s_q = query.shape[-2]
s_k = kv.shape[-2]
query = query.float()
kv = kv.float()
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda")
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device="cuda")
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = lse == float("-inf")
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
cache_seqlens_cpu = cache_seqlens.cpu()
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device="cuda")
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device="cuda")
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0:cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
i,
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None,
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref
@pytest.mark.parametrize("s_q", S_Q_PREFILL)
@pytest.mark.parametrize("kv_topk", KV_TOPK_PREFILL)
@torch.inference_mode()
def test_flashmla_prefill(
s_q: int,
kv_topk: Tuple[int, int],
):
torch.cuda.empty_cache()
q = torch.randn((1, s_q, 128, 576), dtype=torch.bfloat16, device="cuda") / 10
kv = torch.randn((1, kv_topk[0], 1, 576), dtype=torch.bfloat16, device="cuda") / 10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full(
(1, s_q, 1, kv_topk[1]), kv_topk[0], dtype=torch.int32, device="cuda"
)
for s in range(s_q):
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = (
torch.randint(0, 32, (min(kv_topk[1], kv_topk[0]),), device="cuda") < 31
)
cur_indices = torch.randperm(kv_topk[0], device="cuda")[: kv_topk[1]]
cur_indices[near_mask] = torch.randint(
max(0, kv_topk[0] - 20000),
kv_topk[0] - 1,
(near_mask.sum().item(),),
device="cuda",
)
if len(cur_indices) < kv_topk[1]:
cur_indices = torch.cat(
[
cur_indices,
torch.full(
(kv_topk[1] - len(cur_indices),), 2147480000, device="cuda"
),
]
)
cur_indices = cur_indices[torch.randperm(kv_topk[1], device="cuda")]
indices[0, s, 0] = cur_indices
indices = indices.to(q.device)
sm_scale = 1 / math.sqrt(576)
torch.cuda.synchronize()
ans_out, ans_max_logits, ans_lse = flash_mla_sparse_fwd(
q.squeeze(0), kv.squeeze(0), indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = (
ans_out.float(),
ans_max_logits.float(),
ans_lse.float(),
)
torch.cuda.synchronize()
ref_max_logits, ref_lse, ref_out = reference_torch_prefill(
s_q, kv_topk[0], kv_topk[1], indices, q, kv, sm_scale
)
torch.cuda.synchronize()
torch.testing.assert_close(ans_out, ref_out, atol=8e-4, rtol=2.01 / 128)
torch.testing.assert_close(
ans_max_logits,
ref_max_logits,
atol=1e-6,
rtol=2.01 / 65536,
)
torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536)
@pytest.mark.parametrize("b", B_DECODE)
@pytest.mark.parametrize("s_q", S_Q_DECODE)
@pytest.mark.parametrize("s_k", S_K_DECODE)
@pytest.mark.parametrize("is_varlen", IS_VARLEN)
@pytest.mark.parametrize("causal_topk", CAUSAL_TOPK)
@pytest.mark.parametrize("dtype", DTYPE)
@torch.inference_mode()
def test_flash_mla_decode(
b: int,
s_q: int,
s_k: int,
is_varlen: bool,
causal_topk: Tuple[bool, Optional[int]],
dtype: torch.dtype,
):
d = 576
dv = 512
block_size = 64
h_q = 128
h_kv = 1
is_causal = causal_topk[0]
topk = causal_topk[1]
# Generating test data
torch.cuda.synchronize()
cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu")
if is_varlen:
for i in range(b):
cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), s_q)
max_seqlen = cache_seqlens_cpu.max().item()
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
cache_seqlens = cache_seqlens_cpu.cuda()
q = torch.randn(b, s_q, 128, d, dtype=torch.bfloat16, device="cuda")
q.clamp_(min=-1.0, max=1.0)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32, device="cuda"
).view(b, max_seqlen_pad // block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
blocked_k = (
torch.randn(
block_table.numel(),
block_size,
h_kv,
d,
dtype=torch.bfloat16,
device="cuda",
)
/ 10
)
blocked_k.clamp_(min=-1.0, max=1.0)
if topk is None:
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % block_size != 0:
blocked_k[block_table[i][cur_num_blocks - 1]][
cur_len % block_size :
] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
abs_indices = None
indices_in_kvcache = None
else:
block_table_cpu = block_table.cpu()
abs_indices = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
indices_in_kvcache = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
for i in range(b):
# Generate indices
for j in range(s_q):
cur_abs_indices = torch.randperm(
int(cache_seqlens_cpu[i].item()), device="cpu"
)[:topk]
cur_blocked_indices = block_table_cpu[
i, cur_abs_indices // block_size
] * block_size + (cur_abs_indices % block_size)
if len(cur_abs_indices) < topk:
pad_len = topk - len(cur_abs_indices)
cur_abs_indices = torch.cat(
[cur_abs_indices, torch.full((pad_len,), -1, device="cpu")]
)
cur_blocked_indices = torch.cat(
[cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")]
)
# Mask KV
perm = torch.randperm(topk, device="cpu")
cur_abs_indices = cur_abs_indices[perm]
cur_blocked_indices = cur_blocked_indices[perm]
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu")
blocked_k = blocked_k.view(-1, h_kv, d)
nonused_indices_mask = torch.ones(
blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu"
)
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, block_size, h_kv, d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
is_fp8 = topk is not None
if is_fp8:
# The quantization error may be too large to be distinguished from wrong kernels
# So we quantize and de-quantize kv-cache here to mitigate quantization error
blocked_k_quantized = quantize_k_cache(blocked_k, dv, 128)
blocked_k_dequantized = dequantize_k_cache(blocked_k_quantized)
blocked_k = blocked_k_dequantized
# Get schedule metadata
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk
)
torch.cuda.synchronize()
out_ans, lse_ans = flash_mla_with_kvcache(
q,
blocked_k if not is_fp8 else blocked_k_quantized, # type: ignore
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=is_causal,
is_fp8_kvcache=is_fp8,
indices=indices_in_kvcache,
)
out_ref, lse_ref = reference_torch_decode(
cache_seqlens, block_table, q, blocked_k, dv, is_causal, abs_indices
)
torch.testing.assert_close(out_ans, out_ref, atol=8e-4, rtol=2.01 / 128)
torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment