#include #include #include #include #ifndef USE_ROCM #define WARP_SIZE 64 #include "pytorch_extension_utils.h" #else #include "pytorch_extension_utils_rocm.h" #include "utils.h" // WARP_SIZE #endif __device__ __forceinline__ void transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { const uint64_t* __restrict__ src = static_cast(src_addr); uint64_t* __restrict__ dst = static_cast(dst_addr); const int total_chunks = item_size_bytes / sizeof(uint64_t); #pragma unroll for (int j = lane_id; j < total_chunks; j += WARP_SIZE) { #ifndef USE_ROCM uint64_t tmp; asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory"); asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory"); #else uint64_t tmp = __builtin_nontemporal_load(src + j); __builtin_nontemporal_store(tmp, dst + j); #endif } } template __device__ __forceinline__ T* get_global_offset_lf( T* base, const uintptr_t* __restrict__ /*unused*/, int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) { // layer first return base + layer_id * layer_dim + page_id * item_size_bytes; } template __device__ __forceinline__ T* get_global_offset_pf( T* base, const uintptr_t* __restrict__ /*unused*/, int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) { // page first return base + page_id * page_dim + layer_id * item_size_bytes; } // get offset from layer base table when layers are not contiguous template __device__ __forceinline__ T* get_global_offset_lf_tbl( T* /*unused*/, const uintptr_t* __restrict__ layer_base_tbl, int64_t layer_id, int64_t /*unused*/, int64_t page_id, int64_t item_size_bytes) { return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes; } template __global__ void transfer_kernel_impl( const void* __restrict__ src_k, void* __restrict__ dst_k, const void* __restrict__ src_v, void* __restrict__ dst_v, const int64_t* __restrict__ src_indices, const int64_t* __restrict__ dst_indices, int64_t start_layer_id, int64_t num_layers_to_process, int64_t num_items, int64_t items_per_warp, int64_t item_size_bytes, int64_t src_layout_dim, int64_t dst_layout_dim, const uintptr_t* __restrict__ src_k_layer_tbl, const uintptr_t* __restrict__ dst_k_layer_tbl, const uintptr_t* __restrict__ src_v_layer_tbl, const uintptr_t* __restrict__ dst_v_layer_tbl) { int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; int32_t lane_id = tid % WARP_SIZE; int32_t warp_id = tid / WARP_SIZE; for (int i = 0; i < items_per_warp; ++i) { int64_t item_id = warp_id * items_per_warp + i; if (item_id >= num_items) { break; } const int64_t src_page_id = src_indices[item_id]; const int64_t dst_page_id = dst_indices[item_id]; // Loop over layers if necessary for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { const char* src_ptr = SrcOffsetFn( static_cast(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); char* dst_ptr = DstOffsetFn( static_cast(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes); if constexpr (!IsMLA) { const char* src_v_ptr = SrcOffsetFn( static_cast(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); char* dst_v_ptr = DstOffsetFn( static_cast(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes); } } } } template void transfer_kv_launcher( const at::Tensor& src_k, at::Tensor& dst_k, const at::Tensor& src_v, at::Tensor& dst_v, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t start_layer_id, int64_t num_layers_to_process, int64_t item_size, int64_t src_layout_dim, int64_t dst_layout_dim, const at::Tensor& src_k_layers, const at::Tensor& dst_k_layers, const at::Tensor& src_v_layers, const at::Tensor& dst_v_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8"); auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; }; const int64_t num_items = src_indices.numel(); const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); dim3 grid_dim(num_blocks, 1, 1); const int32_t threads_per_block = num_warps_per_block * WARP_SIZE; const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr(); void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr(); const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr() : nullptr; const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr() : nullptr; const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr(); const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); transfer_kernel_impl<<>>( src_k_ptr, dst_k_ptr, src_v_ptr, dst_v_ptr, src_indices.data_ptr(), dst_indices.data_ptr(), start_layer_id, num_layers_to_process, num_items, items_per_warp, item_size, src_layout_dim, dst_layout_dim, src_k_tbl_ptr, dst_k_tbl_ptr, src_v_tbl_ptr, dst_v_tbl_ptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); } void transfer_kv_per_layer( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, false>( src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_per_layer_pf_lf( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, false>( src_k, dst_k, src_v, dst_v, src_indices, dst_indices, layer_id, 1, item_size, src_layout_dim, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_all_layer( const at::Tensor src_k_layers, const at::Tensor dst_k_layers, const at::Tensor src_v_layers, const at::Tensor dst_v_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_lf_tbl, false>( empty, empty, empty, empty, src_indices, dst_indices, 0, num_layers, item_size, 0, 0, src_k_layers, dst_k_layers, src_v_layers, dst_v_layers, block_quota, num_warps_per_block); } void transfer_kv_all_layer_lf_pf( const at::Tensor src_k_layers, at::Tensor dst_k, const at::Tensor src_v_layers, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t dst_layout_dim, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_pf, false>( empty, dst_k, empty, dst_v, src_indices, dst_indices, 0, num_layers, item_size, 0, dst_layout_dim, src_k_layers, empty, src_v_layers, empty, block_quota, num_warps_per_block); } void transfer_kv_per_layer_mla( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, true>( src, dst, empty, empty, src_indices, dst_indices, 0, 1, item_size, 0, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_per_layer_mla_pf_lf( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, true>( src, dst, empty, empty, src_indices, dst_indices, layer_id, 1, item_size, src_layout_dim, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_all_layer_mla( const at::Tensor src_layers, const at::Tensor dst_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_lf_tbl, true>( empty, empty, empty, empty, src_indices, dst_indices, 0, num_layers, item_size, 0, 0, src_layers, dst_layers, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_all_layer_mla_lf_pf( const at::Tensor src_layers, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t dst_layout_dim, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_pf, true>( empty, dst, empty, empty, src_indices, dst_indices, 0, num_layers, item_size, 0, dst_layout_dim, src_layers, empty, empty, empty, block_quota, num_warps_per_block); } inline void transfer_page_direct( const at::Tensor src_buffer, at::Tensor dst_buffer, int64_t src_page_index, int64_t dst_page_index, int64_t page_size) { dst_buffer.slice(0, dst_page_index, dst_page_index + page_size) .copy_( src_buffer.slice(0, src_page_index, src_page_index + page_size), /* non_blocking= */ true); } void transfer_kv_direct( const std::vector& src_layers, std::vector dst_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t page_size) { TORCH_CHECK( src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); auto src_indices_cpu = src_indices.cpu(); auto dst_indices_cpu = dst_indices.cpu(); const auto num_indices = src_indices_cpu.numel(); const int64_t num_layers = src_layers.size(); int64_t* src_indices_ptr = src_indices_cpu.data_ptr(); int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr(); int64_t start_index = 0; int64_t end_index = 0; for (int64_t i = 0; i < num_indices; ++i) { if (i < num_indices - 1) { auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i]; auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i]; if (src_diff == 1 && dst_diff == 1) { continue; } end_index = i + 1; } else { // last batch end_index = num_indices; } auto src_index = src_indices_ptr[start_index]; auto dst_index = dst_indices_ptr[start_index]; auto num_tokens = end_index - start_index; for (int64_t j = 0; j < num_layers; ++j) { transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens); } start_index = end_index; } } template inline void transfer_kv_page_first_direct_impl( const std::vector& src_ptrs, std::vector dst_ptrs, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t start_layer_id, int64_t page_size) { TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); auto src_indices_cpu = src_indices.cpu(); auto dst_indices_cpu = dst_indices.cpu(); const int64_t num_pages = src_indices_cpu.size(0) / page_size; if constexpr (IsLf2Pf) { const bool is_mla = dst_ptrs.size() == 1; const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2; for (const auto i : c10::irange(num_pages)) { auto s_index = src_indices_cpu[i * page_size].item(); auto d_index = dst_indices_cpu[i * page_size].item() / page_size; for (int64_t j = 0; j < num_layers; ++j) { transfer_page_direct( src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); if (!is_mla) { transfer_page_direct( src_ptrs[j + num_layers], dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); } } } } else { const bool is_mla = src_ptrs.size() == 1; const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2; for (const auto i : c10::irange(num_pages)) { auto s_index = src_indices_cpu[i * page_size].item() / page_size; auto d_index = dst_indices_cpu[i * page_size].item(); for (int64_t j = 0; j < num_layers; ++j) { transfer_page_direct( src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size); if (!is_mla) { transfer_page_direct( src_ptrs[1].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j + num_layers], 0, d_index, page_size); } } } } } void transfer_kv_per_layer_direct_pf_lf( const std::vector& src_ptrs, std::vector dst_ptrs, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t layer_id, int64_t page_size) { transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size); } void transfer_kv_all_layer_direct_lf_pf( const std::vector& src_ptrs, std::vector dst_ptrs, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t page_size) { transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); } __device__ int64_t ceil_div(int64_t a, int64_t b) { return (a + b - 1) / b; } __device__ int64_t safe_min(int64_t a, int64_t b) { return a < b ? a : b; } __global__ void launch_alloc_decode_kernel( const int64_t* seq_lens_ptr, const int32_t* last_loc_ptr, const int64_t* free_page_ptr, int64_t* out_indices, int64_t bs, int64_t page_size) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t seq_len = seq_lens_ptr[pid]; int64_t pre_len = seq_len - 1; int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size); int64_t sum_num_new_pages = 0; for (int64_t i = 0; i <= pid; i++) { int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len; int64_t other_num_pages_after = ceil_div(other_seq_len, page_size); int64_t other_num_pages_before = ceil_div(other_pre_len, page_size); int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before; sum_num_new_pages += other_num_new_pages; } int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self; if (num_page_start_loc_self == 0) { int32_t last_loc = last_loc_ptr[pid]; out_indices[pid] = last_loc + 1; } else { int64_t page = free_page_ptr[new_page_start_loc]; out_indices[pid] = page * page_size; } } __global__ void launch_alloc_extend_kernel( const int64_t* pre_lens_ptr, const int64_t* seq_lens_ptr, const int64_t* last_loc_ptr, const int64_t* free_page_ptr, int64_t* out_indices, int64_t bs, int64_t page_size) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t seq_len = seq_lens_ptr[pid]; int64_t pre_len = pre_lens_ptr[pid]; int64_t extend_len = seq_len - pre_len; int64_t sum_extend_lens = 0; for (int64_t i = 0; i <= pid; i++) { int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_pre_len = pre_lens_ptr[i]; int64_t other_extend_len = other_seq_len - other_pre_len; sum_extend_lens += other_extend_len; } int64_t output_start_loc = sum_extend_lens - extend_len; int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size); int64_t sum_num_new_pages = 0; for (int64_t i = 0; i <= pid; i++) { int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_pre_len = pre_lens_ptr[i]; int64_t other_num_pages_after = ceil_div(other_seq_len, page_size); int64_t other_num_pages_before = ceil_div(other_pre_len, page_size); int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before; sum_num_new_pages += other_num_new_pages; } int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self; int64_t last_loc = last_loc_ptr[pid]; int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len; for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) { int64_t output_idx = output_start_loc + offset; out_indices[output_idx] = last_loc + 1 + offset; } if (pre_len + num_part1 == seq_len) { return; } int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size; for (int64_t offset = 0; offset < num_part2; offset++) { int64_t page_idx = new_page_start_loc + offset / page_size; int64_t page_start = free_page_ptr[page_idx]; int64_t output_idx = output_start_loc + num_part1 + offset; out_indices[output_idx] = page_start * page_size + offset % page_size; } if (pre_len + num_part1 + num_part2 == seq_len) { return; } int64_t num_part3 = seq_len - (seq_len / page_size) * page_size; int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1; int64_t start_loc = free_page_ptr[last_page_idx]; for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) { int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset; out_indices[output_idx] = start_loc * page_size + offset; } } __global__ void launch_create_extend_after_decode_spec_info_int32_kernel( const int32_t* verified_id_ptr, const int64_t* seq_lens_ptr, const int32_t* accept_lens_ptr, int64_t* positions_ptr, int32_t* new_verified_id_ptr, int64_t bs) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t seq_length = seq_lens_ptr[pid]; int32_t accept_length = accept_lens_ptr[pid]; int32_t accept_len_cumsum = 0; for (int32_t offset = 0; offset < pid; offset++) { accept_len_cumsum += accept_lens_ptr[offset]; } int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum; for (int32_t offset = 0; offset < accept_length && offset < bs; offset++) { positions_ptr1[offset] = seq_length - accept_length + offset; } int32_t verified_idx = accept_len_cumsum + accept_length - 1; new_verified_id_ptr[pid] = verified_id_ptr[verified_idx]; } __global__ void launch_create_extend_after_decode_spec_info_int64_kernel( const int32_t* verified_id_ptr, const int64_t* seq_lens_ptr, const int64_t* accept_lens_ptr, int64_t* positions_ptr, int32_t* new_verified_id_ptr, int64_t bs) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t seq_length = seq_lens_ptr[pid]; int64_t accept_length = accept_lens_ptr[pid]; int64_t accept_len_cumsum = 0; for (int64_t offset = 0; offset < pid; offset++) { accept_len_cumsum += accept_lens_ptr[offset]; } int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum; for (int64_t offset = 0; offset < accept_length && offset < bs; offset++) { positions_ptr1[offset] = seq_length - accept_length + offset; } int64_t verified_idx = accept_len_cumsum + accept_length - 1; new_verified_id_ptr[pid] = verified_id_ptr[verified_idx]; } void dcu_alloc_decode_kernel( const at::Tensor seq_lens_ptr, const at::Tensor last_loc_ptr, const at::Tensor free_page_ptr, at::Tensor out_indices, int64_t bs, int64_t page_size) { const int64_t* seq_lens_ptr1 = static_cast(seq_lens_ptr.data_ptr()); const int32_t* last_loc_ptr1 = static_cast(last_loc_ptr.data_ptr()); const int64_t* free_page_ptr1 = static_cast(free_page_ptr.data_ptr()); int64_t* out_indices1 = static_cast(out_indices.data_ptr()); int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); launch_alloc_decode_kernel<<>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); } void dcu_create_extend_after_decode_spec_info( const at::Tensor verified_id, const at::Tensor seq_lens, const at::Tensor accept_lens, at::Tensor positions, at::Tensor new_verified_id, int64_t bs) { const int32_t* verified_id_ptr; const int64_t* seq_lens_ptr; const int32_t* accept_lens_ptr_int32; const int64_t* accept_lens_ptr_int64; int64_t* positions_ptr; int32_t* new_verified_id_ptr; int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); if (accept_lens.dtype() == torch::kInt32) { verified_id_ptr = static_cast(verified_id.data_ptr()); seq_lens_ptr = static_cast(seq_lens.data_ptr()); accept_lens_ptr_int32 = static_cast(accept_lens.data_ptr()); positions_ptr = static_cast(positions.data_ptr()); new_verified_id_ptr = static_cast(new_verified_id.data_ptr()); launch_create_extend_after_decode_spec_info_int32_kernel<<>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int32, positions_ptr, new_verified_id_ptr, bs); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { verified_id_ptr = static_cast(verified_id.data_ptr()); seq_lens_ptr = static_cast(seq_lens.data_ptr()); accept_lens_ptr_int64 = static_cast(accept_lens.data_ptr()); positions_ptr = static_cast(positions.data_ptr()); new_verified_id_ptr = static_cast(new_verified_id.data_ptr()); launch_create_extend_after_decode_spec_info_int64_kernel<<>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int64, positions_ptr, new_verified_id_ptr, bs); C10_CUDA_KERNEL_LAUNCH_CHECK(); } }; void dcu_alloc_extend_kernel( const at::Tensor pre_lens_ptr, const at::Tensor seq_lens_ptr, const at::Tensor last_loc_ptr, const at::Tensor free_page_ptr, at::Tensor out_indices, int64_t bs, int64_t page_size) { const int64_t* pre_lens_ptr1 = static_cast(pre_lens_ptr.data_ptr()); const int64_t* seq_lens_ptr1 = static_cast(seq_lens_ptr.data_ptr()); const int64_t* last_loc_ptr1 = static_cast(last_loc_ptr.data_ptr()); const int64_t* free_page_ptr1 = static_cast(free_page_ptr.data_ptr()); int64_t* out_indices1 = static_cast(out_indices.data_ptr()); int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); launch_alloc_extend_kernel<<>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); } __global__ void launch_assign_req_to_token_pool( const int64_t* req_pool_indices_ptr, int32_t* req_to_token_ptr, const int64_t* allocate_lens_ptr, int64_t* new_allocate_lens, int64_t* out_cache_loc_ptr, int64_t shape, int64_t bs) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t kv_start = allocate_lens_ptr[pid]; int64_t kv_end = new_allocate_lens[pid]; int64_t pool_idx = req_pool_indices_ptr[pid]; int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape); int64_t sum_out_offset = 0; for(int length_offset = 0; length_offset < pid;length_offset++){ int64_t start = allocate_lens_ptr[length_offset]; int64_t end = new_allocate_lens[length_offset]; sum_out_offset += (end- start); } int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset; int64_t copy_length = kv_end - kv_start; #pragma unroll(32) for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) { token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index]; } } void dcu_assign_req_to_token_pool( const at::Tensor req_pool_indices_ptr, at::Tensor req_to_token_ptr, const at::Tensor allocate_lens_ptr, at::Tensor new_allocate_lens, at::Tensor out_cache_loc_ptr, int64_t shape, int64_t bs) { const int64_t* req_pool_indices_ptr1 = static_cast(req_pool_indices_ptr.data_ptr()); int32_t* req_to_token_ptr1 = static_cast(req_to_token_ptr.data_ptr()); const int64_t* allocate_lens_ptr1 = static_cast(allocate_lens_ptr.data_ptr()); int64_t* new_allocate_lens1 = static_cast(new_allocate_lens.data_ptr()); int64_t* out_cache_loc_ptr1 = static_cast(out_cache_loc_ptr.data_ptr()); int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); launch_assign_req_to_token_pool<<>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs); C10_CUDA_KERNEL_LAUNCH_CHECK(); } __global__ void get_last_loc_kernel( const int32_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices_tensor, const int64_t* __restrict__ prefix_lens_tensor, int64_t* __restrict__ result, int64_t num_tokens, int64_t req_to_token_stride){ int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= num_tokens) return; int64_t pre_len = prefix_lens_tensor[pid]; if (pre_len > 0) { int64_t req_idx = req_pool_indices_tensor[pid]; int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1); result[pid] = static_cast(req_to_token[token_idx]); } else { result[pid] = static_cast(-1); } } at::Tensor dcu_get_last_loc( const at::Tensor req_to_token, const at::Tensor req_pool_indices, const at::Tensor prefix_lens) { TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor"); TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor"); TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor"); TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]"); TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D"); TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D"); int64_t num_tokens = prefix_lens.numel(); TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens"); int64_t req_to_token_stride = req_to_token.stride(0); auto req_to_token_c = req_to_token.contiguous(); auto req_pool_indices_c = req_pool_indices.contiguous(); auto prefix_lens_c = prefix_lens.contiguous(); const int32_t* req_to_token_ptr = req_to_token_c.data_ptr(); const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr(); const int64_t* prefix_lens_ptr = prefix_lens_c.data_ptr(); auto result = at::empty_like(prefix_lens_c); int64_t* result_ptr = result.data_ptr(); const int64_t block_size = 64; const int64_t grid_size = (num_tokens + block_size - 1) / block_size; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); get_last_loc_kernel<<>>( req_to_token_ptr, req_pool_indices_ptr, prefix_lens_ptr, result_ptr, num_tokens, req_to_token_stride ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return result; } __global__ void launch_assign_extend_cache_locs_kernel( const int64_t* __restrict__ req_pool_indices, // [bs] const int32_t* __restrict__ req_to_token, // [max_num_req, pool_len] const int64_t* __restrict__ start_offset, // [bs] const int64_t* __restrict__ end_offset, // [bs] int64_t* __restrict__ out_cache_loc, // [sum(draft_token_num)] int64_t pool_len, int64_t bs) { int pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t kv_start = start_offset[pid]; int64_t kv_end = end_offset[pid]; int64_t req_id = req_pool_indices[pid]; int64_t out_offset = 0; for (int i = 0; i < pid; ++i) { out_offset += end_offset[i] - start_offset[i]; } const int32_t* src = req_to_token + req_id * pool_len + kv_start; int64_t* dst = out_cache_loc + out_offset; for (int64_t i = 0; i < kv_end - kv_start; ++i) { dst[i] = src[i]; } } void dcu_assign_extend_cache_locs( const at::Tensor req_pool_indices, const at::Tensor req_to_token, const at::Tensor start_offset, const at::Tensor end_offset, at::Tensor out_cache_loc, int64_t pool_len, int64_t bs) { const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr(); const int32_t* req_to_token_ptr = req_to_token.data_ptr(); const int64_t* start_offset_ptr = start_offset.data_ptr(); const int64_t* end_offset_ptr = end_offset.data_ptr(); int64_t* out_cache_loc_ptr = out_cache_loc.data_ptr(); constexpr int64_t threads = 128; int64_t blocks = (bs + threads - 1) / threads; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); launch_assign_extend_cache_locs_kernel<<>>( req_pool_indices_ptr, req_to_token_ptr, start_offset_ptr, end_offset_ptr, out_cache_loc_ptr, pool_len, bs); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template __global__ void dcu_create_flashmla_kv_indices_kernel( const int32_t* __restrict__ req_to_token, const int32_t* __restrict__ req_pool_indices, const int32_t* __restrict__ page_kernel_lens, const int32_t* __restrict__ kv_start_idx, int32_t* __restrict__ kv_indices, int req_to_token_stride, int kv_indices_stride) { int pid = blockIdx.x; // batch index int req_pool_index = req_pool_indices[pid]; int kv_start = 0; int kv_end = 0; if (kv_start_idx != nullptr) { kv_start = kv_start_idx[pid]; kv_end = kv_start; } kv_end += page_kernel_lens[pid]; int total_len = kv_end - kv_start; int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE; for (int pg = 0; pg < num_pages; ++pg) { int offset = pg * PAGED_SIZE; // token id = req_to_token[req_pool_index][kv_start + offset] int64_t token = req_to_token[req_pool_index * req_to_token_stride + kv_start + offset]; // 页索引 kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE; } } void dcu_create_flashmla_kv_indices( const at::Tensor& req_to_token, const at::Tensor& req_pool_indices, const at::Tensor& page_kernel_lens, const c10::optional& kv_start_idx, at::Tensor& kv_indices, int64_t req_to_token_stride, int64_t kv_indices_stride, int64_t PAGED_SIZE) { TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor"); TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor"); int bs = req_pool_indices.size(0); auto stream = at::cuda::getCurrentCUDAStream(); dim3 grid(bs); dim3 block(1); const int32_t* kv_start_idx_ptr = nullptr; if (kv_start_idx.has_value()) { kv_start_idx_ptr = kv_start_idx.value().data_ptr(); } if (PAGED_SIZE == 64) { dcu_create_flashmla_kv_indices_kernel<64><<>>( req_to_token.data_ptr(), req_pool_indices.data_ptr(), page_kernel_lens.data_ptr(), kv_start_idx_ptr, kv_indices.data_ptr(), req_to_token_stride, kv_indices_stride ); } else { TORCH_CHECK(false, "Unsupported PAGED_SIZE"); } } __global__ void launch_create_chunked_prefix_cache_kv_indices( int32_t* req_to_token_ptr, const int64_t* req_pool_indices_ptr, const int32_t* chunk_starts_ptr, const int32_t* chunk_seq_lens_ptr, const int32_t* chunk_cu_seq_lens_ptr, int32_t* chunk_kv_indices_ptr, int64_t col_num, int64_t bs) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t req_pool_index = req_pool_indices_ptr[pid]; int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid]; int32_t chunk_start_pos = chunk_starts_ptr[pid]; int32_t chunk_seq_len = chunk_seq_lens_ptr[pid]; #pragma unroll(32) for(int32_t offset = 0;offset < chunk_seq_len;offset++){ chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset]; } } void dcu_create_chunked_prefix_cache_kv_indices( at::Tensor req_to_token_ptr, const at::Tensor req_pool_indices_ptr, const at::Tensor chunk_starts_ptr, const at::Tensor chunk_seq_lens_ptr, const at::Tensor chunk_cu_seq_lens_ptr, at::Tensor chunk_kv_indices_ptr, int64_t col_num, int64_t bs) { int32_t* req_to_token_ptr1 = static_cast(req_to_token_ptr.data_ptr()); const int64_t* req_pool_indices_ptr1 = static_cast(req_pool_indices_ptr.data_ptr()); const int32_t* chunk_starts_ptr1 = static_cast(chunk_starts_ptr.data_ptr()); const int32_t* chunk_seq_lens_ptr1 = static_cast(chunk_seq_lens_ptr.data_ptr()); const int32_t* chunk_cu_seq_lens_ptr1 = static_cast(chunk_cu_seq_lens_ptr.data_ptr()); int32_t* chunk_kv_indices_ptr1 = static_cast(chunk_kv_indices_ptr.data_ptr()); int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); launch_create_chunked_prefix_cache_kv_indices<<>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs); C10_CUDA_KERNEL_LAUNCH_CHECK(); }