Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3e6281d0
Unverified
Commit
3e6281d0
authored
Oct 26, 2025
by
huangtingwei
Committed by
GitHub
Oct 26, 2025
Browse files
[HiCache]Page head layout IO kernel (#11615)
parent
6371f7af
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
574 additions
and
31 deletions
+574
-31
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+10
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+254
-20
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+30
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+75
-11
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+205
-0
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
3e6281d0
...
...
@@ -370,6 +370,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_pf_lf"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_pf_lf
);
m
.
def
(
"transfer_kv_per_layer_ph_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int page_size, int head_num, int block_quota, int "
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_ph_lf"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_ph_lf
);
m
.
def
(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
...
...
@@ -380,6 +385,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_lf_pf
);
m
.
def
(
"transfer_kv_all_layer_lf_ph(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int page_size, int "
"head_num, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_lf_ph"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_lf_ph
);
m
.
def
(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
3e6281d0
...
...
@@ -68,6 +68,140 @@ __device__ __forceinline__ T* get_global_offset_lf_tbl(
return
reinterpret_cast
<
T
*>
(
layer_base_tbl
[
layer_id
])
+
page_id
*
item_size_bytes
;
}
template
<
typename
T
>
__device__
__forceinline__
T
*
get_global_offset_per_head_lf
(
T
*
base
,
const
uintptr_t
*
__restrict__
/*unused*/
,
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
,
int64_t
head_id
,
int64_t
head_num
,
int64_t
/*unused*/
)
{
// layer first offset func per head
return
base
+
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
+
item_size_bytes
/
head_num
*
head_id
;
}
template
<
typename
T
>
__device__
__forceinline__
T
*
get_global_offset_per_head_lf_tbl
(
T
*
/*unused*/
,
const
uintptr_t
*
__restrict__
layer_base_tbl
,
int64_t
layer_id
,
int64_t
/*unused*/
,
int64_t
page_id
,
int64_t
item_size_bytes
,
int64_t
head_id
,
int64_t
head_num
,
int64_t
/*unused*/
)
{
return
reinterpret_cast
<
T
*>
(
layer_base_tbl
[
layer_id
])
+
page_id
*
item_size_bytes
+
item_size_bytes
/
head_num
*
head_id
;
}
template
<
typename
T
>
__device__
__forceinline__
T
*
get_global_offset_ph
(
T
*
base
,
const
uintptr_t
*
__restrict__
/*unused*/
,
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
,
int64_t
head_id
,
int64_t
head_num
,
int64_t
page_size
)
{
// page head layout: [page_num, head_num, page_size, layer_num, head_dim]
return
base
+
page_id
/
page_size
*
page_size
*
page_dim
+
// page_num dimension offset
page_dim
/
head_num
*
head_id
*
page_size
+
// head_num dimension offset
page_id
%
page_size
*
page_dim
/
head_num
+
// page_size dimension offset
layer_id
*
item_size_bytes
/
head_num
;
// layer_num dimension offset
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
>
__global__
void
transfer_page_head_kernel_impl
(
const
void
*
__restrict__
src_k
,
void
*
__restrict__
dst_k
,
const
void
*
__restrict__
src_v
,
void
*
__restrict__
dst_v
,
const
int64_t
*
__restrict__
src_indices
,
const
int64_t
*
__restrict__
dst_indices
,
int64_t
start_layer_id
,
int64_t
num_layers_to_process
,
int64_t
num_items
,
int64_t
items_per_warp
,
int64_t
item_size_bytes
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
,
const
uintptr_t
*
__restrict__
src_k_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_k_layer_tbl
,
const
uintptr_t
*
__restrict__
src_v_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_v_layer_tbl
,
const
int64_t
page_size
,
const
int64_t
head_num
)
{
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
lane_id
=
tid
%
WARP_SIZE
;
int32_t
warp_id
=
tid
/
WARP_SIZE
;
const
int64_t
head_size_bytes
=
item_size_bytes
/
head_num
;
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
int64_t
item_id
=
warp_id
*
items_per_warp
+
i
;
if
(
item_id
>=
num_items
)
{
break
;
}
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
// Loop over layers if necessary
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
// For page head layout, the cache of each head in the token is discontinuous, need to loop
for
(
int64_t
head_id
=
0
;
head_id
<
head_num
;
++
head_id
)
{
const
char
*
src_k_ptr
=
SrcOffsetFn
(
static_cast
<
const
char
*>
(
src_k
),
src_k_layer_tbl
,
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
,
head_id
,
head_num
,
page_size
);
char
*
dst_k_ptr
=
DstOffsetFn
(
static_cast
<
char
*>
(
dst_k
),
dst_k_layer_tbl
,
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
,
head_id
,
head_num
,
page_size
);
transfer_item_warp
(
lane_id
,
src_k_ptr
,
dst_k_ptr
,
head_size_bytes
);
const
char
*
src_v_ptr
=
SrcOffsetFn
(
static_cast
<
const
char
*>
(
src_v
),
src_v_layer_tbl
,
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
,
head_id
,
head_num
,
page_size
);
char
*
dst_v_ptr
=
DstOffsetFn
(
static_cast
<
char
*>
(
dst_v
),
dst_v_layer_tbl
,
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
,
head_id
,
head_num
,
page_size
);
transfer_item_warp
(
lane_id
,
src_v_ptr
,
dst_v_ptr
,
head_size_bytes
);
}
}
}
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
__global__
void
transfer_kernel_impl
(
const
void
*
__restrict__
src_k
,
...
...
@@ -118,7 +252,7 @@ __global__ void transfer_kernel_impl(
}
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
,
bool
PageHeadLayout
=
false
>
void
transfer_kv_launcher
(
const
at
::
Tensor
&
src_k
,
at
::
Tensor
&
dst_k
,
...
...
@@ -136,7 +270,9 @@ void transfer_kv_launcher(
const
at
::
Tensor
&
src_v_layers
,
const
at
::
Tensor
&
dst_v_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
,
const
int64_t
page_size
=
16
,
const
int64_t
head_num
=
1
)
{
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
...
...
@@ -161,24 +297,47 @@ void transfer_kv_launcher(
const
uintptr_t
*
dst_v_tbl_ptr
=
IsMLA
||
!
dst_v_layers
.
defined
()
?
nullptr
:
dst_v_layers
.
data_ptr
<
uintptr_t
>
();
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k_ptr
,
dst_k_ptr
,
src_v_ptr
,
dst_v_ptr
,
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
num_layers_to_process
,
num_items
,
items_per_warp
,
item_size
,
src_layout_dim
,
dst_layout_dim
,
src_k_tbl_ptr
,
dst_k_tbl_ptr
,
src_v_tbl_ptr
,
dst_v_tbl_ptr
);
if
constexpr
(
PageHeadLayout
)
{
transfer_page_head_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k_ptr
,
dst_k_ptr
,
src_v_ptr
,
dst_v_ptr
,
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
num_layers_to_process
,
num_items
,
items_per_warp
,
item_size
,
src_layout_dim
,
dst_layout_dim
,
src_k_tbl_ptr
,
dst_k_tbl_ptr
,
src_v_tbl_ptr
,
dst_v_tbl_ptr
,
page_size
,
head_num
);
}
else
{
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k_ptr
,
dst_k_ptr
,
src_v_ptr
,
dst_v_ptr
,
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
num_layers_to_process
,
num_items
,
items_per_warp
,
item_size
,
src_layout_dim
,
dst_layout_dim
,
src_k_tbl_ptr
,
dst_k_tbl_ptr
,
src_v_tbl_ptr
,
dst_v_tbl_ptr
);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
...
...
@@ -246,6 +405,43 @@ void transfer_kv_per_layer_pf_lf(
num_warps_per_block
);
}
void
transfer_kv_per_layer_ph_lf
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
layer_id
,
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
page_size
,
int64_t
head_num
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_ph
<
const
char
>
,
get_global_offset_per_head_lf
<
char
>
,
false
,
true
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
layer_id
,
1
,
item_size
,
src_layout_dim
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
,
page_size
,
head_num
);
}
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k_layers
,
const
at
::
Tensor
dst_k_layers
,
...
...
@@ -313,6 +509,44 @@ void transfer_kv_all_layer_lf_pf(
num_warps_per_block
);
}
void
transfer_kv_all_layer_lf_ph
(
const
at
::
Tensor
src_k_layers
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v_layers
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
page_size
,
int64_t
head_num
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_k_layers
.
size
(
0
),
"Number of layers in source k tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_per_head_lf_tbl
<
const
char
>
,
get_global_offset_ph
<
char
>
,
false
,
true
>
(
empty
,
dst_k
,
empty
,
dst_v
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
0
,
dst_layout_dim
,
src_k_layers
,
empty
,
src_v_layers
,
empty
,
block_quota
,
num_warps_per_block
,
page_size
,
head_num
);
}
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
3e6281d0
...
...
@@ -562,6 +562,21 @@ void transfer_kv_per_layer_pf_lf(
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_ph_lf
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
layer_id
,
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
page_size
,
int64_t
head_num
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k_layers
,
const
at
::
Tensor
dst_k_layers
,
...
...
@@ -587,6 +602,21 @@ void transfer_kv_all_layer_lf_pf(
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_lf_ph
(
const
at
::
Tensor
src_k_layers
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v_layers
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
page_size
,
int64_t
head_num
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
3e6281d0
...
...
@@ -21,7 +21,7 @@ def transfer_kv_per_layer(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
.
default
(
src_k
,
dst_k
,
src_v
,
...
...
@@ -47,7 +47,7 @@ def transfer_kv_per_layer_pf_lf(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_pf_lf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_pf_lf
.
default
(
src_k
,
dst_k
,
src_v
,
...
...
@@ -62,6 +62,38 @@ def transfer_kv_per_layer_pf_lf(
)
def
transfer_kv_per_layer_ph_lf
(
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
layer_id
:
int
,
item_size
:
int
,
src_layout_dim
:
int
,
page_size
:
int
,
head_num
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_ph_lf
.
default
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
layer_id
,
item_size
,
src_layout_dim
,
page_size
,
head_num
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_all_layer
(
src_k_layers
:
torch
.
Tensor
,
dst_k_layers
:
torch
.
Tensor
,
...
...
@@ -74,7 +106,7 @@ def transfer_kv_all_layer(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
.
default
(
src_k_layers
,
dst_k_layers
,
src_v_layers
,
...
...
@@ -101,7 +133,37 @@ def transfer_kv_all_layer_lf_pf(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_lf_pf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_lf_pf
.
default
(
src_k_layers
,
dst_k
,
src_v_layers
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
dst_layout_dim
,
num_layers
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_all_layer_lf_ph
(
src_k_layers
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v_layers
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
dst_layout_dim
:
int
,
num_layers
:
int
,
page_size
:
int
,
head_num
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_lf_ph
.
default
(
src_k_layers
,
dst_k
,
src_v_layers
,
...
...
@@ -111,6 +173,8 @@ def transfer_kv_all_layer_lf_pf(
item_size
,
dst_layout_dim
,
num_layers
,
page_size
,
head_num
,
block_quota
,
num_warps_per_block
,
)
...
...
@@ -123,7 +187,7 @@ def transfer_kv_direct(
dst_indices
:
torch
.
Tensor
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
.
default
(
src_layers
,
dst_layers
,
src_indices
,
dst_indices
,
page_size
)
...
...
@@ -136,7 +200,7 @@ def transfer_kv_per_layer_direct_pf_lf(
layer_id
:
int
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_direct_pf_lf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_direct_pf_lf
.
default
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
layer_id
,
page_size
)
...
...
@@ -148,7 +212,7 @@ def transfer_kv_all_layer_direct_lf_pf(
dst_indices
:
torch
.
Tensor
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct_lf_pf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct_lf_pf
.
default
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
page_size
)
...
...
@@ -162,7 +226,7 @@ def transfer_kv_per_layer_mla(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
.
default
(
src
,
dst
,
src_indices
,
...
...
@@ -184,7 +248,7 @@ def transfer_kv_per_layer_mla_pf_lf(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_pf_lf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_pf_lf
.
default
(
src
,
dst
,
src_indices
,
...
...
@@ -207,7 +271,7 @@ def transfer_kv_all_layer_mla(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
.
default
(
src_layers
,
dst_layers
,
src_indices
,
...
...
@@ -230,7 +294,7 @@ def transfer_kv_all_layer_mla_lf_pf(
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
16
if
_is_hip
else
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_lf_pf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_lf_pf
.
default
(
src_layers
,
dst
,
src_indices
,
...
...
sgl-kernel/tests/test_kvcacheio.py
View file @
3e6281d0
...
...
@@ -3,11 +3,13 @@ import torch
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_direct_lf_pf
,
transfer_kv_all_layer_lf_ph
,
transfer_kv_all_layer_mla
,
transfer_kv_direct
,
transfer_kv_per_layer
,
transfer_kv_per_layer_direct_pf_lf
,
transfer_kv_per_layer_mla
,
transfer_kv_per_layer_ph_lf
,
)
...
...
@@ -30,6 +32,32 @@ def ref_copy_with_indices_pf_direct(
][
layer_id
].
to
(
dst_pool
.
device
)
def
ref_copy_with_indices_page_head
(
src_pool
,
dst_pool
,
src_indices
,
dst_indices
,
page_size
,
layer_id
,
head_num
,
lf_to_ph
=
False
,
):
if
lf_to_ph
:
for
head_id
in
range
(
head_num
):
for
i
in
range
(
0
,
len
(
src_indices
)):
dst_pool
[
dst_indices
[
i
]
//
page_size
][
head_id
][
dst_indices
[
i
]
%
page_size
][
layer_id
]
=
src_pool
[
layer_id
][
src_indices
[
i
]][
head_id
].
to
(
dst_pool
.
device
)
else
:
for
head_id
in
range
(
head_num
):
for
i
in
range
(
0
,
len
(
src_indices
)):
dst_pool
[
layer_id
][
dst_indices
[
i
]][
head_id
]
=
src_pool
[
src_indices
[
i
]
//
page_size
][
head_id
][
src_indices
[
i
]
%
page_size
][
layer_id
].
to
(
dst_pool
.
device
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_items_to_transfer"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
1
,
16
,
64
])
...
...
@@ -481,5 +509,182 @@ def test_transfer_kv_pf_direct(
torch
.
set_default_dtype
(
original_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_items_to_transfer"
,
[
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
16
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"item_size"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"head_num"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"total_items_in_pool"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"lf_to_ph"
,
[
False
,
True
])
def
test_transfer_kv_page_head
(
dtype
:
torch
.
dtype
,
num_items_to_transfer
:
int
,
page_size
:
int
,
item_size
:
int
,
head_num
:
int
,
total_items_in_pool
:
int
,
lf_to_ph
:
bool
,
):
original_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
device
=
"cuda"
torch
.
cuda
.
manual_seed
(
42
)
num_layers
=
4
total_pages_in_pool
=
total_items_in_pool
//
page_size
num_pages_to_transfer
=
num_items_to_transfer
//
page_size
if
num_pages_to_transfer
==
0
:
torch
.
set_default_dtype
(
original_dtype
)
return
assert
item_size
%
head_num
==
0
head_dim
=
item_size
//
head_num
page_indices
=
torch
.
randperm
(
total_pages_in_pool
,
dtype
=
torch
.
int64
)
src_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[:
num_pages_to_transfer
]
]
)
src_indices_device
=
src_indices_host
.
to
(
device
)
dst_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[
num_pages_to_transfer
:
2
*
num_pages_to_transfer
]
]
)
dst_indices_device
=
dst_indices_host
.
to
(
device
)
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test
=
0
if
lf_to_ph
:
src_k_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
head_num
,
head_dim
).
to
(
device
)
src_v_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
head_num
,
head_dim
).
to
(
device
)
src_k_pool_ptrs
=
[
src_k_pool
[
i
]
for
i
in
range
(
num_layers
)]
src_k_pool_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
src_k_pool_ptrs
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
src_v_pool_ptrs
=
[
src_v_pool
[
i
]
for
i
in
range
(
num_layers
)]
src_v_pool_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
src_v_pool_ptrs
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_k_pool_ref
=
torch
.
zeros
(
total_pages_in_pool
,
head_num
,
page_size
,
num_layers
,
head_dim
).
pin_memory
()
dst_v_pool_ref
=
torch
.
zeros_like
(
dst_k_pool_ref
).
pin_memory
()
dst_k_pool_kernel
=
torch
.
zeros_like
(
dst_k_pool_ref
).
pin_memory
()
dst_v_pool_kernel
=
torch
.
zeros_like
(
dst_v_pool_ref
).
pin_memory
()
torch
.
cuda
.
synchronize
()
transfer_kv_all_layer_lf_ph
(
src_k_pool_ptrs
,
dst_k_pool_kernel
,
src_v_pool_ptrs
,
dst_v_pool_kernel
,
src_indices_device
,
dst_indices_device
,
item_size
*
dtype
.
itemsize
,
item_size
*
num_layers
*
dtype
.
itemsize
,
num_layers
,
page_size
,
head_num
,
)
torch
.
cuda
.
synchronize
()
for
i
in
range
(
num_layers
):
ref_copy_with_indices_page_head
(
src_k_pool
,
dst_k_pool_ref
,
src_indices_device
,
dst_indices_host
,
page_size
,
i
,
head_num
,
lf_to_ph
=
True
,
)
ref_copy_with_indices_page_head
(
src_v_pool
,
dst_v_pool_ref
,
src_indices_device
,
dst_indices_host
,
page_size
,
i
,
head_num
,
lf_to_ph
=
True
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_kernel
,
dst_v_pool_ref
)
else
:
src_k_pool
=
torch
.
randn
(
total_pages_in_pool
,
head_num
,
page_size
,
num_layers
,
head_dim
).
pin_memory
()
src_v_pool
=
torch
.
randn
(
total_pages_in_pool
,
head_num
,
page_size
,
num_layers
,
head_dim
).
pin_memory
()
dst_k_pool_ref
=
torch
.
zeros
(
num_layers
,
total_items_in_pool
,
head_num
,
head_dim
).
to
(
device
)
dst_v_pool_ref
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_k_pool_kernel
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_kernel
=
torch
.
zeros_like
(
dst_v_pool_ref
)
dst_k_pool_kernel_ptrs
=
[
dst_k_pool_kernel
[
i
]
for
i
in
range
(
num_layers
)]
dst_v_pool_kernel_ptrs
=
[
dst_v_pool_kernel
[
i
]
for
i
in
range
(
num_layers
)]
torch
.
cuda
.
synchronize
()
transfer_kv_per_layer_ph_lf
(
src_k_pool
,
dst_k_pool_kernel_ptrs
[
layer_idx_to_test
],
src_v_pool
,
dst_v_pool_kernel_ptrs
[
layer_idx_to_test
],
src_indices_device
,
dst_indices_device
,
layer_idx_to_test
,
item_size
*
dtype
.
itemsize
,
item_size
*
num_layers
*
dtype
.
itemsize
,
page_size
,
head_num
,
)
ref_copy_with_indices_page_head
(
src_k_pool
,
dst_k_pool_ref
,
src_indices_host
,
dst_indices_device
,
page_size
,
layer_idx_to_test
,
head_num
,
lf_to_ph
=
False
,
)
ref_copy_with_indices_page_head
(
src_v_pool
,
dst_v_pool_ref
,
src_indices_host
,
dst_indices_device
,
page_size
,
layer_idx_to_test
,
head_num
,
lf_to_ph
=
False
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_kernel
,
dst_v_pool_ref
)
torch
.
set_default_dtype
(
original_dtype
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment