Unverified Commit b4326330 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hicache IO kernel refactoring (#8264)

parent 8abd3e77
...@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
m.def( m.def(
"transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size) -> ()"); "dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct); m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def( m.def(
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int " "Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()"); "num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer); m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
m.def( m.def(
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"dst_indices, int page_size, int num_layers) -> ()"); "Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct); "num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
m.def( m.def(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " "transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"); "block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
m.def( m.def(
"transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) " "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, "
"-> ()"); "int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct); m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
m.def( m.def(
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " "transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
"num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()"); "item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla); m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
m.def( m.def(
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, " "transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int num_layers) -> ()"); "int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct); m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
m.def(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
/* /*
* From csrc/moe/cutlass_moe/w4a8 * From csrc/moe/cutlass_moe/w4a8
......
...@@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_ ...@@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_
} }
} }
// todo, structs for different memory layout template <typename T>
__device__ __forceinline__ int64_t __device__ __forceinline__ T* get_global_offset_lf(
get_global_offset_lf(int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) { T* base,
const uintptr_t* __restrict__ /*unused*/,
int64_t layer_id,
int64_t layer_dim,
int64_t page_id,
int64_t item_size_bytes) {
// layer first // layer first
return layer_id * layer_dim + page_id * item_size_bytes; return base + layer_id * layer_dim + page_id * item_size_bytes;
} }
__device__ __forceinline__ int64_t template <typename T>
get_global_offset_pf(int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) { __device__ __forceinline__ T* get_global_offset_pf(
T* base,
const uintptr_t* __restrict__ /*unused*/,
int64_t layer_id,
int64_t page_dim,
int64_t page_id,
int64_t item_size_bytes) {
// page first // page first
return page_id * page_dim + layer_id * item_size_bytes; return base + page_id * page_dim + layer_id * item_size_bytes;
}
// get offset from layer base table when layers are not contiguous
template <typename T>
__device__ __forceinline__ T* get_global_offset_lf_tbl(
T* /*unused*/,
const uintptr_t* __restrict__ layer_base_tbl,
int64_t layer_id,
int64_t /*unused*/,
int64_t page_id,
int64_t item_size_bytes) {
return reinterpret_cast<T*>(layer_base_tbl[layer_id]) + page_id * item_size_bytes;
} }
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA> template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
...@@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl( ...@@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl(
int64_t items_per_warp, int64_t items_per_warp,
int64_t item_size_bytes, int64_t item_size_bytes,
int64_t src_layout_dim, int64_t src_layout_dim,
int64_t dst_layout_dim) { int64_t dst_layout_dim,
const uintptr_t* __restrict__ src_k_layer_tbl,
const uintptr_t* __restrict__ dst_k_layer_tbl,
const uintptr_t* __restrict__ src_v_layer_tbl,
const uintptr_t* __restrict__ dst_v_layer_tbl) {
int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
int32_t lane_id = tid % 32; int32_t lane_id = tid % 32;
int32_t warp_id = tid / 32; int32_t warp_id = tid / 32;
for (int i = 0; i < items_per_warp; ++i) { for (int i = 0; i < items_per_warp; ++i) {
int32_t item_id = warp_id * items_per_warp + i; int64_t item_id = warp_id * items_per_warp + i;
if (item_id >= num_items) { if (item_id >= num_items) {
return; break;
} }
const int64_t src_page_id = src_indices[item_id]; const int64_t src_page_id = src_indices[item_id];
const int64_t dst_page_id = dst_indices[item_id]; const int64_t dst_page_id = dst_indices[item_id];
// Loop over layers if necessary // Loop over layers if necessary
for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) {
// Calculate offsets using the provided function pointers const char* src_ptr = SrcOffsetFn(
const int64_t src_offset = SrcOffsetFn(layer_id, src_layout_dim, src_page_id, item_size_bytes); static_cast<const char*>(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes);
const int64_t dst_offset = DstOffsetFn(layer_id, dst_layout_dim, dst_page_id, item_size_bytes); char* dst_ptr = DstOffsetFn(
static_cast<char*>(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes);
if constexpr (IsMLA) { if constexpr (!IsMLA) {
transfer_item_warp( const char* src_v_ptr = SrcOffsetFn(
lane_id, static_cast<const char*>(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes);
static_cast<const char*>(src_k) + src_offset, char* dst_v_ptr = DstOffsetFn(
static_cast<char*>(dst_k) + dst_offset, static_cast<char*>(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
item_size_bytes); transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes);
} else {
transfer_item_warp(
lane_id,
static_cast<const char*>(src_k) + src_offset,
static_cast<char*>(dst_k) + dst_offset,
item_size_bytes);
transfer_item_warp(
lane_id,
static_cast<const char*>(src_v) + src_offset,
static_cast<char*>(dst_v) + dst_offset,
item_size_bytes);
} }
} }
} }
...@@ -103,44 +121,54 @@ void transfer_kv_launcher( ...@@ -103,44 +121,54 @@ void transfer_kv_launcher(
int64_t item_size, int64_t item_size,
int64_t src_layout_dim, int64_t src_layout_dim,
int64_t dst_layout_dim, int64_t dst_layout_dim,
const at::Tensor& src_k_layers,
const at::Tensor& dst_k_layers,
const at::Tensor& src_v_layers,
const at::Tensor& dst_v_layers,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block) { int64_t num_warps_per_block) {
TORCH_CHECK(src_k.scalar_type() == dst_k.scalar_type(), "Source and destination keys must have the same type");
TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor");
TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor");
TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long");
TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long");
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8");
if (!IsMLA) { auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; };
TORCH_CHECK(src_v.scalar_type() == dst_v.scalar_type(), "Source and destination values must have the same type");
}
int dtype_size = src_k.element_size();
TORCH_CHECK((item_size * dtype_size) % 8 == 0, "Item byte size must be divisible by 8");
auto div_up = [](int32_t x, int32_t y) { return (x + y - 1) / y; };
const int64_t num_items = src_indices.numel(); const int64_t num_items = src_indices.numel();
const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block);
const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block);
dim3 grid_dim(num_blocks, 1, 1); dim3 grid_dim(num_blocks, 1, 1);
const int32_t threads_per_block = num_warps_per_block * 32; const int32_t threads_per_block = num_warps_per_block * 32;
const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr;
void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr;
const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr();
void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr();
const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr<uintptr_t>() : nullptr;
const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr<uintptr_t>() : nullptr;
const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr<uintptr_t>();
const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr<uintptr_t>();
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>( transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
src_k.data_ptr(), src_k_ptr,
dst_k.data_ptr(), dst_k_ptr,
(IsMLA ? nullptr : src_v.data_ptr()), src_v_ptr,
(IsMLA ? nullptr : dst_v.data_ptr()), dst_v_ptr,
src_indices.data_ptr<int64_t>(), src_indices.data_ptr<int64_t>(),
dst_indices.data_ptr<int64_t>(), dst_indices.data_ptr<int64_t>(),
start_layer_id, start_layer_id,
num_layers_to_process, num_layers_to_process,
num_items, num_items,
items_per_warp, items_per_warp,
item_size * dtype_size, item_size,
src_layout_dim * dtype_size, src_layout_dim,
dst_layout_dim * dtype_size); dst_layout_dim,
src_k_tbl_ptr,
dst_k_tbl_ptr,
src_v_tbl_ptr,
dst_v_tbl_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
...@@ -154,11 +182,28 @@ void transfer_kv_per_layer( ...@@ -154,11 +182,28 @@ void transfer_kv_per_layer(
int64_t item_size, int64_t item_size,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block) { int64_t num_warps_per_block) {
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, false>( at::Tensor empty;
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, block_quota, num_warps_per_block); transfer_kv_launcher<get_global_offset_lf<const char>, get_global_offset_lf<char>, false>(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
0,
1,
item_size,
0,
0,
empty,
empty,
empty,
empty,
block_quota,
num_warps_per_block);
} }
void transfer_kv_all_layer( void transfer_kv_per_layer_pf_lf(
const at::Tensor src_k, const at::Tensor src_k,
at::Tensor dst_k, at::Tensor dst_k,
const at::Tensor src_v, const at::Tensor src_v,
...@@ -166,12 +211,11 @@ void transfer_kv_all_layer( ...@@ -166,12 +211,11 @@ void transfer_kv_all_layer(
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t item_size, int64_t item_size,
int64_t num_layers, int64_t src_layout_dim,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block) { int64_t num_warps_per_block) {
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, false>( at::Tensor empty;
transfer_kv_launcher<get_global_offset_pf<const char>, get_global_offset_lf<char>, false>(
src_k, src_k,
dst_k, dst_k,
src_v, src_v,
...@@ -179,10 +223,81 @@ void transfer_kv_all_layer( ...@@ -179,10 +223,81 @@ void transfer_kv_all_layer(
src_indices, src_indices,
dst_indices, dst_indices,
0, 0,
1,
item_size,
src_layout_dim,
0,
empty,
empty,
empty,
empty,
block_quota,
num_warps_per_block);
}
void transfer_kv_all_layer(
const at::Tensor src_k_layers,
const at::Tensor dst_k_layers,
const at::Tensor src_v_layers,
const at::Tensor dst_v_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block) {
TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
at::Tensor empty;
transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_lf_tbl<char>, false>(
empty,
empty,
empty,
empty,
src_indices,
dst_indices,
0,
num_layers,
item_size,
0,
0,
src_k_layers,
dst_k_layers,
src_v_layers,
dst_v_layers,
block_quota,
num_warps_per_block);
}
void transfer_kv_all_layer_lf_pf(
const at::Tensor src_k_layers,
at::Tensor dst_k,
const at::Tensor src_v_layers,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block) {
TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
at::Tensor empty;
transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_pf<char>, false>(
empty,
dst_k,
empty,
dst_v,
src_indices,
dst_indices,
0,
num_layers, num_layers,
item_size, item_size,
src_layer_offset, 0,
dst_layer_offset, dst_layout_dim,
src_k_layers,
empty,
src_v_layers,
empty,
block_quota, block_quota,
num_warps_per_block); num_warps_per_block);
} }
...@@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla( ...@@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla(
int64_t item_size, int64_t item_size,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block) { int64_t num_warps_per_block) {
at::Tensor empty_tensor = at::Tensor(); at::Tensor empty;
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, true>( transfer_kv_launcher<get_global_offset_lf<const char>, get_global_offset_lf<char>, true>(
src, src,
dst, dst,
empty_tensor, empty,
empty_tensor, empty,
src_indices, src_indices,
dst_indices, dst_indices,
0, 0,
...@@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla( ...@@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla(
item_size, item_size,
0, 0,
0, 0,
empty,
empty,
empty,
empty,
block_quota, block_quota,
num_warps_per_block); num_warps_per_block);
} }
void transfer_kv_all_layer_mla( void transfer_kv_per_layer_mla_pf_lf(
const at::Tensor src, const at::Tensor src,
at::Tensor dst, at::Tensor dst,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t item_size, int64_t item_size,
int64_t num_layers, int64_t src_layout_dim,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block) { int64_t num_warps_per_block) {
at::Tensor empty_tensor = at::Tensor(); at::Tensor empty;
transfer_kv_launcher<get_global_offset_lf, get_global_offset_lf, true>( transfer_kv_launcher<get_global_offset_pf<const char>, get_global_offset_lf<char>, true>(
src, src,
dst, dst,
empty_tensor, empty,
empty_tensor, empty,
src_indices,
dst_indices,
0,
1,
item_size,
src_layout_dim,
0,
empty,
empty,
empty,
empty,
block_quota,
num_warps_per_block);
}
void transfer_kv_all_layer_mla(
const at::Tensor src_layers,
const at::Tensor dst_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block) {
TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers");
at::Tensor empty;
transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_lf_tbl<char>, true>(
empty,
empty,
empty,
empty,
src_indices,
dst_indices,
0,
num_layers,
item_size,
0,
0,
src_layers,
dst_layers,
empty,
empty,
block_quota,
num_warps_per_block);
}
void transfer_kv_all_layer_mla_lf_pf(
const at::Tensor src_layers,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block) {
TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers");
at::Tensor empty;
transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_pf<char>, true>(
empty,
dst,
empty,
empty,
src_indices, src_indices,
dst_indices, dst_indices,
0, 0,
num_layers, num_layers,
item_size, item_size,
src_layer_offset, 0,
dst_layer_offset, dst_layout_dim,
src_layers,
empty,
empty,
empty,
block_quota, block_quota,
num_warps_per_block); num_warps_per_block);
} }
inline void transfer_page_direct( inline void transfer_page_direct(
const at::Tensor src_buffer, const at::Tensor& src_buffer,
at::Tensor dst_buffer, at::Tensor& dst_buffer,
int64_t src_page_index, int64_t src_page_index,
int64_t dst_page_index, int64_t dst_page_index,
int64_t page_size) { int64_t page_size) {
...@@ -252,16 +436,14 @@ inline void transfer_page_direct( ...@@ -252,16 +436,14 @@ inline void transfer_page_direct(
/* non_blocking= */ true); /* non_blocking= */ true);
} }
template <bool IsMLA, bool AllLayers> void transfer_kv_direct(
inline void transfer_kv_direct_impl( const std::vector<at::Tensor>& src_layers,
const at::Tensor& src_k, std::vector<at::Tensor> dst_layers,
at::Tensor& dst_k, const at::Tensor src_indices,
const at::Tensor& src_v_opt, // Only used when IsMLA is false (for src_v) const at::Tensor dst_indices,
at::Tensor& dst_v_opt, // Only used when IsMLA is false (for dst_v) int64_t page_size) {
const at::Tensor& src_indices, TORCH_CHECK(
const at::Tensor& dst_indices, src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers");
int64_t page_size,
int64_t num_layers = 1) {
TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(page_size > 0, "Page size must be positive");
TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");
...@@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl( ...@@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl(
auto dst_indices_cpu = dst_indices.cpu(); auto dst_indices_cpu = dst_indices.cpu();
const int64_t num_pages = src_indices_cpu.size(0) / page_size; const int64_t num_pages = src_indices_cpu.size(0) / page_size;
const int64_t num_layers = src_layers.size();
for (const auto i : c10::irange(num_pages)) { for (int64_t i = 0; i < num_pages; ++i) {
auto s_index = src_indices_cpu[i * page_size].item<int64_t>(); auto src_index = src_indices_cpu[i * page_size].item<int64_t>();
auto d_index = dst_indices_cpu[i * page_size].item<int64_t>(); auto dst_index = dst_indices_cpu[i * page_size].item<int64_t>();
if constexpr (AllLayers) { for (int64_t j = 0; j < num_layers; ++j) {
for (const auto j : c10::irange(num_layers)) { transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size);
if constexpr (IsMLA) {
transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size);
} else {
transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size);
transfer_page_direct(src_v_opt.select(0, j), dst_v_opt.select(0, j), s_index, d_index, page_size);
}
}
} else { // Per-layer
if constexpr (IsMLA) {
transfer_page_direct(src_k, dst_k, s_index, d_index, page_size);
} else {
transfer_page_direct(src_k, dst_k, s_index, d_index, page_size);
transfer_page_direct(src_v_opt, dst_v_opt, s_index, d_index, page_size);
}
} }
} }
} }
void transfer_kv_per_layer_direct(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size) {
transfer_kv_direct_impl<false, false>(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size);
}
void transfer_kv_all_layer_direct(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size,
int64_t num_layers) {
transfer_kv_direct_impl<false, true>(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers);
}
void transfer_kv_per_layer_mla_direct(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size) {
at::Tensor empty_tensor = at::Tensor();
transfer_kv_direct_impl<true, false>(src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size);
}
void transfer_kv_all_layer_mla_direct(
const at::Tensor src,
at::Tensor dst,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size,
int64_t num_layers) {
at::Tensor empty_tensor = at::Tensor();
transfer_kv_direct_impl<true, true>(
src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size, num_layers);
}
...@@ -399,38 +399,42 @@ void transfer_kv_per_layer( ...@@ -399,38 +399,42 @@ void transfer_kv_per_layer(
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block); int64_t num_warps_per_block);
void transfer_kv_per_layer_direct( void transfer_kv_per_layer_pf_lf(
const at::Tensor src_k, const at::Tensor src_k,
at::Tensor dst_k, at::Tensor dst_k,
const at::Tensor src_v, const at::Tensor src_v,
at::Tensor dst_v, at::Tensor dst_v,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t page_size); int64_t item_size,
int64_t src_layout_dim,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer( void transfer_kv_all_layer(
const at::Tensor src_k, const at::Tensor src_k_layers,
at::Tensor dst_k, const at::Tensor dst_k_layers,
const at::Tensor src_v, const at::Tensor src_v_layers,
at::Tensor dst_v, const at::Tensor dst_v_layers,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t item_size, int64_t item_size,
int64_t num_layers, int64_t num_layers,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block); int64_t num_warps_per_block);
void transfer_kv_all_layer_direct( void transfer_kv_all_layer_lf_pf(
const at::Tensor src_k, const at::Tensor src_k_layers,
at::Tensor dst_k, at::Tensor dst_k,
const at::Tensor src_v, const at::Tensor src_v_layers,
at::Tensor dst_v, at::Tensor dst_v,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t page_size, int64_t item_size,
int64_t num_layers); int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla( void transfer_kv_per_layer_mla(
const at::Tensor src, const at::Tensor src,
...@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla( ...@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block); int64_t num_warps_per_block);
void transfer_kv_per_layer_mla_direct( void transfer_kv_per_layer_mla_pf_lf(
const at::Tensor src, const at::Tensor src,
at::Tensor dst, at::Tensor dst,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t page_size); int64_t item_size,
int64_t src_layout_dim,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_mla( void transfer_kv_all_layer_mla(
const at::Tensor src, const at::Tensor src_layers,
at::Tensor dst, const at::Tensor dst_layers,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t item_size, int64_t item_size,
int64_t num_layers, int64_t num_layers,
int64_t src_layer_offset,
int64_t dst_layer_offset,
int64_t block_quota, int64_t block_quota,
int64_t num_warps_per_block); int64_t num_warps_per_block);
void transfer_kv_all_layer_mla_direct( void transfer_kv_all_layer_mla_lf_pf(
const at::Tensor src, const at::Tensor src_layers,
at::Tensor dst, at::Tensor dst,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t page_size, int64_t item_size,
int64_t num_layers); int64_t dst_layout_dim,
int64_t num_layers,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_direct(
const std::vector<at::Tensor>& src_layers,
std::vector<at::Tensor> dst_layers,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t page_size);
/* /*
* From csrc/moe/cutlass_moe/w4a8 * From csrc/moe/cutlass_moe/w4a8
......
from typing import List
import torch import torch
...@@ -22,57 +24,116 @@ def transfer_kv_per_layer( ...@@ -22,57 +24,116 @@ def transfer_kv_per_layer(
dst_v, dst_v,
src_indices, src_indices,
dst_indices, dst_indices,
item_size, item_size * src_k.element_size(), # todo, hot fix for compatibility
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct": elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_per_layer_direct( torch.ops.sgl_kernel.transfer_kv_direct(
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size [src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size
) )
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer( def transfer_kv_per_layer_pf_lf(
src_k: torch.Tensor, src_k: torch.Tensor,
dst_k: torch.Tensor, dst_k: torch.Tensor,
src_v: torch.Tensor, src_v: torch.Tensor,
dst_v: torch.Tensor, dst_v: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer(
src_k_layers: torch.Tensor,
dst_k_layers: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str, io_backend: str,
page_size: int,
item_size: int, item_size: int,
num_layers: int, num_layers: int,
src_layer_offset: int,
dst_layer_offset: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel": if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer( torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k, src_k_layers,
dst_k, dst_k_layers,
src_v, src_v_layers,
dst_v, dst_v_layers,
src_indices, src_indices,
dst_indices, dst_indices,
item_size, item_size,
num_layers, num_layers,
src_layer_offset,
dst_layer_offset,
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct": elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_all_layer_direct( raise NotImplementedError("Deprecated interface")
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers
)
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_lf_pf(
src_k_layers: torch.Tensor,
dst_k: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
src_k_layers,
dst_k,
src_v_layers,
dst_v,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_direct(
src_layers: List[torch.Tensor],
dst_layers: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_direct(
src_layers, dst_layers, src_indices, dst_indices, page_size
)
def transfer_kv_per_layer_mla( def transfer_kv_per_layer_mla(
src: torch.Tensor, src: torch.Tensor,
dst: torch.Tensor, dst: torch.Tensor,
...@@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla( ...@@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla(
dst, dst,
src_indices, src_indices,
dst_indices, dst_indices,
item_size, item_size * src.element_size(), # todo, hot fix for compatibility
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct": elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct( torch.ops.sgl_kernel.transfer_kv_direct(
src, dst, src_indices, dst_indices, page_size [src], [dst], src_indices, dst_indices, page_size
) )
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_mla( def transfer_kv_per_layer_mla_pf_lf(
src: torch.Tensor, src: torch.Tensor,
dst: torch.Tensor, dst: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
src,
dst,
src_indices,
dst_indices,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla(
src_layers: torch.Tensor,
dst_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str, io_backend: str,
page_size: int,
item_size: int, item_size: int,
num_layers: int, num_layers: int,
src_layer_offset: int,
dst_layer_offset: int,
block_quota: int = 2, block_quota: int = 2,
num_warps_per_block: int = 32, num_warps_per_block: int = 32,
): ):
if io_backend == "kernel": if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer_mla( torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src, src_layers,
dst, dst_layers,
src_indices, src_indices,
dst_indices, dst_indices,
item_size, item_size,
num_layers, num_layers,
src_layer_offset,
dst_layer_offset,
block_quota, block_quota,
num_warps_per_block, num_warps_per_block,
) )
elif io_backend == "direct": elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct( raise NotImplementedError("Deprecated interface")
src, dst, src_indices, dst_indices, page_size, num_layers
)
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_mla_lf_pf(
src_layers: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
src_layers,
dst,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from sgl_kernel.kvcacheio import ( from sgl_kernel.kvcacheio import (
transfer_kv_all_layer, transfer_kv_all_layer,
transfer_kv_all_layer_mla, transfer_kv_all_layer_mla,
transfer_kv_direct,
transfer_kv_per_layer, transfer_kv_per_layer,
transfer_kv_per_layer_mla, transfer_kv_per_layer_mla,
) )
...@@ -104,14 +105,12 @@ def test_transfer_kv( ...@@ -104,14 +105,12 @@ def test_transfer_kv(
page_size=page_size, page_size=page_size,
item_size=item_size, item_size=item_size,
) )
transfer_kv_per_layer_mla( transfer_kv_direct(
src_pool_host[layer_idx_to_test], [src_pool_host[layer_idx_to_test]],
dst_pool_direct[layer_idx_to_test], [dst_pool_direct[layer_idx_to_test]],
src_indices_host, src_indices_host,
dst_indices_device, dst_indices_device,
io_backend="direct",
page_size=page_size, page_size=page_size,
item_size=item_size,
) )
else: else:
for layer_id in range(num_layers): for layer_id in range(num_layers):
...@@ -121,29 +120,34 @@ def test_transfer_kv( ...@@ -121,29 +120,34 @@ def test_transfer_kv(
src_indices_host, src_indices_host,
dst_indices_device, dst_indices_device,
) )
src_layers_device = torch.tensor(
[src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_layers_device = torch.tensor(
[
dst_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer_mla( transfer_kv_all_layer_mla(
src_pool_host, src_layers_device,
dst_pool_kernel, dst_layers_device,
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel", io_backend="kernel",
page_size=page_size, item_size=item_size * dtype.itemsize,
item_size=item_size,
num_layers=num_layers, num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
) )
transfer_kv_all_layer_mla( transfer_kv_direct(
src_pool_host, [src_pool_host[layer_id] for layer_id in range(num_layers)],
dst_pool_direct, [dst_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host, src_indices_host,
dst_indices_device, dst_indices_device,
io_backend="direct",
page_size=page_size, page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref) torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
...@@ -173,16 +177,15 @@ def test_transfer_kv( ...@@ -173,16 +177,15 @@ def test_transfer_kv(
page_size=page_size, page_size=page_size,
item_size=item_size, item_size=item_size,
) )
transfer_kv_per_layer( transfer_kv_direct(
src_k_pool[layer_idx_to_test], [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
dst_k_pool_direct[layer_idx_to_test], [
src_v_pool[layer_idx_to_test], dst_k_pool_direct[layer_idx_to_test],
dst_v_pool_direct[layer_idx_to_test], dst_v_pool_direct[layer_idx_to_test],
],
src_indices_host, src_indices_host,
dst_indices_device, dst_indices_device,
io_backend="direct",
page_size=page_size, page_size=page_size,
item_size=item_size,
) )
else: else:
for layer_id in range(num_layers): for layer_id in range(num_layers):
...@@ -198,33 +201,52 @@ def test_transfer_kv( ...@@ -198,33 +201,52 @@ def test_transfer_kv(
src_indices_host, src_indices_host,
dst_indices_device, dst_indices_device,
) )
src_k_layers_device = torch.tensor(
[src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
src_v_layers_device = torch.tensor(
[src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_k_layers_device = torch.tensor(
[
dst_k_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
dst_v_layers_device = torch.tensor(
[
dst_v_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer( transfer_kv_all_layer(
src_k_pool, src_k_layers_device,
dst_k_pool_kernel, dst_k_layers_device,
src_v_pool, src_v_layers_device,
dst_v_pool_kernel, dst_v_layers_device,
src_indices_device, src_indices_device,
dst_indices_device, dst_indices_device,
io_backend="kernel", io_backend="kernel",
page_size=page_size, item_size=item_size * dtype.itemsize,
item_size=item_size,
num_layers=num_layers, num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
) )
transfer_kv_all_layer( transfer_kv_direct(
src_k_pool, [src_k_pool[layer_id] for layer_id in range(num_layers)]
dst_k_pool_direct, + [src_v_pool[layer_id] for layer_id in range(num_layers)],
src_v_pool, [dst_k_pool_direct[layer_id] for layer_id in range(num_layers)]
dst_v_pool_direct, + [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host, src_indices_host,
dst_indices_device, dst_indices_device,
io_backend="direct",
page_size=page_size, page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref) torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment