Commit b3230e1a authored by Yongye Zhu's avatar Yongye Zhu Committed by simon-mo
Browse files
parent 03df0fb5
......@@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
......@@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(SUPPORT_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
list(APPEND SUPPORT_ARCHS 9.0a)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
list(APPEND SUPPORT_ARCHS 10.0a)
endif()
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
if(FLASH_MLA_ARCHS)
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
)
set(FlashMLA_Extension_SOURCES
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
set(FlashMLA_Extension_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc)
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_Extension_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
define_gpu_extension_target(
_flashmla_C
DESTINATION vllm
......@@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
USE_SABI 3
WITH_SOABI)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
define_gpu_extension_target(
_flashmla_extension_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${FlashMLA_Extension_SOURCES}
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
USE_SABI 3
WITH_SOABI)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_extension_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
else()
# Create an empty target for setup.py when not targeting sm90a systems
# Create empty targets for setup.py when not targeting sm90a systems
add_custom_target(_flashmla_C)
add_custom_target(_flashmla_extension_C)
endif()
......@@ -56,3 +56,11 @@ void cp_gather_cache(
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <cassert>
#include <cfloat> // FLT_MIN
#include <map>
#include <vector>
......@@ -396,6 +397,176 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_ds_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride;
// Create 4 tile scales in shared memory
__shared__ float smem[20];
float* shard_abs_max = smem;
float* tile_scales = smem + 16;
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
// Cast kv_cache to 16_bit for RoPE values
scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last 64 threads handle the RoPE part
if (threadIdx.x >= kv_lora_rank) {
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
// RoPE values start after the packed 8-bit NoPE values and the
// 32-bit scales
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
kv_cache_16bit[dst_idx] = k_pe[src_idx];
return;
}
// Determine the scale for each chunk of NoPE
const int16_t tile_idx = threadIdx.x >> 7;
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
const int16_t lane_idx = threadIdx.x & 31;
// Load the NoPE element for this thread into registers
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
const scalar_t src_val = kv_c[src_idx];
// Warp-level reduction to find the max absolute value in the warp
float max_abs = fabsf(src_val);
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
#ifdef USE_ROCM
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
#else
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
#endif
}
// The first lane of each warp in each tile writes the max_abs of this part
// of the tile to shared memory
if (lane_idx == 0) {
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
}
__syncthreads();
// The first lane of the first warp in each tile computes the scale for the
// tile and writes it to shared memory and to kv_cache
if (warp_idx == 0 && lane_idx == 0) {
float4 shard_abs_max_vec =
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
448.f;
// Avoid division by zero in `scaled_convert`
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
}
__syncthreads();
// Now all threads in the block scale and write their element
const float scale_val = tile_scales[tile_idx];
const int64_t dst_idx = dst_idx_start + threadIdx.x;
kv_cache[dst_idx] =
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
src_val, scale_val);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
float amax = 0.0f;
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
}
__syncwarp();
// Reduced amax
for (int mask = 16; mask > 0; mask /= 2) {
#ifdef USE_ROCM
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
#else
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
}
__syncwarp();
float scale = fmaxf(amax, 1e-4) / 448.0f;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
}
if (threadIdx.x == 0) {
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride +
cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
} // namespace vllm
// KV_T is the data type of key and value tensors.
......@@ -438,7 +609,7 @@ void reshape_and_cache(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE)
CALL_RESHAPE_AND_CACHE);
}
// KV_T is the data type of key and value tensors.
......@@ -509,6 +680,18 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
......@@ -531,20 +714,44 @@ void concat_and_cache_mla(
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
if (kv_cache_dtype == "fp8_ds_mla") {
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
TORCH_CHECK(kv_c.itemsize() == 2,
"kv_c.itemsize() must be 2 for fp8_ds_mla");
TORCH_CHECK(k_pe.itemsize() == 2,
"k_pe.itemsize() must be 2 for fp8_ds_mla");
} else {
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
}
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
if (kv_cache_dtype == "fp8_ds_mla") {
dim3 grid(num_tokens);
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
dim3 block(576);
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_DS_MLA);
} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
}
namespace vllm {
......@@ -922,3 +1129,42 @@ void cp_gather_cache(
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
cache_block_size, cache_stride, use_ue8m0);
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
CALL_INDEXER_K_QUANT_AND_CACHE);
}
\ No newline at end of file
......@@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
......
......@@ -713,6 +713,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
......
......@@ -322,6 +322,8 @@ class precompiled_wheel_utils:
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/_flashmla_extension_C.abi3.so",
"vllm/_sparse_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/cumem_allocator.abi3.so",
......@@ -589,6 +591,8 @@ if _is_cuda():
# not targeting a hopper system
ext_modules.append(
CMakeExtension(name="vllm._flashmla_C", optional=True))
ext_modules.append(
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops():
......
......@@ -191,7 +191,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
use_mla=False,
),
layer_names=[self.attn.layer_name],
vllm_config=self.vllm_config,
......
......@@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i]
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results,
# so this must be manually computed.
tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
fp8_paged_mqa_logits, get_num_sms,
get_paged_mqa_logits_metadata)
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
# x: (num_blocks, block_size, 1, head_dim)
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
x_fp8 = torch.empty(
(num_blocks, block_size * (head_dim + 4)),
device=x.device,
dtype=torch.uint8,
)
x_fp8[:, :block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
x_fp8[:,
block_size * head_dim:] = sf.view(num_blocks,
block_size).view(dtype=torch.uint8)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
def per_custom_dims_cast_to_fp8(
x: torch.Tensor, dims: tuple,
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
chunk_size = seq_len // 2
cp_size = seq_len_kv // seq_len
cp_id = cp_size // 3
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
for i in range(chunk_size):
ke[i] = cp_id * chunk_size + i
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
return ks, ke
def _ref_fp8_mqa_logits(
q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
):
seq_len_kv = kv.shape[0]
k = kv
q = q.float()
k = k.float()
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
score = torch.einsum("mhd,and->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
def test_deepgemm_fp8_mqa_logits():
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
for seq_len in (512, ):
for seq_len_kv in (1024, ):
for disable_cp in (False, True):
q = torch.randn(
seq_len,
num_heads,
head_dim,
device="cuda",
dtype=torch.bfloat16,
)
kv = torch.randn(seq_len_kv,
head_dim,
device="cuda",
dtype=torch.bfloat16)
weights = torch.randn(seq_len,
num_heads,
device="cuda",
dtype=torch.float32)
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.arange(seq_len, dtype=torch.int,
device="cuda") + (seq_len_kv - seq_len)
else:
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
ref_logits = _ref_fp8_mqa_logits(
q=q,
kv=kv,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke,
)
ref_neginf_mask = ref_logits == float("-inf")
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
):
batch_size, next_n, _, _ = q.size()
_, block_size, _, _ = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
context_lens_list = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n,
context_len,
device="cuda")
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
0, 1).contiguous())
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size,
(block_rk + 1) * block_size,
device="cuda",
)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
<= q_offsets[:, None])
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n:(i + 1) * next_n,
block_rk * block_size:(block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
def test_deepgemm_fp8_paged_mqa_logits():
torch.manual_seed(0)
random.seed(0)
max_model_len = 4096
for batch_size, next_n in [(4, 1), (2, 2)]:
for heads, index_dim in [(32, 128)]:
for avg_kv in (2048, ):
num_blocks, blocksize = max_model_len * 2, 64
q = torch.randn(
(batch_size, next_n, heads, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
kv_cache = torch.randn(
(num_blocks, blocksize, 1, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
weights = torch.randn(
(batch_size * next_n, heads),
device="cuda",
dtype=torch.float32,
)
context_lens = (torch.randint(int(0.8 * avg_kv),
int(1.2 * avg_kv),
(batch_size, )).cuda().to(
torch.int32))
max_block_len = ((context_lens.max().item() + blocksize - 1) //
blocksize * blocksize)
block_tables = torch.zeros(
(batch_size, max_block_len),
device="cuda",
dtype=torch.int32,
)
counter = 0
block_idx_pool = list(range(num_blocks))
random.shuffle(block_idx_pool)
for i in range(batch_size):
ctx_len = int(context_lens[i].item())
for j in range((ctx_len + blocksize - 1) // blocksize):
block_tables[i][j] = block_idx_pool[counter]
counter += 1
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms())
logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
)
ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
context_lens,
block_tables,
max_model_len,
)
positions = (torch.arange(max_model_len,
device="cuda").unsqueeze(0).expand(
batch_size * next_n, -1))
row_indices = (
torch.arange(batch_size * next_n, device="cuda") // next_n)
next_n_offset = (
torch.arange(batch_size * next_n, device="cuda") % next_n)
mask = positions <= (context_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
logits = logits.masked_fill(~mask, 0)
ref_logits = ref_logits.masked_fill(~mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
......@@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
descale_k = None
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
return flash_mla_with_kvcache(q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k)
def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
def _cuda_sm90_available() -> bool:
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 9
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 128
num_heads_k = 1
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
topk = 128
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
assert tile_md.dtype == torch.int32
assert num_splits.dtype == torch.int32
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 1
head_dim_k = 576
head_dim_v = 512
num_heads_k = 1
page_block_size = 64
bytes_per_token = 656
topk = 128
# Metadata
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
# q_heads_per_hk = num_heads_q // num_heads_k
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
# Inputs
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
dtype=torch.bfloat16,
device=device)
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
dtype=torch.uint8,
device=device)
indices = torch.zeros((batch_size, seqlen_q, topk),
dtype=torch.int32,
device=device)
block_table = torch.zeros((batch_size, 128),
dtype=torch.int32,
device=device)
out, lse = fm.flash_mla_with_kvcache(q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_md,
num_splits,
indices=indices,
is_fp8_kvcache=True)
assert out.shape[0] == batch_size
assert out.shape[-1] == head_dim_v
assert lse.shape[0] == batch_size
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
s_q = 1
s_kv = 1
h_q = 64 # kernel expects multiple of 64
h_kv = 1
d_qk = 576
d_v = 512
topk = 128
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
d_v)
assert out.shape == (s_q, h_q, d_v)
assert max_logits.shape == (s_q, h_q)
assert lse.shape == (s_q, h_q)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.testing import assert_close
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
def test_pack_seq_basic_fp8():
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test cases with 3D tensors (N, H, D)
test_cases = [
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
]
for N, H, D, B, lengths_list in test_cases:
# Create input tensor with small values for fp8
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor(lengths_list, device=device)
# Pack the data
packed = pack_seq_triton(x, lengths)
# Check output shape and properties
expected_shape = (B, max(lengths_list), H, D)
assert packed.shape == expected_shape
assert packed.dtype == dtype
assert packed.device == x.device
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = sum(lengths_list[:b])
seq_len = lengths_list[b]
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
actual_data = packed[b, :seq_len].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
def test_pack_seq_custom_padding_fp8():
"""Test pack_seq_triton with custom padding values for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 20, 8, 16, 2
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
# Test with different padding values
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
result = pack_seq_triton(x, lengths, pad_value=pad_value)
# Check valid data
for b in range(B):
start_idx = b * 10
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
actual_data = result[b, :10].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
# Check padding (fp8 has limited range, so check for large values)
padded_data = result[:, 10:].to(torch.float32)
if pad_value < 0:
assert torch.all(padded_data < -50) # Large negative values
elif pad_value > 0:
assert torch.all(padded_data > 50) # Large positive values
else:
assert torch.allclose(padded_data,
torch.zeros_like(padded_data),
atol=1e-2)
def test_pack_seq_default_negative_inf_padding_fp8():
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# B = 2
N, H, D = 20, 8, 16
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
result = pack_seq_triton(x, lengths)
# Check that padding is large negative values (fp8 representation of -inf)
padded_data = result[:, 10:].to(torch.float32)
assert torch.all(
padded_data < -100) # fp8 -inf is represented as large negative number
def test_pack_seq_edge_cases_fp8():
"""Test pack_seq_triton with edge cases for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test with single batch element
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([10], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (1, 10, 8, 16)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([1, 1, 1], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (3, 1, 4, 8)
# Test with different sequence lengths
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([5, 7, 3], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (3, 7, 8, 16)
def test_pack_seq_different_block_sizes_fp8():
"""Test pack_seq_triton with different block sizes for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 100, 16, 32, 4
lengths = torch.tensor([25, 25, 25, 25], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
# Test different block sizes
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
assert result.shape == (B, 25, H, D)
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = b * 25
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
actual_data = result[b, :25].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
def test_pack_seq_shape_consistency():
"""Test that pack_seq_triton maintains shape consistency."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 20, 8, 16, 2
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
result = pack_seq_triton(x, lengths)
# Check shape consistency
assert result.shape[0] == B # Batch dimension
assert result.shape[1] == lengths.max().item() # Max sequence length
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
def test_pack_unpack_roundtrip_fp8():
"""Test that pack -> unpack gives us back the original data for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test cases with 3D tensors
test_cases = [
(6, 8, 4, 2, [3, 3]),
(10, 4, 8, 3, [2, 4, 4]),
(20, 16, 32, 4, [5, 5, 5, 5]),
(15, 8, 16, 3, [7, 5, 3]),
]
for N, H, D, B, lengths_list in test_cases:
# Create input tensor with small values for fp8
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor(lengths_list, device=device)
# Pack the data
packed = pack_seq_triton(x, lengths)
# Unpack the data
unpacked = unpack_seq_triton(packed, lengths)
# Check that we get back the original data (within fp8 precision)
assert unpacked.shape == x.shape
x_f32 = x.to(torch.float32)
unpacked_f32 = unpacked.to(torch.float32)
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
# Unpack without explicit start locations (computed in kernel)
unpacked_with_loc = unpack_seq_triton(packed, lengths)
assert_close(x_f32,
unpacked_with_loc.to(torch.float32),
rtol=1e-3,
atol=1e-2)
def test_unpack_seq_triton_edge_cases_fp8():
"""Test unpack function with edge cases for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test with single batch element
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([10], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([1, 1, 1], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
# Only compare the first 3 elements that were actually packed
assert_close(x[:3].to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([5, 7, 3], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
......@@ -207,6 +207,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
min_transformers_version="4.54"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
......
......@@ -8,7 +8,8 @@ import pytest
from vllm import LLM
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
get_kv_cache_configs)
from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test
......@@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# Avoid calling model.forward()
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_configs(
kv_cache_configs = get_kv_cache_configs(
vllm_config,
kv_cache_specs,
[10 * GiB_bytes],
)[0]
)
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config
......
......@@ -26,5 +26,5 @@ class DummyPlatform(Platform):
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink):
has_sink, use_sparse):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
from typing import Optional, Union
import pytest
import torch
......@@ -10,6 +11,7 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
......@@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache(
device: torch.device,
num_blocks: int,
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor:
randomize_blocks: bool = True,
kv_cache_dtype: Optional[str] = None,
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
"""Create and prepopulate an MLA KV cache with context data.
Args:
......@@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks
or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to
"fp8_ds_mla" the cache is populated using the
fp8 DeepSeek MLA layout via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested.
Returns:
MLA KV cache tensor
......@@ -105,23 +114,61 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(num_blocks,
block_size,
head_size,
dtype=dtype,
device=device)
kv_cache_flat = kv_cache.view(-1, head_size)
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
if use_fp8_ds_mla:
if not kv_c_contexts:
raise ValueError("kv_c_contexts cannot be empty when using"
" fp8_ds_mla cache dtype")
kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1]
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
kv_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=device)
scale_tensor = (scale
if isinstance(scale, torch.Tensor) else torch.tensor(
scale, dtype=torch.float32, device=device))
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
else:
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(num_blocks,
block_size,
head_size,
dtype=dtype,
device=device)
kv_cache_flat = kv_cache.view(-1, head_size)
# Populate the cache with the context tokens
# Start from block_id=1 since block_id=0 is considered the null block
start_block_idx = 1
for i in range(batch_size):
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
context_len = kv_c_context.shape[0]
if context_len == 0:
start_block_idx += cdiv(int(seq_lens[i]), block_size)
continue
start = start_block_idx * block_size
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context
if use_fp8_ds_mla:
slots = torch.arange(context_len, device=device,
dtype=torch.long) + start
ops.concat_and_cache_mla(
kv_c_context,
k_pe_context.squeeze(1),
kv_cache,
slots,
kv_cache_dtype="fp8_ds_mla",
scale=scale_tensor,
)
else:
kv_context = torch.cat(
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context
# Stay block aligned and allocate enough blocks for the new tokens
start_block_idx += cdiv(int(seq_lens[i]), block_size)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""
import math
from types import MethodType, SimpleNamespace
import numpy as np
import pytest
import torch
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS, BatchSpec, MockAttentionLayer,
create_and_prepopulate_kv_cache)
from tests.v1.attention.utils import (create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata)
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
for name in [
"mixed_small",
"mixed_medium",
"small_prefill",
"medium_prefill",
"single_prefill",
]
}
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
query_lens=[256] * 2)
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2)
def _dequantize_fp8_ds_mla_entry(
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# 128 element tile written as float32 right after the latent payload.
scales = cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
latent = torch.empty(kv_lora_rank,
dtype=torch.float16,
device=cache_slice.device)
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = tile_start + 128
ops.convert_fp8(latent[tile_start:tile_end],
cache_slice[tile_start:tile_end],
float(scales[tile_idx].item()),
kv_dtype="fp8")
latent = latent.to(dtype)
rope_offset = kv_lora_rank // 2 + 8
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
return latent, rope_vals.clone()
def _quantize_dequantize_fp8_ds_mla(
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
if kv_c.numel() == 0:
return kv_c.clone(), k_pe.clone()
kv_lora_rank = kv_c.shape[-1]
rope_dim = k_pe.shape[-1]
num_tokens = kv_c.shape[0]
num_blocks = max(1, math.ceil(num_tokens / block_size))
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
tmp_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=kv_c.device)
slot_mapping = torch.arange(num_tokens,
dtype=torch.long,
device=kv_c.device)
ops.concat_and_cache_mla(kv_c,
k_pe,
tmp_cache,
slot_mapping,
kv_cache_dtype="fp8_ds_mla",
scale=scale)
dequant_kv_c = torch.empty_like(kv_c)
dequant_k_pe = torch.empty_like(k_pe)
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
block_idx = slot // block_size
block_offset = slot % block_size
cache_slice = tmp_cache[block_idx, block_offset]
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
dequant_kv_c[token_idx] = latent
dequant_k_pe[token_idx] = rope_vals
return dequant_kv_c, dequant_k_pe
def test_sparse_backend_metadata_registration():
backend = FlashMLASparseBackend
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
assert backend.get_metadata_cls() is FlashMLASparseMetadata
assert backend.get_impl_cls() is FlashMLASparseImpl
dtype_list = backend.get_supported_dtypes()
assert torch.bfloat16 in dtype_list
shape = backend.get_kv_cache_shape(num_blocks=2,
block_size=64,
num_kv_heads=1,
head_size=576)
assert shape == (2, 64, 576)
def test_sparse_decode_metadata_filters_prefill_indices():
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
metadata = FlashMLASparseDecodeAndContextMetadata(
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
num_splits=torch.tensor([1, 1], dtype=torch.int32),
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
prefill_context_lengths=prefill_context_lengths,
)
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
context_indices, new_token_indices = metadata.filter_prefill_indices(
indices)
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
dtype=torch.int32)
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
dtype=torch.int32)
assert torch.equal(context_indices, expected_context)
assert torch.equal(new_token_indices, expected_new_tokens)
def test_sparse_impl_zero_fills_when_metadata_missing():
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
dummy_layer = object()
q = torch.zeros((2, 1, 3))
k_c = torch.zeros((2, 3))
k_pe = torch.zeros((2, 1, 1))
kv_cache = torch.zeros((1, 1, 1))
output = torch.ones((2, 4))
result = FlashMLASparseImpl.forward(impl,
dummy_layer,
q,
k_c,
k_pe,
kv_cache,
attn_metadata=None,
output=output)
assert result is output
assert torch.all(result == 0)
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name,
kv_cache_dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
device = torch.device("cuda")
dtype = torch.bfloat16
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
# Model hyper-parameters (kept intentionally small for the unit test)
num_heads = 128
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
v_head_dim = 128
head_size = kv_lora_rank + qk_rope_head_dim
topk_tokens = 2048
max_seqlen = max(batch_spec.seq_lens)
total_cache_tokens = sum(batch_spec.seq_lens)
block_size = 64
vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=max_seqlen,
num_gpu_blocks=max(2048,
cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size)
model_config = vllm_config.model_config
model_config.hf_config = SimpleNamespace(
attn_module_list_cfg=[{
"topk_tokens": topk_tokens
}])
model_config.hf_text_config = SimpleNamespace(
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
model_type="deepseek_v2",
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config)
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
model_config)
model_config.get_head_size = MethodType(lambda self: head_size,
model_config)
model_config.get_sliding_window = MethodType(lambda self: None,
model_config)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
torch.manual_seed(0)
scale = 1.0 / math.sqrt(head_size)
# Shared MLA projection weights to keep reference and backend in sync
W_UK = torch.randn(kv_lora_rank,
num_heads,
qk_nope_head_dim,
dtype=dtype,
device=device)
W_UV = torch.randn(kv_lora_rank,
num_heads,
v_head_dim,
dtype=dtype,
device=device)
# Build synthetic decode-only workload
seq_lens = batch_spec.seq_lens
query_lens = batch_spec.query_lens
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
kv_c_contexts, k_pe_contexts = [], []
reference_outputs = []
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
for i in range(batch_spec.batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
ctx_len = s_len - q_len
q_c = torch.rand(q_len,
num_heads,
qk_nope_head_dim + qk_rope_head_dim,
dtype=dtype,
device=device)
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
k_pe_full = torch.rand(s_len,
1,
qk_rope_head_dim,
dtype=dtype,
device=device)
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
k_pe_full.squeeze(1),
block_size=vllm_config.cache_config.block_size,
scale=kv_cache_scale,
)
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, ctx_len:] = causal_mask
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
all_q_vllm.append(q_c)
all_kv_c_vllm.append(kv_c_full[ctx_len:])
all_k_pe_vllm.append(k_pe_full[ctx_len:])
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_reference = torch.cat(reference_outputs, dim=0)
vllm_config.cache_config.cache_dtype = kv_cache_dtype
common_attn_metadata = create_common_attn_metadata(
batch_spec,
vllm_config.cache_config.block_size,
device,
arange_block_indices=True)
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=vllm_config.cache_config.block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=False,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
scale=kv_cache_scale,
)
builder_cls = FlashMLASparseBackend.get_builder_cls()
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
metadata = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths)
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths)
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
topk = metadata.topk_tokens
debug_indices = torch.arange(topk, device=device,
dtype=torch.int32).unsqueeze(0)
token_positions = pos_gpu.unsqueeze(1)
causal_mask = (debug_indices <= token_positions)
debug_indices = torch.where(causal_mask, debug_indices,
torch.full_like(debug_indices, -1))
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
# buffer, so emulate that contract with a simple namespace mock.
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
-1).clone()
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
ok, reason = flashmla.is_flashmla_supported()
if not ok:
pytest.skip(reason)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
output_size=num_heads *
(qk_nope_head_dim + v_head_dim),
bias=False).to(device=device,
dtype=dtype)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer)
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
out_buffer = torch.empty(metadata.num_actual_tokens,
num_heads * v_head_dim,
dtype=dtype,
device=device)
backend_output = impl.forward(layer,
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer)
assert backend_output.shape == sdpa_reference.shape
assert backend_output.dtype == sdpa_reference.dtype
assert torch.isfinite(backend_output).all()
torch.testing.assert_close(backend_output,
sdpa_reference,
rtol=0.5,
atol=0.5)
......@@ -168,7 +168,6 @@ def create_standard_kv_cache_spec(
vllm_config.parallel_config),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
use_mla=vllm_config.model_config.use_mla,
sliding_window=vllm_config.model_config.get_sliding_window(),
)
......
......@@ -24,7 +24,8 @@ from vllm.v1.core.kv_cache_utils import (
make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec,
KVCacheTensor, MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
......@@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window)
......@@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False,
sliding_window=1):
return SlidingWindowSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window)
......@@ -894,7 +891,6 @@ def test_merge_kv_cache_spec():
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
......@@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len,
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
)
# Estimate the maximum model length, 16384 model_len need 8GB
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
......@@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config():
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
)
sliding_window_spec = SlidingWindowSpec(
......@@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config():
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
sliding_window=1024,
)
......@@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config():
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
],
)
def new_mla_spec(cache_dtype_str=None):
return MLAAttentionSpec(block_size=16,
num_kv_heads=16,
head_size=64,
dtype=torch.float32,
cache_dtype_str=cache_dtype_str)
def test_merge_mla_spec():
kv_cache_specs = [
new_mla_spec(),
new_mla_spec(),
]
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
assert mla_spec == new_mla_spec()
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
]
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_mla_spec(cache_dtype_str=None),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
kv_cache_specs = [
new_kv_cache_spec(),
new_mla_spec(),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_kv_cache_spec(),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
......@@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
FullAttentionSpec(block_size, 1, 1, torch.float32),
)
],
)
......@@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
FullAttentionSpec(block_size, 1, 1, torch.float32),
),
KVCacheGroupSpec(
["layer2"],
......@@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
1,
1,
torch.float32,
False,
sliding_window=2 * block_size),
),
KVCacheGroupSpec(
......@@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
1,
1,
torch.float32,
False,
sliding_window=2 * block_size),
),
],
......@@ -1338,7 +1336,6 @@ def test_eagle_with_sliding_window():
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
use_mla=False,
)
manager = KVCacheManager(
KVCacheConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment