Unverified Commit 3e6281d0 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

[HiCache]Page head layout IO kernel (#11615)

parent 6371f7af
......@@ -370,6 +370,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def(
"transfer_kv_per_layer_ph_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int page_size, int head_num, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_ph_lf", torch::kCUDA, &transfer_kv_per_layer_ph_lf);
m.def(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
......@@ -380,6 +385,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
m.def(
"transfer_kv_all_layer_lf_ph(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int page_size, int "
"head_num, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_lf_ph", torch::kCUDA, &transfer_kv_all_layer_lf_ph);
m.def(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()");
......
......@@ -68,6 +68,140 @@ __device__ __forceinline__ T* get_global_offset_lf_tbl(
return reinterpret_cast<T*>(layer_base_tbl[layer_id]) + page_id * item_size_bytes;
}
template <typename T>
__device__ __forceinline__ T* get_global_offset_per_head_lf(
T* base,
const uintptr_t* __restrict__ /*unused*/,
int64_t layer_id,
int64_t layer_dim,
int64_t page_id,
int64_t item_size_bytes,
int64_t head_id,
int64_t head_num,
int64_t /*unused*/) {
// layer first offset func per head
return base + layer_id * layer_dim + page_id * item_size_bytes + item_size_bytes / head_num * head_id;
}
template <typename T>
__device__ __forceinline__ T* get_global_offset_per_head_lf_tbl(
T* /*unused*/,
const uintptr_t* __restrict__ layer_base_tbl,
int64_t layer_id,
int64_t /*unused*/,
int64_t page_id,
int64_t item_size_bytes,
int64_t head_id,
int64_t head_num,
int64_t /*unused*/) {
return reinterpret_cast<T*>(layer_base_tbl[layer_id]) + page_id * item_size_bytes +
item_size_bytes / head_num * head_id;
}
template <typename T>
__device__ __forceinline__ T* get_global_offset_ph(
T* base,
const uintptr_t* __restrict__ /*unused*/,
int64_t layer_id,
int64_t page_dim,
int64_t page_id,
int64_t item_size_bytes,
int64_t head_id,
int64_t head_num,
int64_t page_size) {
// page head layout: [page_num, head_num, page_size, layer_num, head_dim]
return base + page_id / page_size * page_size * page_dim + // page_num dimension offset
page_dim / head_num * head_id * page_size + // head_num dimension offset
page_id % page_size * page_dim / head_num + // page_size dimension offset
layer_id * item_size_bytes / head_num; // layer_num dimension offset
}
template <auto SrcOffsetFn, auto DstOffsetFn>
__global__ void transfer_page_head_kernel_impl(
const void* __restrict__ src_k,
void* __restrict__ dst_k,
const void* __restrict__ src_v,
void* __restrict__ dst_v,
const int64_t* __restrict__ src_indices,
const int64_t* __restrict__ dst_indices,
int64_t start_layer_id,
int64_t num_layers_to_process,
int64_t num_items,
int64_t items_per_warp,
int64_t item_size_bytes,
int64_t src_layout_dim,
int64_t dst_layout_dim,
const uintptr_t* __restrict__ src_k_layer_tbl,
const uintptr_t* __restrict__ dst_k_layer_tbl,
const uintptr_t* __restrict__ src_v_layer_tbl,
const uintptr_t* __restrict__ dst_v_layer_tbl,
const int64_t page_size,
const int64_t head_num) {
int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
int32_t lane_id = tid % WARP_SIZE;
int32_t warp_id = tid / WARP_SIZE;
const int64_t head_size_bytes = item_size_bytes / head_num;
for (int i = 0; i < items_per_warp; ++i) {
int64_t item_id = warp_id * items_per_warp + i;
if (item_id >= num_items) {
break;
}
const int64_t src_page_id = src_indices[item_id];
const int64_t dst_page_id = dst_indices[item_id];
// Loop over layers if necessary
for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) {
// For page head layout, the cache of each head in the token is discontinuous, need to loop
for (int64_t head_id = 0; head_id < head_num; ++head_id) {
const char* src_k_ptr = SrcOffsetFn(
static_cast<const char*>(src_k),
src_k_layer_tbl,
layer_id,
src_layout_dim,
src_page_id,
item_size_bytes,
head_id,
head_num,
page_size);
char* dst_k_ptr = DstOffsetFn(
static_cast<char*>(dst_k),
dst_k_layer_tbl,
layer_id,
dst_layout_dim,
dst_page_id,
item_size_bytes,
head_id,
head_num,
page_size);
transfer_item_warp(lane_id, src_k_ptr, dst_k_ptr, head_size_bytes);
const char* src_v_ptr = SrcOffsetFn(
static_cast<const char*>(src_v),
src_v_layer_tbl,
layer_id,
src_layout_dim,
src_page_id,
item_size_bytes,
head_id,
head_num,
page_size);
char* dst_v_ptr = DstOffsetFn(
static_cast<char*>(dst_v),
dst_v_layer_tbl,
layer_id,
dst_layout_dim,
dst_page_id,
item_size_bytes,
head_id,
head_num,
page_size);
transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, head_size_bytes);
}
}
}
}
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
__global__ void transfer_kernel_impl(
const void* __restrict__ src_k,
......@@ -118,7 +252,7 @@ __global__ void transfer_kernel_impl(
}
}
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA, bool PageHeadLayout = false>
void transfer_kv_launcher(
const at::Tensor& src_k,
at::Tensor& dst_k,
......@@ -136,7 +270,9 @@ void transfer_kv_launcher(
const at::Tensor& src_v_layers,
const at::Tensor& dst_v_layers,
int64_t block_quota,
int64_t num_warps_per_block) {
int64_t num_warps_per_block,
const int64_t page_size = 16,
const int64_t head_num = 1) {
TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor");
TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor");
TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long");
......@@ -161,24 +297,47 @@ void transfer_kv_launcher(
const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr<uintptr_t>();
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
src_k_ptr,
dst_k_ptr,
src_v_ptr,
dst_v_ptr,
src_indices.data_ptr<int64_t>(),
dst_indices.data_ptr<int64_t>(),
start_layer_id,
num_layers_to_process,
num_items,
items_per_warp,
item_size,
src_layout_dim,
dst_layout_dim,
src_k_tbl_ptr,
dst_k_tbl_ptr,
src_v_tbl_ptr,
dst_v_tbl_ptr);
if constexpr (PageHeadLayout) {
transfer_page_head_kernel_impl<SrcOffsetFn, DstOffsetFn><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
src_k_ptr,
dst_k_ptr,
src_v_ptr,
dst_v_ptr,
src_indices.data_ptr<int64_t>(),
dst_indices.data_ptr<int64_t>(),
start_layer_id,
num_layers_to_process,
num_items,
items_per_warp,
item_size,
src_layout_dim,
dst_layout_dim,
src_k_tbl_ptr,
dst_k_tbl_ptr,
src_v_tbl_ptr,
dst_v_tbl_ptr,
page_size,
head_num);
} else {
transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
src_k_ptr,
dst_k_ptr,
src_v_ptr,
dst_v_ptr,
src_indices.data_ptr<int64_t>(),
dst_indices.data_ptr<int64_t>(),
start_layer_id,
num_layers_to_process,
num_items,
items_per_warp,
item_size,
src_layout_dim,
dst_layout_dim,
src_k_tbl_ptr,
dst_k_tbl_ptr,
src_v_tbl_ptr,
dst_v_tbl_ptr);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......@@ -246,6 +405,43 @@ void transfer_kv_per_layer_pf_lf(
num_warps_per_block);
}
void transfer_kv_per_layer_ph_lf(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size,
int64_t src_layout_dim,
int64_t page_size,
int64_t head_num,
int64_t block_quota,
int64_t num_warps_per_block) {
at::Tensor empty;
transfer_kv_launcher<get_global_offset_ph<const char>, get_global_offset_per_head_lf<char>, false, true>(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
layer_id,
1,
item_size,
src_layout_dim,
0,
empty,
empty,
empty,
empty,
block_quota,
num_warps_per_block,
page_size,
head_num);
}
void transfer_kv_all_layer(
const at::Tensor src_k_layers,
const at::Tensor dst_k_layers,
......@@ -313,6 +509,44 @@ void transfer_kv_all_layer_lf_pf(
num_warps_per_block);
}
void transfer_kv_all_layer_lf_ph(
const at::Tensor src_k_layers,
at::Tensor dst_k,
const at::Tensor src_v_layers,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t page_size,
int64_t head_num,
int64_t block_quota,
int64_t num_warps_per_block) {
TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
at::Tensor empty;
transfer_kv_launcher<get_global_offset_per_head_lf_tbl<const char>, get_global_offset_ph<char>, false, true>(
empty,
dst_k,
empty,
dst_v,
src_indices,
dst_indices,
0,
num_layers,
item_size,
0,
dst_layout_dim,
src_k_layers,
empty,
src_v_layers,
empty,
block_quota,
num_warps_per_block,
page_size,
head_num);
}
void transfer_kv_per_layer_mla(
const at::Tensor src,
at::Tensor dst,
......
......@@ -562,6 +562,21 @@ void transfer_kv_per_layer_pf_lf(
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_ph_lf(
const at::Tensor src_k,
at::Tensor dst_k,
const at::Tensor src_v,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size,
int64_t src_layout_dim,
int64_t page_size,
int64_t head_num,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer(
const at::Tensor src_k_layers,
const at::Tensor dst_k_layers,
......@@ -587,6 +602,21 @@ void transfer_kv_all_layer_lf_pf(
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_all_layer_lf_ph(
const at::Tensor src_k_layers,
at::Tensor dst_k,
const at::Tensor src_v_layers,
at::Tensor dst_v,
const at::Tensor src_indices,
const at::Tensor dst_indices,
int64_t item_size,
int64_t dst_layout_dim,
int64_t num_layers,
int64_t page_size,
int64_t head_num,
int64_t block_quota,
int64_t num_warps_per_block);
void transfer_kv_per_layer_mla(
const at::Tensor src,
at::Tensor dst,
......
......@@ -21,7 +21,7 @@ def transfer_kv_per_layer(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer(
torch.ops.sgl_kernel.transfer_kv_per_layer.default(
src_k,
dst_k,
src_v,
......@@ -47,7 +47,7 @@ def transfer_kv_per_layer_pf_lf(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf.default(
src_k,
dst_k,
src_v,
......@@ -62,6 +62,38 @@ def transfer_kv_per_layer_pf_lf(
)
def transfer_kv_per_layer_ph_lf(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
page_size: int,
head_num: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_ph_lf.default(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
page_size,
head_num,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer(
src_k_layers: torch.Tensor,
dst_k_layers: torch.Tensor,
......@@ -74,7 +106,7 @@ def transfer_kv_all_layer(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer(
torch.ops.sgl_kernel.transfer_kv_all_layer.default(
src_k_layers,
dst_k_layers,
src_v_layers,
......@@ -101,7 +133,37 @@ def transfer_kv_all_layer_lf_pf(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf.default(
src_k_layers,
dst_k,
src_v_layers,
dst_v,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_lf_ph(
src_k_layers: torch.Tensor,
dst_k: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
page_size: int,
head_num: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_ph.default(
src_k_layers,
dst_k,
src_v_layers,
......@@ -111,6 +173,8 @@ def transfer_kv_all_layer_lf_pf(
item_size,
dst_layout_dim,
num_layers,
page_size,
head_num,
block_quota,
num_warps_per_block,
)
......@@ -123,7 +187,7 @@ def transfer_kv_direct(
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_direct(
torch.ops.sgl_kernel.transfer_kv_direct.default(
src_layers, dst_layers, src_indices, dst_indices, page_size
)
......@@ -136,7 +200,7 @@ def transfer_kv_per_layer_direct_pf_lf(
layer_id: int,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf(
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf.default(
src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size
)
......@@ -148,7 +212,7 @@ def transfer_kv_all_layer_direct_lf_pf(
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf(
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf.default(
src_ptrs, dst_ptrs, src_indices, dst_indices, page_size
)
......@@ -162,7 +226,7 @@ def transfer_kv_per_layer_mla(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
torch.ops.sgl_kernel.transfer_kv_per_layer_mla.default(
src,
dst,
src_indices,
......@@ -184,7 +248,7 @@ def transfer_kv_per_layer_mla_pf_lf(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf.default(
src,
dst,
src_indices,
......@@ -207,7 +271,7 @@ def transfer_kv_all_layer_mla(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
torch.ops.sgl_kernel.transfer_kv_all_layer_mla.default(
src_layers,
dst_layers,
src_indices,
......@@ -230,7 +294,7 @@ def transfer_kv_all_layer_mla_lf_pf(
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf.default(
src_layers,
dst,
src_indices,
......
......@@ -3,11 +3,13 @@ import torch
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_direct_lf_pf,
transfer_kv_all_layer_lf_ph,
transfer_kv_all_layer_mla,
transfer_kv_direct,
transfer_kv_per_layer,
transfer_kv_per_layer_direct_pf_lf,
transfer_kv_per_layer_mla,
transfer_kv_per_layer_ph_lf,
)
......@@ -30,6 +32,32 @@ def ref_copy_with_indices_pf_direct(
][layer_id].to(dst_pool.device)
def ref_copy_with_indices_page_head(
src_pool,
dst_pool,
src_indices,
dst_indices,
page_size,
layer_id,
head_num,
lf_to_ph=False,
):
if lf_to_ph:
for head_id in range(head_num):
for i in range(0, len(src_indices)):
dst_pool[dst_indices[i] // page_size][head_id][
dst_indices[i] % page_size
][layer_id] = src_pool[layer_id][src_indices[i]][head_id].to(
dst_pool.device
)
else:
for head_id in range(head_num):
for i in range(0, len(src_indices)):
dst_pool[layer_id][dst_indices[i]][head_id] = src_pool[
src_indices[i] // page_size
][head_id][src_indices[i] % page_size][layer_id].to(dst_pool.device)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
@pytest.mark.parametrize("page_size", [1, 16, 64])
......@@ -481,5 +509,182 @@ def test_transfer_kv_pf_direct(
torch.set_default_dtype(original_dtype)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [256, 1024])
@pytest.mark.parametrize("page_size", [16, 64, 128])
@pytest.mark.parametrize("item_size", [1024])
@pytest.mark.parametrize("head_num", [8, 16])
@pytest.mark.parametrize("total_items_in_pool", [4096])
@pytest.mark.parametrize("lf_to_ph", [False, True])
def test_transfer_kv_page_head(
dtype: torch.dtype,
num_items_to_transfer: int,
page_size: int,
item_size: int,
head_num: int,
total_items_in_pool: int,
lf_to_ph: bool,
):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
num_layers = 4
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
assert item_size % head_num == 0
head_dim = item_size // head_num
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if lf_to_ph:
src_k_pool = torch.randn(
num_layers, total_items_in_pool, head_num, head_dim
).to(device)
src_v_pool = torch.randn(
num_layers, total_items_in_pool, head_num, head_dim
).to(device)
src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)]
src_k_pool_ptrs = torch.tensor(
[x.data_ptr() for x in src_k_pool_ptrs],
dtype=torch.uint64,
device=device,
)
src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)]
src_v_pool_ptrs = torch.tensor(
[x.data_ptr() for x in src_v_pool_ptrs],
dtype=torch.uint64,
device=device,
)
dst_k_pool_ref = torch.zeros(
total_pages_in_pool, head_num, page_size, num_layers, head_dim
).pin_memory()
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref).pin_memory()
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref).pin_memory()
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref).pin_memory()
torch.cuda.synchronize()
transfer_kv_all_layer_lf_ph(
src_k_pool_ptrs,
dst_k_pool_kernel,
src_v_pool_ptrs,
dst_v_pool_kernel,
src_indices_device,
dst_indices_device,
item_size * dtype.itemsize,
item_size * num_layers * dtype.itemsize,
num_layers,
page_size,
head_num,
)
torch.cuda.synchronize()
for i in range(num_layers):
ref_copy_with_indices_page_head(
src_k_pool,
dst_k_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
head_num,
lf_to_ph=True,
)
ref_copy_with_indices_page_head(
src_v_pool,
dst_v_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
head_num,
lf_to_ph=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
else:
src_k_pool = torch.randn(
total_pages_in_pool, head_num, page_size, num_layers, head_dim
).pin_memory()
src_v_pool = torch.randn(
total_pages_in_pool, head_num, page_size, num_layers, head_dim
).pin_memory()
dst_k_pool_ref = torch.zeros(
num_layers, total_items_in_pool, head_num, head_dim
).to(device)
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
dst_k_pool_kernel_ptrs = [dst_k_pool_kernel[i] for i in range(num_layers)]
dst_v_pool_kernel_ptrs = [dst_v_pool_kernel[i] for i in range(num_layers)]
torch.cuda.synchronize()
transfer_kv_per_layer_ph_lf(
src_k_pool,
dst_k_pool_kernel_ptrs[layer_idx_to_test],
src_v_pool,
dst_v_pool_kernel_ptrs[layer_idx_to_test],
src_indices_device,
dst_indices_device,
layer_idx_to_test,
item_size * dtype.itemsize,
item_size * num_layers * dtype.itemsize,
page_size,
head_num,
)
ref_copy_with_indices_page_head(
src_k_pool,
dst_k_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
head_num,
lf_to_ph=False,
)
ref_copy_with_indices_page_head(
src_v_pool,
dst_v_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
head_num,
lf_to_ph=False,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment