#include #include #include #include #ifndef USE_ROCM #define WARP_SIZE 64 #include "pytorch_extension_utils.h" #else #include "pytorch_extension_utils_rocm.h" #include "utils.h" // WARP_SIZE #endif __device__ __forceinline__ void transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { const uint64_t* __restrict__ src = static_cast(src_addr); uint64_t* __restrict__ dst = static_cast(dst_addr); const int total_chunks = item_size_bytes / sizeof(uint64_t); #pragma unroll for (int j = lane_id; j < total_chunks; j += WARP_SIZE) { #ifndef USE_ROCM uint64_t tmp; asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory"); asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory"); #else uint64_t tmp = __builtin_nontemporal_load(src + j); __builtin_nontemporal_store(tmp, dst + j); #endif } } template __device__ __forceinline__ T* get_global_offset_lf( T* base, const uintptr_t* __restrict__ /*unused*/, int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) { // layer first return base + layer_id * layer_dim + page_id * item_size_bytes; } template __device__ __forceinline__ T* get_global_offset_pf( T* base, const uintptr_t* __restrict__ /*unused*/, int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) { // page first return base + page_id * page_dim + layer_id * item_size_bytes; } // get offset from layer base table when layers are not contiguous template __device__ __forceinline__ T* get_global_offset_lf_tbl( T* /*unused*/, const uintptr_t* __restrict__ layer_base_tbl, int64_t layer_id, int64_t /*unused*/, int64_t page_id, int64_t item_size_bytes) { return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes; } template __global__ void transfer_kernel_impl( const void* __restrict__ src_k, void* __restrict__ dst_k, const void* __restrict__ src_v, void* __restrict__ dst_v, const int64_t* __restrict__ src_indices, const int64_t* __restrict__ dst_indices, int64_t start_layer_id, int64_t num_layers_to_process, int64_t num_items, int64_t items_per_warp, int64_t item_size_bytes, int64_t src_layout_dim, int64_t dst_layout_dim, const uintptr_t* __restrict__ src_k_layer_tbl, const uintptr_t* __restrict__ dst_k_layer_tbl, const uintptr_t* __restrict__ src_v_layer_tbl, const uintptr_t* __restrict__ dst_v_layer_tbl) { int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; int32_t lane_id = tid % WARP_SIZE; int32_t warp_id = tid / WARP_SIZE; for (int i = 0; i < items_per_warp; ++i) { int64_t item_id = warp_id * items_per_warp + i; if (item_id >= num_items) { break; } const int64_t src_page_id = src_indices[item_id]; const int64_t dst_page_id = dst_indices[item_id]; // Loop over layers if necessary for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { const char* src_ptr = SrcOffsetFn( static_cast(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); char* dst_ptr = DstOffsetFn( static_cast(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes); if constexpr (!IsMLA) { const char* src_v_ptr = SrcOffsetFn( static_cast(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); char* dst_v_ptr = DstOffsetFn( static_cast(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes); } } } } template void transfer_kv_launcher( const at::Tensor& src_k, at::Tensor& dst_k, const at::Tensor& src_v, at::Tensor& dst_v, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t start_layer_id, int64_t num_layers_to_process, int64_t item_size, int64_t src_layout_dim, int64_t dst_layout_dim, const at::Tensor& src_k_layers, const at::Tensor& dst_k_layers, const at::Tensor& src_v_layers, const at::Tensor& dst_v_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8"); auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; }; const int64_t num_items = src_indices.numel(); const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); dim3 grid_dim(num_blocks, 1, 1); const int32_t threads_per_block = num_warps_per_block * WARP_SIZE; const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr(); void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr(); const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr() : nullptr; const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr() : nullptr; const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr(); const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); transfer_kernel_impl<<>>( src_k_ptr, dst_k_ptr, src_v_ptr, dst_v_ptr, src_indices.data_ptr(), dst_indices.data_ptr(), start_layer_id, num_layers_to_process, num_items, items_per_warp, item_size, src_layout_dim, dst_layout_dim, src_k_tbl_ptr, dst_k_tbl_ptr, src_v_tbl_ptr, dst_v_tbl_ptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); } void transfer_kv_per_layer( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, false>( src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_per_layer_pf_lf( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, false>( src_k, dst_k, src_v, dst_v, src_indices, dst_indices, layer_id, 1, item_size, src_layout_dim, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_all_layer( const at::Tensor src_k_layers, const at::Tensor dst_k_layers, const at::Tensor src_v_layers, const at::Tensor dst_v_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_lf_tbl, false>( empty, empty, empty, empty, src_indices, dst_indices, 0, num_layers, item_size, 0, 0, src_k_layers, dst_k_layers, src_v_layers, dst_v_layers, block_quota, num_warps_per_block); } void transfer_kv_all_layer_lf_pf( const at::Tensor src_k_layers, at::Tensor dst_k, const at::Tensor src_v_layers, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t dst_layout_dim, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_pf, false>( empty, dst_k, empty, dst_v, src_indices, dst_indices, 0, num_layers, item_size, 0, dst_layout_dim, src_k_layers, empty, src_v_layers, empty, block_quota, num_warps_per_block); } void transfer_kv_per_layer_mla( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, true>( src, dst, empty, empty, src_indices, dst_indices, 0, 1, item_size, 0, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_per_layer_mla_pf_lf( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { at::Tensor empty; transfer_kv_launcher, get_global_offset_lf, true>( src, dst, empty, empty, src_indices, dst_indices, layer_id, 1, item_size, src_layout_dim, 0, empty, empty, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_all_layer_mla( const at::Tensor src_layers, const at::Tensor dst_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_lf_tbl, true>( empty, empty, empty, empty, src_indices, dst_indices, 0, num_layers, item_size, 0, 0, src_layers, dst_layers, empty, empty, block_quota, num_warps_per_block); } void transfer_kv_all_layer_mla_lf_pf( const at::Tensor src_layers, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t dst_layout_dim, int64_t num_layers, int64_t block_quota, int64_t num_warps_per_block) { TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); at::Tensor empty; transfer_kv_launcher, get_global_offset_pf, true>( empty, dst, empty, empty, src_indices, dst_indices, 0, num_layers, item_size, 0, dst_layout_dim, src_layers, empty, empty, empty, block_quota, num_warps_per_block); } inline void transfer_page_direct( const at::Tensor src_buffer, at::Tensor dst_buffer, int64_t src_page_index, int64_t dst_page_index, int64_t page_size) { dst_buffer.slice(0, dst_page_index, dst_page_index + page_size) .copy_( src_buffer.slice(0, src_page_index, src_page_index + page_size), /* non_blocking= */ true); } void transfer_kv_direct( const std::vector& src_layers, std::vector dst_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t page_size) { TORCH_CHECK( src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); auto src_indices_cpu = src_indices.cpu(); auto dst_indices_cpu = dst_indices.cpu(); const auto num_indices = src_indices_cpu.numel(); const int64_t num_layers = src_layers.size(); int64_t* src_indices_ptr = src_indices_cpu.data_ptr(); int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr(); int64_t start_index = 0; int64_t end_index = 0; for (int64_t i = 0; i < num_indices; ++i) { if (i < num_indices - 1) { auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i]; auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i]; if (src_diff == 1 && dst_diff == 1) { continue; } end_index = i + 1; } else { // last batch end_index = num_indices; } auto src_index = src_indices_ptr[start_index]; auto dst_index = dst_indices_ptr[start_index]; auto num_tokens = end_index - start_index; for (int64_t j = 0; j < num_layers; ++j) { transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens); } start_index = end_index; } } template inline void transfer_kv_page_first_direct_impl( const std::vector& src_ptrs, std::vector dst_ptrs, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t start_layer_id, int64_t page_size) { TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); auto src_indices_cpu = src_indices.cpu(); auto dst_indices_cpu = dst_indices.cpu(); const int64_t num_pages = src_indices_cpu.size(0) / page_size; if constexpr (IsLf2Pf) { const bool is_mla = dst_ptrs.size() == 1; const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2; for (const auto i : c10::irange(num_pages)) { auto s_index = src_indices_cpu[i * page_size].item(); auto d_index = dst_indices_cpu[i * page_size].item() / page_size; for (int64_t j = 0; j < num_layers; ++j) { transfer_page_direct( src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); if (!is_mla) { transfer_page_direct( src_ptrs[j + num_layers], dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); } } } } else { const bool is_mla = src_ptrs.size() == 1; const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2; for (const auto i : c10::irange(num_pages)) { auto s_index = src_indices_cpu[i * page_size].item() / page_size; auto d_index = dst_indices_cpu[i * page_size].item(); for (int64_t j = 0; j < num_layers; ++j) { transfer_page_direct( src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size); if (!is_mla) { transfer_page_direct( src_ptrs[1].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j + num_layers], 0, d_index, page_size); } } } } } void transfer_kv_per_layer_direct_pf_lf( const std::vector& src_ptrs, std::vector dst_ptrs, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t layer_id, int64_t page_size) { transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size); } void transfer_kv_all_layer_direct_lf_pf( const std::vector& src_ptrs, std::vector dst_ptrs, const at::Tensor& src_indices, const at::Tensor& dst_indices, int64_t page_size) { transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); } __device__ int64_t ceil_div(int64_t a, int64_t b) { return (a + b - 1) / b; } __device__ int64_t safe_min(int64_t a, int64_t b) { return a < b ? a : b; } __global__ void launch_alloc_decode_kernel( const int64_t* seq_lens_ptr, const int32_t* last_loc_ptr, const int64_t* free_page_ptr, int64_t* out_indices, int64_t bs, int64_t page_size) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t seq_len = seq_lens_ptr[pid]; int64_t pre_len = seq_len - 1; int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size); int64_t sum_num_new_pages = 0; for (int64_t i = 0; i <= pid; i++) { int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len; int64_t other_num_pages_after = ceil_div(other_seq_len, page_size); int64_t other_num_pages_before = ceil_div(other_pre_len, page_size); int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before; sum_num_new_pages += other_num_new_pages; } int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self; if (num_page_start_loc_self == 0) { int32_t last_loc = last_loc_ptr[pid]; out_indices[pid] = last_loc + 1; } else { int64_t page = free_page_ptr[new_page_start_loc]; out_indices[pid] = page * page_size; } } __global__ void launch_alloc_extend_kernel( const int64_t* pre_lens_ptr, const int64_t* seq_lens_ptr, const int64_t* last_loc_ptr, const int64_t* free_page_ptr, int64_t* out_indices, int64_t bs, int64_t page_size) { int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; if (pid >= bs) return; int64_t seq_len = seq_lens_ptr[pid]; int64_t pre_len = pre_lens_ptr[pid]; int64_t extend_len = seq_len - pre_len; int64_t sum_extend_lens = 0; for (int64_t i = 0; i <= pid; i++) { int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_pre_len = pre_lens_ptr[i]; int64_t other_extend_len = other_seq_len - other_pre_len; sum_extend_lens += other_extend_len; } int64_t output_start_loc = sum_extend_lens - extend_len; int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size); int64_t sum_num_new_pages = 0; for (int64_t i = 0; i <= pid; i++) { int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_pre_len = pre_lens_ptr[i]; int64_t other_num_pages_after = ceil_div(other_seq_len, page_size); int64_t other_num_pages_before = ceil_div(other_pre_len, page_size); int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before; sum_num_new_pages += other_num_new_pages; } int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self; int64_t last_loc = last_loc_ptr[pid]; int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len; for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) { int64_t output_idx = output_start_loc + offset; out_indices[output_idx] = last_loc + 1 + offset; } if (pre_len + num_part1 == seq_len) { return; } int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size; for (int64_t offset = 0; offset < num_part2; offset++) { int64_t page_idx = new_page_start_loc + offset / page_size; int64_t page_start = free_page_ptr[page_idx]; int64_t output_idx = output_start_loc + num_part1 + offset; out_indices[output_idx] = page_start * page_size + offset % page_size; } if (pre_len + num_part1 + num_part2 == seq_len) { return; } int64_t num_part3 = seq_len - (seq_len / page_size) * page_size; int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1; int64_t start_loc = free_page_ptr[last_page_idx]; for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) { int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset; out_indices[output_idx] = start_loc * page_size + offset; } } void dcu_alloc_decode_kernel( const at::Tensor seq_lens_ptr, const at::Tensor last_loc_ptr, const at::Tensor free_page_ptr, at::Tensor out_indices, int64_t bs, int64_t page_size) { const int64_t* seq_lens_ptr1 = static_cast(seq_lens_ptr.data_ptr()); const int32_t* last_loc_ptr1 = static_cast(last_loc_ptr.data_ptr()); const int64_t* free_page_ptr1 = static_cast(free_page_ptr.data_ptr()); int64_t* out_indices1 = static_cast(out_indices.data_ptr()); int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); launch_alloc_decode_kernel<<>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); } void dcu_alloc_extend_kernel( const at::Tensor pre_lens_ptr, const at::Tensor seq_lens_ptr, const at::Tensor last_loc_ptr, const at::Tensor free_page_ptr, at::Tensor out_indices, int64_t bs, int64_t page_size) { const int64_t* pre_lens_ptr1 = static_cast(pre_lens_ptr.data_ptr()); const int64_t* seq_lens_ptr1 = static_cast(seq_lens_ptr.data_ptr()); const int64_t* last_loc_ptr1 = static_cast(last_loc_ptr.data_ptr()); const int64_t* free_page_ptr1 = static_cast(free_page_ptr.data_ptr()); int64_t* out_indices1 = static_cast(out_indices.data_ptr()); int64_t block_size = 64; int64_t grid_size = (bs + block_size - 1) / block_size; cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); launch_alloc_extend_kernel<<>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); }