Unverified Commit 3b80232d authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
parent 252dc4e1
...@@ -113,6 +113,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -113,6 +113,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor " "fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor "
"cu_seqlens_q) -> ()"); "cu_seqlens_q) -> ()");
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface); m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
m.def(
"fast_topk_transform_ragged_fused(Tensor score, Tensor lengths, Tensor topk_indices_ragged, Tensor "
"topk_indices_offset) -> ()");
m.impl("fast_topk_transform_ragged_fused", torch::kCUDA, &fast_topk_transform_ragged_interface);
/* /*
* From gguf quantiztion * From gguf quantiztion
......
...@@ -51,6 +51,15 @@ __device__ void naive_topk_transform( ...@@ -51,6 +51,15 @@ __device__ void naive_topk_transform(
} }
} }
// keep the first `length` entries, set others to -1
__device__ void naive_topk_transform_ragged(
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
}
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x); __half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h); uint16_t bits = __half_as_ushort(h);
...@@ -322,8 +331,40 @@ __global__ __launch_bounds__(kThreadsPerBlock) // prefill ...@@ -322,8 +331,40 @@ __global__ __launch_bounds__(kThreadsPerBlock) // prefill
} }
} }
auto get_params(at::Tensor score, at::Tensor lengths, std::optional<at::Tensor> indices_opt = std::nullopt) __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv
-> FastTopKParams { void topk_transform_prefill_ragged_kernel(
const FastTopKParams params,
int32_t* __restrict__ topk_indices_ragged,
const int32_t* __restrict__ topk_indices_offset) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto dst_indices_entry = topk_indices_ragged + bid * TopK;
const auto score = input + bid * input_stride;
const auto offset = topk_indices_offset[bid];
if (length <= TopK) {
return naive_topk_transform_ragged(score, length, dst_indices_entry, offset);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_indices_entry[idx_0] = pos_0 + offset;
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_indices_entry[idx_1] = pos_1 + offset;
}
}
auto get_params(
const at::Tensor& score,
const at::Tensor& lengths,
std::optional<at::Tensor> indices_opt = std::nullopt) -> FastTopKParams {
const auto B = score.size(0); const auto B = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
...@@ -357,7 +398,7 @@ void setup_kernel_smem_once() { ...@@ -357,7 +398,7 @@ void setup_kernel_smem_once() {
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) { void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths) {
CHECK_CUDA(score); CHECK_CUDA(score);
CHECK_CUDA(indices); CHECK_CUDA(indices);
CHECK_CUDA(lengths); CHECK_CUDA(lengths);
...@@ -373,11 +414,11 @@ void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor length ...@@ -373,11 +414,11 @@ void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor length
} }
void fast_topk_transform_interface( void fast_topk_transform_interface(
at::Tensor score, const at::Tensor& score,
at::Tensor lengths, const at::Tensor& lengths,
at::Tensor dst_page_table, at::Tensor& dst_page_table,
at::Tensor src_page_table, const at::Tensor& src_page_table,
at::Tensor cu_seqlens_q) { const at::Tensor& cu_seqlens_q) {
CHECK_CUDA(score); CHECK_CUDA(score);
CHECK_CUDA(lengths); CHECK_CUDA(lengths);
CHECK_CUDA(dst_page_table); CHECK_CUDA(dst_page_table);
...@@ -420,3 +461,35 @@ void fast_topk_transform_interface( ...@@ -420,3 +461,35 @@ void fast_topk_transform_interface(
const auto result = cudaGetLastError(); const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
} }
void fast_topk_transform_ragged_interface(
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& topk_indices_ragged,
const at::Tensor& topk_indices_offset) {
CHECK_CUDA(score);
CHECK_CUDA(lengths);
CHECK_CUDA(topk_indices_ragged);
CHECK_CUDA(topk_indices_offset);
const auto params = get_params(score, lengths);
const auto B = score.size(0);
TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous());
TORCH_CHECK(topk_indices_offset.dim() == 1);
TORCH_CHECK(topk_indices_ragged.size(0) == B);
TORCH_CHECK(topk_indices_ragged.size(1) == TopK);
TORCH_CHECK(topk_indices_offset.size(0) == B);
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_transform_prefill_ragged_kernel, kSmem>();
topk_transform_prefill_ragged_kernel<<<grid, block, kSmem, stream>>>(
params, topk_indices_ragged.data_ptr<int32_t>(), topk_indices_offset.data_ptr<int32_t>());
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}
...@@ -174,13 +174,18 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); ...@@ -174,13 +174,18 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths); void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths);
void fast_topk_transform_interface( void fast_topk_transform_interface(
at::Tensor score, const at::Tensor& score,
at::Tensor lengths, const at::Tensor& lengths,
at::Tensor dst_page_table, at::Tensor& dst_page_table,
at::Tensor src_page_table, const at::Tensor& src_page_table,
at::Tensor cu_seqlens_q); const at::Tensor& cu_seqlens_q);
void fast_topk_transform_ragged_interface(
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& topk_indices_ragged,
const at::Tensor& topk_indices_offset);
#ifdef USE_ROCM #ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input); void gelu_quick(at::Tensor& out, const at::Tensor& input);
......
...@@ -327,7 +327,12 @@ from sgl_kernel.speculative import ( ...@@ -327,7 +327,12 @@ from sgl_kernel.speculative import (
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy, verify_tree_greedy,
) )
from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2 from sgl_kernel.top_k import (
fast_topk,
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)
from sgl_kernel.version import __version__ from sgl_kernel.version import __version__
if torch.version.hip is not None: if torch.version.hip is not None:
......
...@@ -28,13 +28,36 @@ def fast_topk_transform_fused( ...@@ -28,13 +28,36 @@ def fast_topk_transform_fused(
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor,
topk: int, topk: int,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Transform topk indices to indices to the page table (page_size = 1)
"""
assert ( assert (
topk == 2048 topk == 2048
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048" ), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2 assert score.dim() == 2
src_page_table = page_table_size_1 src_page_table = page_table_size_1
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32) dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_fused( torch.ops.sgl_kernel.fast_topk_transform_fused(
score, lengths, dst_page_table, src_page_table, cu_seqlens_q score, lengths, dst_page_table, src_page_table, cu_seqlens_q
) )
return dst_page_table return dst_page_table
def fast_topk_transform_ragged_fused(
score: torch.Tensor,
lengths: torch.Tensor,
topk_indices_offset: torch.Tensor, # ragged kv
topk: int,
) -> torch.Tensor:
"""
Transform topk indices to indices to ragged kv (non-paged)
"""
assert (
topk == 2048
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
score, lengths, topk_indices_ragged, topk_indices_offset
)
return topk_indices_ragged
from typing import Optional
import pytest import pytest
import torch import torch
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2 from sgl_kernel import (
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)
def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor: def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor:
...@@ -26,6 +32,21 @@ def _ref_torch_transform_decode_impl( ...@@ -26,6 +32,21 @@ def _ref_torch_transform_decode_impl(
return topk_indices return topk_indices
def _ref_torch_transform_ragged_impl(
score: torch.Tensor,
seq_len: int,
topk_indices_offset: torch.Tensor,
topk: int,
) -> torch.Tensor:
assert score.shape[0] == topk_indices_offset.shape[0]
assert seq_len >= topk
indices = _ref_torch_impl(score, seq_len, topk)
mask = indices != -1
topk_indices_offset = topk_indices_offset.unsqueeze(1)
return torch.where(mask, indices + topk_indices_offset, indices)
MAX_SEQ_LEN = 131072 MAX_SEQ_LEN = 131072
MAX_PERMIT_ERROR = 0 MAX_PERMIT_ERROR = 0
...@@ -37,6 +58,7 @@ def assert_equal( ...@@ -37,6 +58,7 @@ def assert_equal(
bs: int, bs: int,
k: int, k: int,
seq_len: int, seq_len: int,
topk_indices_offset: Optional[torch.Tensor] = None,
): ):
indices_our_cpu = indices_our.cpu().tolist() indices_our_cpu = indices_our.cpu().tolist()
indices_ref_cpu = indices_ref.cpu().tolist() indices_ref_cpu = indices_ref.cpu().tolist()
...@@ -45,11 +67,13 @@ def assert_equal( ...@@ -45,11 +67,13 @@ def assert_equal(
indices_our_set_i = set(indices_our_cpu[i]) indices_our_set_i = set(indices_our_cpu[i])
more = indices_our_set_i - indices_ref_set_i more = indices_our_set_i - indices_ref_set_i
less = indices_ref_set_i - indices_our_set_i less = indices_ref_set_i - indices_our_set_i
if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR: offset = topk_indices_offset[i].item() if topk_indices_offset is not None else 0
if len(more) > 0 or len(less) > 0:
print(f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=}")
# check whether more values are the same with less values # check whether more values are the same with less values
# if so, either one is acceptable, since their values are the same # if so, either one is acceptable, since their values are the same
more_values = sorted(score[i, idx].item() for idx in more) more_values = sorted(score[i, idx - offset].item() for idx in more)
less_values = sorted(score[i, idx].item() for idx in less) less_values = sorted(score[i, idx - offset].item() for idx in less)
assert ( assert (
more_values == less_values more_values == less_values
), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}" ), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
...@@ -116,5 +140,52 @@ def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None: ...@@ -116,5 +140,52 @@ def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None:
assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len) assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len)
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
@torch.inference_mode()
def test_topk_transform_ragged_kernel(bs: int, k: int, seq_len: int) -> None:
# TODO(dark): test prefill kernel, though nothing special
MAX_PERMIT_ERROR = 1
torch.manual_seed(42)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
# bs: # of q tokens
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
# kv_len
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
topk_indices_offset = torch.randint(
0, 1024, (bs,), dtype=torch.int32, device="cuda"
)
dst_page_table_ref = _ref_torch_transform_ragged_impl(
score=score,
seq_len=seq_len,
topk_indices_offset=topk_indices_offset,
topk=k,
)
dst_page_table_our = fast_topk_transform_ragged_fused(
score=score,
lengths=lengths,
topk_indices_offset=topk_indices_offset,
topk=k,
)
# sort and compare
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
assert_equal(
score,
dst_page_table_ref,
dst_page_table_our,
bs,
k,
seq_len,
topk_indices_offset,
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment