Unverified Commit e0b2d3ee authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Feature] Add a fast-topk to sgl-kernel for DeepSeek v3.2 (#11194)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent 4cb5a523
...@@ -268,6 +268,7 @@ set(SOURCES ...@@ -268,6 +268,7 @@ set(SOURCES
"csrc/elementwise/concat_mla.cu" "csrc/elementwise/concat_mla.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu" "csrc/elementwise/rope.cu"
"csrc/elementwise/topk.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
"csrc/gemm/awq_kernel.cu" "csrc/gemm/awq_kernel.cu"
......
...@@ -107,6 +107,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -107,6 +107,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()"); m.def("concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()");
m.impl("concat_mla_absorb_q", torch::kCUDA, &concat_mla_absorb_q); m.impl("concat_mla_absorb_q", torch::kCUDA, &concat_mla_absorb_q);
m.def("fast_topk(Tensor score, Tensor indices, Tensor lengths) -> ()");
m.impl("fast_topk", torch::kCUDA, &fast_topk_interface);
m.def(
"fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor "
"cu_seqlens_q) -> ()");
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
......
/**
* @NOTE: This file is adapted from
* https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py
* We:
* 1. adapt from tilelang to pure cuda
* 2. optimize the performance a little
* 3. fix the potential illegal memory access
*/
#include <ATen/core/TensorBase.h>
#include <ATen/core/TensorBody.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cstddef>
#include <cstdint>
#include <optional>
namespace {
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
struct FastTopKParams {
const float* __restrict__ input; // [B, input_stride]
int32_t* __restrict__ indices; // [B, TopK]
int32_t* __restrict__ lengths; // [B]
int64_t input_stride;
};
// when length <= TopK, we can directly write the indices
__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) {
const auto tid = threadIdx.x;
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
indice[i] = (i < length) ? i : -1;
}
}
// keep the first `length` entries, set others to -1
__device__ void naive_topk_transform(
const float* __restrict__ score,
int32_t length,
int32_t* __restrict__ dst_page_table,
const int32_t* __restrict__ src_page_table) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
}
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits) : static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int length) {
// An optimized topk kernel copied from tilelang kernel
// We assume length > TopK here, or it will crash
int topk = TopK;
constexpr auto BLOCK_SIZE = 1024;
constexpr auto RADIX = 256;
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128];
alignas(128) __shared__ int s_counter;
alignas(128) __shared__ int s_threshold_bin_id;
alignas(128) __shared__ int s_num_input[2];
auto& s_histogram = s_histogram_buf[0];
// allocate for two rounds
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
const int tx = threadIdx.x;
// stage 1: 8bit coarse histogram
if (tx < RADIX + 1) s_histogram[tx] = 0;
__syncthreads();
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin = convert_to_uint8(input[idx]);
::atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
const auto run_cumsum = [&] {
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
static_assert(1 << 8 == RADIX);
if (C10_LIKELY(tx < RADIX)) {
const auto j = 1 << i;
const auto k = i & 1;
auto value = s_histogram_buf[k][tx];
if (tx < RADIX - j) {
value += s_histogram_buf[k][tx + j];
}
s_histogram_buf[k ^ 1][tx] = value;
}
__syncthreads();
}
};
run_cumsum();
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
s_threshold_bin_id = tx;
s_num_input[0] = 0;
s_counter = 0;
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
topk -= s_histogram[threshold_bin + 1];
if (topk == 0) {
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin = static_cast<int>(convert_to_uint8(input[idx]));
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
}
}
__syncthreads();
return;
} else {
__syncthreads();
if (tx < RADIX + 1) {
s_histogram[tx] = 0;
}
__syncthreads();
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto raw_input = input[idx];
const auto bin = static_cast<int>(convert_to_uint8(raw_input));
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin == threshold_bin) {
const auto pos = ::atomicAdd(&s_num_input[0], 1);
/// NOTE: (dark) fuse the histogram computation here
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
s_input_idx[0][pos] = idx;
const auto bin = convert_to_uint32(raw_input);
const auto sub_bin = (bin >> 24) & 0xFF;
::atomicAdd(&s_histogram[sub_bin], 1);
}
}
}
__syncthreads();
}
// stage 2: refine with 8bit radix passes
#pragma unroll 4
for (int round = 0; round < 4; ++round) {
__shared__ int s_last_remain;
const auto r_idx = round % 2;
// clip here to prevent overflow
const auto _raw_num_input = s_num_input[r_idx];
const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE);
run_cumsum();
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
s_threshold_bin_id = tx;
s_num_input[r_idx ^ 1] = 0;
s_last_remain = topk - s_histogram[tx + 1];
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
topk -= s_histogram[threshold_bin + 1];
if (topk == 0) {
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto offset = 24 - round * 8;
const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF;
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
}
}
__syncthreads();
break;
} else {
__syncthreads();
if (tx < RADIX + 1) {
s_histogram[tx] = 0;
}
__syncthreads();
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto raw_input = input[idx];
const auto offset = 24 - round * 8;
const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF;
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin == threshold_bin) {
if (round == 3) {
const auto pos = ::atomicAdd(&s_last_remain, -1);
if (pos > 0) {
index[TopK - pos] = idx;
}
} else {
const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
/// NOTE: (dark) fuse the histogram computation here
s_input_idx[r_idx ^ 1][pos] = idx;
const auto bin = convert_to_uint32(raw_input);
const auto sub_bin = (bin >> (offset - 8)) & 0xFF;
::atomicAdd(&s_histogram[sub_bin], 1);
}
}
}
}
__syncthreads();
}
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // topk
void topk_kernel(const FastTopKParams params) {
const auto& [input, indices, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto length = lengths[bid];
const auto indice = indices + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_cuda(score, indice, length);
} else {
return fast_topk_cuda_tl(score, indice, length);
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // decode
void topk_transform_decode_kernel(
const FastTopKParams params,
int32_t* __restrict__ dst_page_table,
const int32_t* __restrict__ src_page_table,
const int64_t src_stride) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto src_page_entry = src_page_table + bid * src_stride;
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // prefill
void topk_transform_prefill_kernel(
const FastTopKParams params,
int32_t* __restrict__ dst_page_table,
const int32_t* __restrict__ src_page_table,
const int64_t src_stride,
const int32_t* __restrict__ cu_seqlens_q,
const int64_t prefill_bs) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
/// NOTE: prefill bs is usually small, we can just use a simple loop here
/// We ensure that last cu_seqlens is equal to number of blocks launched
__shared__ const int32_t* s_src_page_entry;
if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) {
if (tid < prefill_bs) {
if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) {
s_src_page_entry = src_page_table + tid * src_stride;
}
}
} else {
for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) {
if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) {
s_src_page_entry = src_page_table + i * src_stride;
}
}
}
__syncthreads();
const auto src_page_entry = s_src_page_entry;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
auto get_params(at::Tensor score, at::Tensor lengths, std::optional<at::Tensor> indices_opt = std::nullopt)
-> FastTopKParams {
const auto B = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
TORCH_CHECK(lengths.size(0) == B);
int32_t* indices_data_ptr = nullptr;
if (indices_opt.has_value()) {
const auto& indices = indices_opt.value();
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
TORCH_CHECK(indices.size(0) == B);
TORCH_CHECK(indices.size(1) == TopK);
indices_data_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.indices = indices_data_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
};
}
template <auto* f, size_t max_dynamic_smem>
void setup_kernel_smem_once() {
[[maybe_unused]]
static const auto result =
[] { return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); }();
TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result));
}
} // namespace
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) {
CHECK_CUDA(score);
CHECK_CUDA(indices);
CHECK_CUDA(lengths);
const auto params = get_params(score, lengths, indices);
const auto B = score.size(0);
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_kernel, kSmem>();
topk_kernel<<<grid, block, kSmem, stream>>>(params);
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}
void fast_topk_transform_interface(
at::Tensor score,
at::Tensor lengths,
at::Tensor dst_page_table,
at::Tensor src_page_table,
at::Tensor cu_seqlens_q) {
CHECK_CUDA(score);
CHECK_CUDA(lengths);
CHECK_CUDA(dst_page_table);
CHECK_CUDA(src_page_table);
CHECK_CUDA(cu_seqlens_q);
const auto params = get_params(score, lengths);
const auto B = score.size(0);
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous());
const auto prefill_bs = cu_seqlens_q.size(0) - 1;
TORCH_CHECK(dst_page_table.size(0) == B);
TORCH_CHECK(dst_page_table.size(1) == TopK);
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
const auto src_stride = src_page_table.stride(0);
// dispatch to decode or prefill
const auto is_decode = (prefill_bs == B);
if (is_decode) {
setup_kernel_smem_once<topk_transform_decode_kernel, kSmem>();
topk_transform_decode_kernel<<<grid, block, kSmem, stream>>>(
params, dst_page_table.data_ptr<int32_t>(), src_page_table.data_ptr<int32_t>(), src_stride);
} else {
setup_kernel_smem_once<topk_transform_prefill_kernel, kSmem>();
topk_transform_prefill_kernel<<<grid, block, kSmem, stream>>>(
params,
dst_page_table.data_ptr<int32_t>(),
src_page_table.data_ptr<int32_t>(),
src_stride,
cu_seqlens_q.data_ptr<int32_t>(),
prefill_bs);
}
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}
...@@ -174,6 +174,14 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); ...@@ -174,6 +174,14 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths);
void fast_topk_transform_interface(
at::Tensor score,
at::Tensor lengths,
at::Tensor dst_page_table,
at::Tensor src_page_table,
at::Tensor cu_seqlens_q);
#ifdef USE_ROCM #ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input); void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif #endif
......
...@@ -309,7 +309,7 @@ from sgl_kernel.speculative import ( ...@@ -309,7 +309,7 @@ from sgl_kernel.speculative import (
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy, verify_tree_greedy,
) )
from sgl_kernel.top_k import fast_topk from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2
from sgl_kernel.version import __version__ from sgl_kernel.version import __version__
if torch.version.hip is not None: if torch.version.hip is not None:
......
...@@ -9,3 +9,32 @@ def fast_topk(values, topk, dim): ...@@ -9,3 +9,32 @@ def fast_topk(values, topk, dim):
# Use topk for efficiency with larger k values # Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes # TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim) return torch.topk(values, topk, dim=dim)
def fast_topk_v2(score: torch.Tensor, lengths: torch.Tensor, topk: int) -> torch.Tensor:
assert (
topk == 2048
), "fast_topk_v2 is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk(score, topk_indices, lengths)
return topk_indices
def fast_topk_transform_fused(
score: torch.Tensor,
lengths: torch.Tensor,
page_table_size_1: torch.Tensor, # NOTE: page size should be 1
cu_seqlens_q: torch.Tensor,
topk: int,
) -> torch.Tensor:
assert (
topk == 2048
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
src_page_table = page_table_size_1
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_fused(
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
)
return dst_page_table
import pytest
import torch
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor:
assert score.dim() == 2
return torch.topk(score[:, :seq_len], topk, dim=-1, sorted=False).indices
def _ref_torch_transform_decode_impl(
score: torch.Tensor,
seq_len: int,
src_page_table: torch.Tensor,
topk: int,
) -> torch.Tensor:
batch_size, _ = score.shape
assert score.shape[0] == src_page_table.shape[0]
assert seq_len >= topk
indices = _ref_torch_impl(score, seq_len, topk)
topk_indices = torch.empty(
(batch_size, topk), dtype=torch.int32, device=score.device
)
for i in range(batch_size):
topk_indices[i] = src_page_table[i, indices[i]]
return topk_indices
MAX_SEQ_LEN = 131072
MAX_PERMIT_ERROR = 0
def assert_equal(
score: torch.Tensor,
indices_ref: torch.Tensor,
indices_our: torch.Tensor,
bs: int,
k: int,
seq_len: int,
):
indices_our_cpu = indices_our.cpu().tolist()
indices_ref_cpu = indices_ref.cpu().tolist()
for i in range(bs):
indices_ref_set_i = set(indices_ref_cpu[i])
indices_our_set_i = set(indices_our_cpu[i])
more = indices_our_set_i - indices_ref_set_i
less = indices_ref_set_i - indices_our_set_i
if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR:
# check whether more values are the same with less values
# if so, either one is acceptable, since their values are the same
more_values = sorted(score[i, idx].item() for idx in more)
less_values = sorted(score[i, idx].item() for idx in less)
assert (
more_values == less_values
), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
@torch.inference_mode()
def test_topk_kernel(bs: int, k: int, seq_len: int) -> None:
torch.manual_seed(42)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
indices_ref = _ref_torch_impl(score, seq_len, k)
indices_our = fast_topk_v2(score, lengths, k)
# sort and compare
indices_ref = torch.sort(indices_ref, dim=-1).values
indices_our = torch.sort(indices_our, dim=-1).values
assert_equal(score, indices_ref, indices_our, bs, k, seq_len)
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
@torch.inference_mode()
def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None:
# TODO(dark): test prefill kernel, though nothing special
MAX_PERMIT_ERROR = 1
torch.manual_seed(42)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
src_page_table = torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
src_page_table = src_page_table.unsqueeze(0).expand(bs, -1)
# NOTE: for decode, cumulative seqlens_q is just 0..=bs
# NOTE: since page table is arange, they equal topk indices
cu_seqlens_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda")
dst_page_table_ref = _ref_torch_transform_decode_impl(
score=score,
seq_len=seq_len,
src_page_table=src_page_table,
topk=k,
)
dst_page_table_our = fast_topk_transform_fused(
score=score,
lengths=lengths,
page_table_size_1=src_page_table,
cu_seqlens_q=cu_seqlens_q,
topk=k,
)
# sort and compare
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment