Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
34c3f9b2
"vscode:/vscode.git/clone" did not exist on "7ecd229ba475dbf78040f368ae86c86bba875442"
Unverified
Commit
34c3f9b2
authored
Jun 23, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jun 23, 2025
Browse files
kvcache io kernels and test case (#7382)
parent
76139bfb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
845 additions
and
0 deletions
+845
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+37
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+342
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+83
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+6
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+137
-0
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+239
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
34c3f9b2
...
@@ -250,6 +250,7 @@ set(SOURCES
...
@@ -250,6 +250,7 @@ set(SOURCES
"csrc/speculative/packbit.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/common_extension.cc"
"csrc/common_extension.cc"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
34c3f9b2
...
@@ -230,6 +230,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -230,6 +230,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"int cuda_stream) -> ()"
);
"int cuda_stream) -> ()"
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
/*
* From csrc/kvcacheio
*/
m
.
def
(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
def
(
"transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_direct"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_direct
);
m
.
def
(
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int "
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
def
(
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int page_size, int num_layers) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_direct
);
m
.
def
(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
def
(
"transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) "
"-> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla_direct"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_direct
);
m
.
def
(
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
def
(
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
"int num_layers) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_direct
);
/*
/*
* From FlashInfer
* From FlashInfer
*/
*/
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
0 → 100644
View file @
34c3f9b2
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/irange.h>
#include <cstdint>
#include "pytorch_extension_utils.h"
__device__
__forceinline__
void
transfer_item_warp
(
int32_t
lane_id
,
const
void
*
src_addr
,
void
*
dst_addr
,
int64_t
item_size_bytes
)
{
// todo, different chunk size
int
total_chunks
=
item_size_bytes
/
8
;
const
int64_t
*
src_8
=
reinterpret_cast
<
const
int64_t
*>
(
src_addr
);
int64_t
*
dst_8
=
reinterpret_cast
<
int64_t
*>
(
dst_addr
);
#pragma unroll
for
(
int
j
=
lane_id
;
j
<
total_chunks
;
j
+=
32
)
{
const
int64_t
*
src_addr_lane
=
&
src_8
[
j
];
int64_t
*
dst_addr_lane
=
&
dst_8
[
j
];
int64_t
temp_val
;
asm
volatile
(
"ld.global.nc.b64 %0, [%1];"
:
"=l"
(
temp_val
)
:
"l"
(
src_addr_lane
)
:
"memory"
);
asm
volatile
(
"st.global.cg.b64 [%0], %1;"
::
"l"
(
dst_addr_lane
),
"l"
(
temp_val
)
:
"memory"
);
}
}
// todo, structs for different memory layout
__device__
__forceinline__
int64_t
get_global_offset_lf
(
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// layer first
return
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
;
}
__device__
__forceinline__
int64_t
get_global_offset_pf
(
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// page first
return
page_id
*
page_dim
+
layer_id
*
item_size_bytes
;
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
__global__
void
transfer_kernel_impl
(
const
void
*
__restrict__
src_k
,
void
*
__restrict__
dst_k
,
const
void
*
__restrict__
src_v
,
void
*
__restrict__
dst_v
,
const
int64_t
*
__restrict__
src_indices
,
const
int64_t
*
__restrict__
dst_indices
,
int64_t
start_layer_id
,
int64_t
num_layers_to_process
,
int64_t
num_items
,
int64_t
items_per_warp
,
int64_t
item_size_bytes
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
)
{
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
lane_id
=
tid
%
32
;
int32_t
warp_id
=
tid
/
32
;
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
int32_t
item_id
=
warp_id
*
items_per_warp
+
i
;
if
(
item_id
>=
num_items
)
{
return
;
}
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
// Loop over layers if necessary
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
// Calculate offsets using the provided function pointers
const
int64_t
src_offset
=
SrcOffsetFn
(
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
const
int64_t
dst_offset
=
DstOffsetFn
(
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
if
constexpr
(
IsMLA
)
{
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
item_size_bytes
);
}
else
{
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
item_size_bytes
);
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_v
)
+
src_offset
,
static_cast
<
char
*>
(
dst_v
)
+
dst_offset
,
item_size_bytes
);
}
}
}
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
void
transfer_kv_launcher
(
const
at
::
Tensor
&
src_k
,
at
::
Tensor
&
dst_k
,
const
at
::
Tensor
&
src_v
,
at
::
Tensor
&
dst_v
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
start_layer_id
,
int64_t
num_layers_to_process
,
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
src_k
.
scalar_type
()
==
dst_k
.
scalar_type
(),
"Source and destination keys must have the same type"
);
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
TORCH_CHECK
(
dst_indices
.
scalar_type
()
==
at
::
kLong
,
"Destination indices must be of type long"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
if
(
!
IsMLA
)
{
TORCH_CHECK
(
src_v
.
scalar_type
()
==
dst_v
.
scalar_type
(),
"Source and destination values must have the same type"
);
}
int
dtype_size
=
src_k
.
element_size
();
TORCH_CHECK
((
item_size
*
dtype_size
)
%
8
==
0
,
"Item byte size must be divisible by 8"
);
auto
div_up
=
[](
int32_t
x
,
int32_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
const
int64_t
num_items
=
src_indices
.
numel
();
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k
.
data_ptr
(),
dst_k
.
data_ptr
(),
(
IsMLA
?
nullptr
:
src_v
.
data_ptr
()),
(
IsMLA
?
nullptr
:
dst_v
.
data_ptr
()),
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
num_layers_to_process
,
num_items
,
items_per_warp
,
item_size
*
dtype_size
,
src_layout_dim
*
dtype_size
,
dst_layout_dim
*
dtype_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
void
transfer_kv_per_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
);
}
inline
void
transfer_page_direct
(
const
at
::
Tensor
src_buffer
,
at
::
Tensor
dst_buffer
,
int64_t
src_page_index
,
int64_t
dst_page_index
,
int64_t
page_size
)
{
dst_buffer
.
slice
(
0
,
dst_page_index
,
dst_page_index
+
page_size
)
.
copy_
(
src_buffer
.
slice
(
0
,
src_page_index
,
src_page_index
+
page_size
),
/* non_blocking= */
true
);
}
template
<
bool
IsMLA
,
bool
AllLayers
>
inline
void
transfer_kv_direct_impl
(
const
at
::
Tensor
&
src_k
,
at
::
Tensor
&
dst_k
,
const
at
::
Tensor
&
src_v_opt
,
// Only used when IsMLA is false (for src_v)
at
::
Tensor
&
dst_v_opt
,
// Only used when IsMLA is false (for dst_v)
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
=
1
)
{
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
auto
src_indices_cpu
=
src_indices
.
cpu
();
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
for
(
const
auto
i
:
c10
::
irange
(
num_pages
))
{
auto
s_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
if
constexpr
(
AllLayers
)
{
for
(
const
auto
j
:
c10
::
irange
(
num_layers
))
{
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
.
select
(
0
,
j
),
dst_v_opt
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
}
}
else
{
// Per-layer
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
,
dst_v_opt
,
s_index
,
d_index
,
page_size
);
}
}
}
}
void
transfer_kv_per_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
transfer_kv_direct_impl
<
false
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
transfer_kv_direct_impl
<
false
,
true
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
void
transfer_kv_per_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
false
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
34c3f9b2
...
@@ -371,6 +371,89 @@ void segment_packbits(
...
@@ -371,6 +371,89 @@ void segment_packbits(
int64_t
batch_size
,
int64_t
batch_size
,
int64_t
cuda_stream
=
0
);
int64_t
cuda_stream
=
0
);
/*
* From csrc/kvcacheio
*/
void
transfer_kv_per_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
);
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
);
/*
/*
* From FlashInfer
* From FlashInfer
*/
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
34c3f9b2
...
@@ -47,6 +47,12 @@ from sgl_kernel.gemm import (
...
@@ -47,6 +47,12 @@ from sgl_kernel.gemm import (
shuffle_rows
,
shuffle_rows
,
)
)
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_mla
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
)
from
sgl_kernel.moe
import
(
from
sgl_kernel.moe
import
(
apply_shuffle_mul_sum
,
apply_shuffle_mul_sum
,
cutlass_fp4_group_mm
,
cutlass_fp4_group_mm
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
0 → 100644
View file @
34c3f9b2
import
torch
def
transfer_kv_per_layer
(
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_direct
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer
(
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_per_layer_mla
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
(
src
,
dst
,
src_indices
,
dst_indices
,
item_size
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_direct
(
src
,
dst
,
src_indices
,
dst_indices
,
page_size
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_mla
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
src
,
dst
,
src_indices
,
dst_indices
,
item_size
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
num_warps_per_block
,
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_direct
(
src
,
dst
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
sgl-kernel/tests/test_kvcacheio.py
0 → 100644
View file @
34c3f9b2
import
pytest
import
torch
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_mla
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
)
def
ref_copy_with_indices
(
src_pool
,
dst_pool
,
src_indices
,
dst_indices
):
dst_pool
[
dst_indices
]
=
src_pool
[
src_indices
].
to
(
dst_pool
.
device
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_items_to_transfer"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
1
,
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"item_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"total_items_in_pool"
,
[
10240
])
@
pytest
.
mark
.
parametrize
(
"is_mla"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"all_layers"
,
[
False
,
True
])
def
test_transfer_kv
(
dtype
:
torch
.
dtype
,
num_items_to_transfer
:
int
,
item_size
:
int
,
page_size
:
int
,
total_items_in_pool
:
int
,
is_mla
:
bool
,
all_layers
:
bool
,
):
"""
Tests the per-layer transfer functions, treating tensors as memory pools.
"""
original_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
device
=
"cuda"
torch
.
cuda
.
manual_seed
(
42
)
num_layers
=
4
# A small number of layers for pool creation
total_pages_in_pool
=
total_items_in_pool
//
page_size
num_pages_to_transfer
=
num_items_to_transfer
//
page_size
if
num_pages_to_transfer
==
0
:
torch
.
set_default_dtype
(
original_dtype
)
return
page_indices
=
torch
.
randperm
(
total_pages_in_pool
,
dtype
=
torch
.
int64
)
src_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[:
num_pages_to_transfer
]
]
)
src_indices_device
=
src_indices_host
.
to
(
device
)
dst_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[
num_pages_to_transfer
:
2
*
num_pages_to_transfer
]
]
)
dst_indices_device
=
dst_indices_host
.
to
(
device
)
# Prepare memory pools based on whether it's an MLA case.
if
is_mla
:
src_pool_host
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
pin_memory
()
dst_pool_ref
=
torch
.
zeros_like
(
src_pool_host
).
to
(
device
)
dst_pool_kernel
=
torch
.
zeros_like
(
dst_pool_ref
)
dst_pool_direct
=
torch
.
zeros_like
(
dst_pool_ref
)
else
:
src_k_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
pin_memory
()
src_v_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
pin_memory
()
dst_k_pool_ref
=
torch
.
zeros_like
(
src_k_pool
).
to
(
device
)
dst_v_pool_ref
=
torch
.
zeros_like
(
src_v_pool
).
to
(
device
)
dst_k_pool_kernel
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_kernel
=
torch
.
zeros_like
(
dst_v_pool_ref
)
dst_k_pool_direct
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_direct
=
torch
.
zeros_like
(
dst_v_pool_ref
)
torch
.
cuda
.
synchronize
()
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test
=
0
if
is_mla
:
if
not
all_layers
:
ref_copy_with_indices
(
src_pool_host
[
layer_idx_to_test
],
dst_pool_ref
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_per_layer_mla
(
src_pool_host
[
layer_idx_to_test
],
dst_pool_kernel
[
layer_idx_to_test
],
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
transfer_kv_per_layer_mla
(
src_pool_host
[
layer_idx_to_test
],
dst_pool_direct
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
else
:
for
layer_id
in
range
(
num_layers
):
ref_copy_with_indices
(
src_pool_host
[
layer_id
],
dst_pool_ref
[
layer_id
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_all_layer_mla
(
src_pool_host
,
dst_pool_kernel
,
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
transfer_kv_all_layer_mla
(
src_pool_host
,
dst_pool_direct
,
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_pool_kernel
,
dst_pool_ref
)
torch
.
testing
.
assert_close
(
dst_pool_direct
,
dst_pool_ref
)
else
:
if
not
all_layers
:
ref_copy_with_indices
(
src_k_pool
[
layer_idx_to_test
],
dst_k_pool_ref
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
)
ref_copy_with_indices
(
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_ref
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_per_layer
(
src_k_pool
[
layer_idx_to_test
],
dst_k_pool_kernel
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_kernel
[
layer_idx_to_test
],
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
transfer_kv_per_layer
(
src_k_pool
[
layer_idx_to_test
],
dst_k_pool_direct
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_direct
[
layer_idx_to_test
],
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
)
else
:
for
layer_id
in
range
(
num_layers
):
ref_copy_with_indices
(
src_k_pool
[
layer_id
],
dst_k_pool_ref
[
layer_id
],
src_indices_host
,
dst_indices_device
,
)
ref_copy_with_indices
(
src_v_pool
[
layer_id
],
dst_v_pool_ref
[
layer_id
],
src_indices_host
,
dst_indices_device
,
)
transfer_kv_all_layer
(
src_k_pool
,
dst_k_pool_kernel
,
src_v_pool
,
dst_v_pool_kernel
,
src_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
transfer_kv_all_layer
(
src_k_pool
,
dst_k_pool_direct
,
src_v_pool
,
dst_v_pool_direct
,
src_indices_host
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_kernel
,
dst_v_pool_ref
)
torch
.
testing
.
assert_close
(
dst_k_pool_direct
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_direct
,
dst_v_pool_ref
)
torch
.
set_default_dtype
(
original_dtype
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment