Unverified Commit 3e41992f authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 91401c7a
#pragma once #pragma once
#include <torch/all.h> #include <torch/all.h>
#include <c10/util/Optional.h>
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -58,6 +59,15 @@ void cp_gather_cache( ...@@ -58,6 +59,15 @@ void cp_gather_cache(
torch::Tensor const& cu_seq_lens, // [BATCH+1] torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt); int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// Indexer K quantization and cache function // Indexer K quantization and cache function
void indexer_k_quant_and_cache( void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim] torch::Tensor& k, // [num_tokens, head_dim]
...@@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache( ...@@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks] const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1] const torch::Tensor& cu_seq_lens); // [batch_size + 1]
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cuda_compat.h" #include "cuda_compat.h"
...@@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel( ...@@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int quant_block_size, // quantization block size const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
const bool use_ue8m0 // use ue8m0 scale format
) { ) {
constexpr int VEC_SIZE = 4; constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
...@@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache( ...@@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache(
} }
namespace vllm { namespace vllm {
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
const bool is_active_split = (split_start < tot_slots);
if (!is_active_split) return;
// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;
const int tid = threadIdx.x;
// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}
template <typename scalar_t> template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size. // block_size.
...@@ -1202,6 +1280,57 @@ void cp_gather_cache( ...@@ -1202,6 +1280,57 @@ void cp_gather_cache(
} }
} }
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32");
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_cache.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}
// Macro to dispatch the kernel based on the data type. // Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \ vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
......
...@@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def( cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, " "slot_mapping, "
......
...@@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): ...@@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.fixture
def workspace_init():
"""Initialize the workspace manager for tests that need it.
This fixture initializes the workspace manager with a CUDA device
if available, and resets it after the test completes. Tests that
create a full vLLM engine should NOT use this fixture as the engine
will initialize the workspace manager itself.
"""
from vllm.v1.worker.workspace import (
init_workspace_manager,
reset_workspace_manager,
)
if torch.cuda.is_available():
device = torch.device("cuda:0")
init_workspace_manager(device)
yield
reset_workspace_manager()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def dynamo_reset(): def dynamo_reset():
yield yield
......
...@@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128] ...@@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128]
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert @pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4]) @pytest.mark.parametrize("topk", [2, 4])
def test_batched_deepgemm_vs_triton( def test_batched_deepgemm_vs_triton(
E: int, T: int, K: int, N: int, topk: int, monkeypatch E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
): ):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" """Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
......
...@@ -248,6 +248,7 @@ def test_fused_moe_batched_experts( ...@@ -248,6 +248,7 @@ def test_fused_moe_batched_experts(
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: list[int] | None, block_shape: list[int] | None,
input_scales: bool, input_scales: bool,
workspace_init,
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89, """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware.""" and those tests will be skipped on unsupported hardware."""
......
...@@ -137,7 +137,7 @@ def setup_cuda(): ...@@ -137,7 +137,7 @@ def setup_cuda():
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_fused_moe( def test_w8a8_block_fp8_fused_moe(
M, N, K, E, topk, block_size, dtype, seed, monkeypatch M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
): ):
if topk > E: if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}") pytest.skip(f"Skipping test; topk={topk} > E={E}")
......
...@@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph(
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
monkeypatch, monkeypatch,
workspace_init,
ep_size: int | None = None, ep_size: int | None = None,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
monkeypatch, monkeypatch,
workspace_init,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
...@@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP( ...@@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP(
per_out_channel: bool, per_out_channel: bool,
ep_size: int, ep_size: int,
monkeypatch, monkeypatch,
workspace_init,
): ):
test_cutlass_moe_8_bit_no_graph( test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size m,
n,
k,
e,
topk,
per_act_token,
per_out_channel,
monkeypatch,
workspace_init,
ep_size,
) )
...@@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large( ...@@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large(
per_out_channel: bool, per_out_channel: bool,
ep_size: int, ep_size: int,
monkeypatch, monkeypatch,
workspace_init,
): ):
test_cutlass_moe_8_bit_no_graph( test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size m,
n,
k,
e,
topk,
per_act_token,
per_out_channel,
monkeypatch,
workspace_init,
ep_size,
) )
...@@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8( ...@@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8(
per_act_token: bool, per_act_token: bool,
per_out_channel: bool, per_out_channel: bool,
ep_size: int, ep_size: int,
workspace_init,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
......
...@@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import ( ...@@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import (
is_deep_gemm_supported, is_deep_gemm_supported,
) )
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
...@@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe( ...@@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
): ):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
current_platform.seed_everything(pgi.rank) current_platform.seed_everything(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device()) w1 = w1.to(device=torch.cuda.current_device())
...@@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe( ...@@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe(
topk: int, topk: int,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0, disable_deepgemm_ue8m0,
workspace_init,
): ):
""" """
Tests for High-Throughput DeepEP + DeepGemm integration. Tests for High-Throughput DeepEP + DeepGemm integration.
...@@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe( ...@@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe(
block_size: list[int], block_size: list[int],
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0, disable_deepgemm_ue8m0,
workspace_init,
): ):
""" """
Tests for Low-Latency DeepEP + DeepGemm integration. Tests for Low-Latency DeepEP + DeepGemm integration.
......
...@@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep from vllm.utils.import_utils import has_deep_ep
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
...@@ -342,6 +343,9 @@ def _deep_ep_moe( ...@@ -342,6 +343,9 @@ def _deep_ep_moe(
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
per_act_token_quant: bool, per_act_token_quant: bool,
): ):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
if not low_latency_mode: if not low_latency_mode:
assert not use_fp8_dispatch, ( assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode" "FP8 dispatch interface is available only in low-latency mode"
...@@ -437,6 +441,7 @@ def test_deep_ep_moe( ...@@ -437,6 +441,7 @@ def test_deep_ep_moe(
topk: int, topk: int,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
per_act_token_quant: bool, per_act_token_quant: bool,
workspace_init,
): ):
low_latency_mode = False low_latency_mode = False
use_fp8_dispatch = False use_fp8_dispatch = False
...@@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe( ...@@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe(
topk: int, topk: int,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
workspace_init,
): ):
low_latency_mode = True low_latency_mode = True
......
...@@ -143,7 +143,7 @@ NUM_EXPERTS = [32] ...@@ -143,7 +143,7 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") @pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
with monkeypatch.context() as mp: with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1") mp.setenv("VLLM_USE_DEEP_GEMM", "1")
......
...@@ -206,6 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -206,6 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk: int, topk: int,
activation: str, activation: str,
monkeypatch, monkeypatch,
workspace_init,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
......
...@@ -51,7 +51,14 @@ MNK_FACTORS = [ ...@@ -51,7 +51,14 @@ MNK_FACTORS = [
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"]) @pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@torch.inference_mode() @torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph( def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
activation: str,
workspace_init,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(
......
...@@ -269,7 +269,7 @@ class Case: ...@@ -269,7 +269,7 @@ class Case:
) )
@pytest.mark.parametrize("num_token", [2]) @pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8]) @pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp): def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
from triton_kernels.tensor_details import layout from triton_kernels.tensor_details import layout
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"): if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
......
...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform ...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.worker.workspace import init_workspace_manager
from .modular_kernel_tools.common import ( from .modular_kernel_tools.common import (
Config, Config,
...@@ -77,6 +78,10 @@ def rank_worker( ...@@ -77,6 +78,10 @@ def rank_worker(
weights: WeightTensors, weights: WeightTensors,
verbose: bool, verbose: bool,
): ):
# Initialize workspace manager in child process
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
current_platform.seed_everything(pgi.rank) current_platform.seed_everything(pgi.rank)
# sanity check # sanity check
...@@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu(
chunk_size: int | None, chunk_size: int | None,
world_size: int, world_size: int,
pytestconfig, pytestconfig,
workspace_init,
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89, """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware.""" and those tests will be skipped on unsupported hardware."""
......
...@@ -209,6 +209,7 @@ def test_oai_triton_moe( ...@@ -209,6 +209,7 @@ def test_oai_triton_moe(
num_experts: int, num_experts: int,
topk: int, topk: int,
unfused: bool, unfused: bool,
workspace_init,
): ):
current_platform.seed_everything(0) current_platform.seed_everything(0)
( (
......
...@@ -231,6 +231,7 @@ def test_fused_moe( ...@@ -231,6 +231,7 @@ def test_fused_moe(
padding: bool, padding: bool,
chunk_size: int, chunk_size: int,
monkeypatch, monkeypatch,
workspace_init,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
......
...@@ -40,7 +40,7 @@ MNK_FACTORS = [ ...@@ -40,7 +40,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
def test_cutlass_fp4_moe_no_graph( def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(
......
...@@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
...@@ -181,6 +182,7 @@ def test_fused_moe_batched_experts( ...@@ -181,6 +182,7 @@ def test_fused_moe_batched_experts(
e: int, e: int,
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
workspace_init,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -863,6 +865,9 @@ def _pplx_test_loop( ...@@ -863,6 +865,9 @@ def _pplx_test_loop(
make_weights: bool, make_weights: bool,
test_fn: Callable, test_fn: Callable,
): ):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
def format_result(msg, ex=None): def format_result(msg, ex=None):
if ex is not None: if ex is not None:
x = str(ex) x = str(ex)
......
...@@ -22,10 +22,14 @@ from tests.v1.attention.utils import ( ...@@ -22,10 +22,14 @@ from tests.v1.attention.utils import (
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla from vllm.attention.ops import flashmla
from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend from vllm.v1.attention.backends.mla.flashmla_sparse import (
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = { SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name] name: BATCH_SPECS[name]
...@@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla( ...@@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla(
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_sparse_backend_decode_correctness( def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
): ):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test") pytest.skip("CUDA is required for sparse MLA decode test")
...@@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness( ...@@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness(
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls() impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls( with set_current_vllm_config(vllm_config):
num_heads=num_heads, impl = impl_cls(
head_size=head_size, num_heads=num_heads,
scale=scale, head_size=head_size,
num_kv_heads=1, scale=scale,
alibi_slopes=None, num_kv_heads=1,
sliding_window=None, alibi_slopes=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype, sliding_window=None,
logits_soft_cap=None, kv_cache_dtype=vllm_config.cache_config.cache_dtype,
attn_type="decoder", logits_soft_cap=None,
kv_sharing_target_layer_name=None, attn_type="decoder",
q_lora_rank=None, kv_sharing_target_layer_name=None,
kv_lora_rank=kv_lora_rank, q_lora_rank=None,
qk_nope_head_dim=qk_nope_head_dim, kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim, qk_nope_head_dim=qk_nope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim, qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
kv_b_proj=mock_kv_b_proj, v_head_dim=v_head_dim,
indexer=mock_indexer, kv_b_proj=mock_kv_b_proj,
) indexer=mock_indexer,
)
impl.process_weights_after_loading(dtype) impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device) layer = MockAttentionLayer(device)
out_buffer = torch.empty( out_buffer = torch.empty(
...@@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness( ...@@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness(
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5) torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
def _triton_convert_reference_impl(
req_ids: torch.Tensor,
block_table: torch.Tensor,
token_indices: torch.Tensor,
block_size: int,
num_topk_tokens: int,
HAS_PREFILL_WORKSPACE: bool = False,
prefill_workspace_request_ids: torch.Tensor | None = None,
prefill_workspace_starts: torch.Tensor | None = None,
) -> torch.Tensor:
"""Reference implementation for triton_convert_req_index_to_global_index."""
num_tokens = req_ids.shape[0]
max_blocks_per_req = block_table.shape[1]
result = torch.empty(
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
)
for token_id in range(num_tokens):
req_id = req_ids[token_id].item()
# Determine if this token uses workspace or paged cache
use_prefill_workspace = False
workspace_start = 0
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
assert prefill_workspace_starts is not None
prefill_req_id = prefill_workspace_request_ids[token_id].item()
if prefill_req_id >= 0:
use_prefill_workspace = True
workspace_start = prefill_workspace_starts[prefill_req_id].item()
for idx_id in range(num_topk_tokens):
token_idx = token_indices[token_id, idx_id].item()
if token_idx == -1:
result[token_id, idx_id] = -1
elif use_prefill_workspace:
# Prefill + using prefill workspace: map to workspace offset
result[token_id, idx_id] = workspace_start + token_idx
else:
# Decode: map to paged cache
block_id = token_idx // block_size
if block_id >= max_blocks_per_req:
result[token_id, idx_id] = -1
else:
block_num = block_table[req_id, block_id].item()
offset = token_idx % block_size
result[token_id, idx_id] = block_num * block_size + offset
return result
@pytest.mark.parametrize("block_size", [16, 64, 128])
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_decode_only(
block_size, num_topk_tokens
):
device = torch.device("cuda")
num_tokens = 8
num_requests = 4
max_blocks_per_req = 10
req_id = torch.randint(
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
)
block_table = torch.randint(
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
)
token_indices = torch.randint(
0,
block_size * max_blocks_per_req,
(num_tokens, num_topk_tokens),
dtype=torch.int32,
device=device,
)
# Set some to -1 to test masking
token_indices[0, :10] = -1
token_indices[3, 50:60] = -1
# Set some to out of bounds
token_indices[2, 100:110] = max_blocks_per_req * block_size
token_indices[6, 150:160] = max_blocks_per_req * block_size
result = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
)
reference_result = _triton_convert_reference_impl(
req_id,
block_table,
token_indices,
block_size,
num_topk_tokens,
)
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
device = torch.device("cuda")
num_requests = 4
max_blocks_per_req = 8
num_topk_tokens = 128
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
req_id = torch.tensor(
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
)
prefill_workspace_request_ids = torch.tensor(
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
)
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
block_table = torch.randint(
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
)
token_indices = torch.randint(
0,
block_size * max_blocks_per_req,
(req_id.shape[0], num_topk_tokens),
dtype=torch.int32,
device=device,
)
# Set some to -1 to test masking
token_indices[0, :10] = -1
token_indices[3, 50:60] = -1
# Set some to out of bounds
token_indices[2, 100:110] = max_blocks_per_req * block_size
token_indices[6, 150:160] = max_blocks_per_req * block_size
result = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
HAS_PREFILL_WORKSPACE=True,
prefill_workspace_request_ids=prefill_workspace_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
reference_result = _triton_convert_reference_impl(
req_id,
block_table,
token_indices,
block_size,
num_topk_tokens,
HAS_PREFILL_WORKSPACE=True,
prefill_workspace_request_ids=prefill_workspace_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seq_lens,max_buf,start,expected", "seq_lens,max_buf,expected",
[ [
# Basic split: totals per chunk ≤ max_buf # Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]), (torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index # Exact fits should split between items when adding the next would overflow
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]), (torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk # All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]), (torch.tensor([1, 1, 1]), 10, [(0, 3)]),
# Large buffer with non-zero start # Large buffer
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]), (torch.tensor([4, 4, 4]), 100, [(0, 3)]),
], ],
) )
def test_split_prefill_chunks(seq_lens, max_buf, start, expected): def test_split_prefill_chunks(seq_lens, max_buf, expected):
out = split_prefill_chunks(seq_lens, max_buf, start) out = split_prefill_chunks(seq_lens, max_buf)
assert out == expected assert out == expected
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment