Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
34c3f9b2
Unverified
Commit
34c3f9b2
authored
Jun 23, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jun 23, 2025
Browse files
kvcache io kernels and test case (#7382)
parent
76139bfb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
845 additions
and
0 deletions
+845
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+37
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+342
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+83
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+6
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+137
-0
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+239
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
34c3f9b2
...
...
@@ -250,6 +250,7 @@ set(SOURCES
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/common_extension.cc"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
34c3f9b2
...
...
@@ -230,6 +230,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"int cuda_stream) -> ()"
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
/*
* From csrc/kvcacheio
*/
m
.
def
(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
def
(
"transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_direct"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_direct
);
m
.
def
(
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int "
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
def
(
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size, int num_layers) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_direct
);
m
.
def
(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
def
(
"transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) "
"-> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla_direct"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_direct
);
m
.
def
(
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
def
(
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
"int num_layers) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_direct
);
/*
* From FlashInfer
*/
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
0 → 100644
View file @
34c3f9b2
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/irange.h>
#include <cstdint>
#include "pytorch_extension_utils.h"
__device__
__forceinline__
void
transfer_item_warp
(
int32_t
lane_id
,
const
void
*
src_addr
,
void
*
dst_addr
,
int64_t
item_size_bytes
)
{
// todo, different chunk size
int
total_chunks
=
item_size_bytes
/
8
;
const
int64_t
*
src_8
=
reinterpret_cast
<
const
int64_t
*>
(
src_addr
);
int64_t
*
dst_8
=
reinterpret_cast
<
int64_t
*>
(
dst_addr
);
#pragma unroll
for
(
int
j
=
lane_id
;
j
<
total_chunks
;
j
+=
32
)
{
const
int64_t
*
src_addr_lane
=
&
src_8
[
j
];
int64_t
*
dst_addr_lane
=
&
dst_8
[
j
];
int64_t
temp_val
;
asm
volatile
(
"ld.global.nc.b64 %0, [%1];"
:
"=l"
(
temp_val
)
:
"l"
(
src_addr_lane
)
:
"memory"
);
asm
volatile
(
"st.global.cg.b64 [%0], %1;"
::
"l"
(
dst_addr_lane
),
"l"
(
temp_val
)
:
"memory"
);
}
}
// todo, structs for different memory layout
__device__
__forceinline__
int64_t
get_global_offset_lf
(
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// layer first
return
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
;
}
__device__
__forceinline__
int64_t
get_global_offset_pf
(
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// page first
return
page_id
*
page_dim
+
layer_id
*
item_size_bytes
;
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
__global__
void
transfer_kernel_impl
(
const
void
*
__restrict__
src_k
,
void
*
__restrict__
dst_k
,
const
void
*
__restrict__
src_v
,
void
*
__restrict__
dst_v
,
const
int64_t
*
__restrict__
src_indices
,
const
int64_t
*
__restrict__
dst_indices
,
int64_t
start_layer_id
,
int64_t
num_layers_to_process
,
int64_t
num_items
,
int64_t
items_per_warp
,
int64_t
item_size_bytes
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
)
{
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
lane_id
=
tid
%
32
;
int32_t
warp_id
=
tid
/
32
;
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
int32_t
item_id
=
warp_id
*
items_per_warp
+
i
;
if
(
item_id
>=
num_items
)
{
return
;
}
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
// Loop over layers if necessary
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
// Calculate offsets using the provided function pointers
const
int64_t
src_offset
=
SrcOffsetFn
(
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
const
int64_t
dst_offset
=
DstOffsetFn
(
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
if
constexpr
(
IsMLA
)
{
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
item_size_bytes
);
}
else
{
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
item_size_bytes
);
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_v
)
+
src_offset
,
static_cast
<
char
*>
(
dst_v
)
+
dst_offset
,
item_size_bytes
);
}
}
}
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
void
transfer_kv_launcher
(
const
at
::
Tensor
&
src_k
,
at
::
Tensor
&
dst_k
,
const
at
::
Tensor
&
src_v
,
at
::
Tensor
&
dst_v
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
start_layer_id
,
int64_t
num_layers_to_process
,
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
src_k
.
scalar_type
()
==
dst_k
.
scalar_type
(),
"Source and destination keys must have the same type"
);
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
TORCH_CHECK
(
dst_indices
.
scalar_type
()
==
at
::
kLong
,
"Destination indices must be of type long"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
if
(
!
IsMLA
)
{
TORCH_CHECK
(
src_v
.
scalar_type
()
==
dst_v
.
scalar_type
(),
"Source and destination values must have the same type"
);
}
int
dtype_size
=
src_k
.
element_size
();
TORCH_CHECK
((
item_size
*
dtype_size
)
%
8
==
0
,
"Item byte size must be divisible by 8"
);
auto
div_up
=
[](
int32_t
x
,
int32_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
const
int64_t
num_items
=
src_indices
.
numel
();
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k
.
data_ptr
(),
dst_k
.
data_ptr
(),
(
IsMLA
?
nullptr
:
src_v
.
data_ptr
()),
(
IsMLA
?
nullptr
:
dst_v
.
data_ptr
()),
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
num_layers_to_process
,
num_items
,
items_per_warp
,
item_size
*
dtype_size
,
src_layout_dim
*
dtype_size
,
dst_layout_dim
*
dtype_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
void
transfer_kv_per_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
);
}
inline
void
transfer_page_direct
(
const
at
::
Tensor
src_buffer
,
at
::
Tensor
dst_buffer
,
int64_t
src_page_index
,
int64_t
dst_page_index
,
int64_t
page_size
)
{
dst_buffer
.
slice
(
0
,
dst_page_index
,
dst_page_index
+
page_size
)
.
copy_
(
src_buffer
.
slice
(
0
,
src_page_index
,
src_page_index
+
page_size
),
/* non_blocking= */
true
);
}
template
<
bool
IsMLA
,
bool
AllLayers
>
inline
void
transfer_kv_direct_impl
(
const
at
::
Tensor
&
src_k
,
at
::
Tensor
&
dst_k
,
const
at
::
Tensor
&
src_v_opt
,
// Only used when IsMLA is false (for src_v)
at
::
Tensor
&
dst_v_opt
,
// Only used when IsMLA is false (for dst_v)
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
=
1
)
{
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
auto
src_indices_cpu
=
src_indices
.
cpu
();
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
for
(
const
auto
i
:
c10
::
irange
(
num_pages
))
{
auto
s_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
if
constexpr
(
AllLayers
)
{
for
(
const
auto
j
:
c10
::
irange
(
num_layers
))
{
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
.
select
(
0
,
j
),
dst_v_opt
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
}
}
else
{
// Per-layer
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
,
dst_v_opt
,
s_index
,
d_index
,
page_size
);
}
}
}
}
void
transfer_kv_per_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
transfer_kv_direct_impl
<
false
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
transfer_kv_direct_impl
<
false
,
true
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
void
transfer_kv_per_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
false
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
34c3f9b2
...
...
@@ -371,6 +371,89 @@ void segment_packbits(
int64_t
batch_size
,
int64_t
cuda_stream
=
0
);
/*
* From csrc/kvcacheio
*/
void
transfer_kv_per_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
);
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
);
/*
* From FlashInfer
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
34c3f9b2
...
...
@@ -47,6 +47,12 @@ from sgl_kernel.gemm import (
shuffle_rows
,
)
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_mla
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
)
from
sgl_kernel.moe
import
(
apply_shuffle_mul_sum
,
cutlass_fp4_group_mm
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
0 → 100644
View file @
34c3f9b2
import
torch
def
transfer_kv_per_layer
(
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_direct
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer
(
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_per_layer_mla
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
(
src
,
dst
,
src_indices
,
dst_indices
,
item_size
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_direct
(
src
,
dst
,
src_indices
,
dst_indices
,
page_size
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_mla
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
src
,
dst
,
src_indices
,
dst_indices
,
item_size
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_direct
(
src
,
dst
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
sgl-kernel/tests/test_kvcacheio.py
0 → 100644
View file @
34c3f9b2
import
pytest
import
torch
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_mla
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
)
def
ref_copy_with_indices
(
src_pool
,
dst_pool
,
src_indices
,
dst_indices
):
dst_pool
[
dst_indices
]
=
src_pool
[
src_indices
].
to
(
dst_pool
.
device
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_items_to_transfer"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
1
,
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"item_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"total_items_in_pool"
,
[
10240
])
@
pytest
.
mark
.
parametrize
(
"is_mla"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"all_layers"
,
[
False
,
True
])
def
test_transfer_kv
(
dtype
:
torch
.
dtype
,
num_items_to_transfer
:
int
,
item_size
:
int
,
page_size
:
int
,
total_items_in_pool
:
int
,
is_mla
:
bool
,
all_layers
:
bool
,
):
"""
Tests the per-layer transfer functions, treating tensors as memory pools.
"""
original_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
device
=
"cuda"
torch
.
cuda
.
manual_seed
(
42
)
num_layers
=
4
# A small number of layers for pool creation
total_pages_in_pool
=
total_items_in_pool
//
page_size
num_pages_to_transfer
=
num_items_to_transfer
//
page_size
if
num_pages_to_transfer
==
0
:
torch
.
set_default_dtype
(
original_dtype
)
return
page_indices
=
torch
.
randperm
(
total_pages_in_pool
,
dtype
=
torch
.
int64
)
src_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[:
num_pages_to_transfer
]
]
)
src_indices_device
=
src_indices_host
.
to
(
device
)
dst_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[
num_pages_to_transfer
:
2
*
num_pages_to_transfer
]
]
)
dst_indices_device
=
dst_indices_host
.
to
(
device
)
# Prepare memory pools based on whether it's an MLA case.
if
is_mla
:
src_pool_host
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
pin_memory
()
dst_pool_ref
=
torch
.
zeros_like
(
src_pool_host
).
to
(
device
)
dst_pool_kernel
=
torch
.
zeros_like
(
dst_pool_ref
)
dst_pool_direct
=
torch
.
zeros_like
(
dst_pool_ref
)
else
:
src_k_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
pin_memory
()
src_v_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
pin_memory
()
dst_k_pool_ref
=
torch
.
zeros_like
(
src_k_pool
).
to
(
device
)
dst_v_pool_ref
=
torch
.
zeros_like
(
src_v_pool
).
to
(
device
)
dst_k_pool_kernel
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_kernel
=
torch
.
zeros_like
(
dst_v_pool_ref
)
dst_k_pool_direct
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_direct
=
torch
.
zeros_like
(
dst_v_pool_ref
)
torch
.
cuda
.
synchronize
()
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test
=
0
if
is_mla
:
if
not
all_layers
:
ref_copy_with_indices
(
src_pool_host
[
layer_idx_to_test
],
dst_pool_ref
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_per_layer_mla
(
src_pool_host
[
layer_idx_to_test
],
dst_pool_kernel
[
layer_idx_to_test
],
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
transfer_kv_per_layer_mla
(
src_pool_host
[
layer_idx_to_test
],
dst_pool_direct
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
else
:
for
layer_id
in
range
(
num_layers
):
ref_copy_with_indices
(
src_pool_host
[
layer_id
],
dst_pool_ref
[
layer_id
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_all_layer_mla
(
src_pool_host
,
dst_pool_kernel
,
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
transfer_kv_all_layer_mla
(
src_pool_host
,
dst_pool_direct
,
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_pool_kernel
,
dst_pool_ref
)
torch
.
testing
.
assert_close
(
dst_pool_direct
,
dst_pool_ref
)
else
:
if
not
all_layers
:
ref_copy_with_indices
(
src_k_pool
[
layer_idx_to_test
],
dst_k_pool_ref
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
)
ref_copy_with_indices
(
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_ref
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_per_layer
(
src_k_pool
[
layer_idx_to_test
],
dst_k_pool_kernel
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_kernel
[
layer_idx_to_test
],
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
transfer_kv_per_layer
(
src_k_pool
[
layer_idx_to_test
],
dst_k_pool_direct
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_direct
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
else
:
for
layer_id
in
range
(
num_layers
):
ref_copy_with_indices
(
src_k_pool
[
layer_id
],
dst_k_pool_ref
[
layer_id
],
src_indices_host
,
dst_indices_device
,
)
ref_copy_with_indices
(
src_v_pool
[
layer_id
],
dst_v_pool_ref
[
layer_id
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_all_layer
(
src_k_pool
,
dst_k_pool_kernel
,
src_v_pool
,
dst_v_pool_kernel
,
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
transfer_kv_all_layer
(
src_k_pool
,
dst_k_pool_direct
,
src_v_pool
,
dst_v_pool_direct
,
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_kernel
,
dst_v_pool_ref
)
torch
.
testing
.
assert_close
(
dst_k_pool_direct
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_direct
,
dst_v_pool_ref
)
torch
.
set_default_dtype
(
original_dtype
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment