Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86a0be65
Unverified
Commit
86a0be65
authored
Aug 12, 2025
by
DarkSharpness
Committed by
GitHub
Aug 12, 2025
Browse files
[Feature] Support custom set kv buffer kernel (#8884)
parent
0edda320
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
178 additions
and
0 deletions
+178
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+6
-0
sgl-kernel/csrc/memory/store.cu
sgl-kernel/csrc/memory/store.cu
+147
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+5
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/memory.py
sgl-kernel/python/sgl_kernel/memory.py
+18
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
86a0be65
...
...
@@ -280,6 +280,7 @@ set(SOURCES
"csrc/speculative/packbit.cu"
"csrc/spatial/greenctx_stream.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/memory/store.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/sampling.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
86a0be65
...
...
@@ -413,6 +413,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m
.
def
(
"create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"
);
m
.
impl
(
"create_greenctx_stream_by_value"
,
&
create_greenctx_stream_by_value
);
/*
* From csrc/memory
*/
m
.
def
(
"store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"
);
m
.
impl
(
"store_kv_cache"
,
&
store_kv_cache
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/memory/store.cu
0 → 100644
View file @
86a0be65
#include <ATen/Dispatch.h>
#include <ATen/core/TensorBody.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <cstdint>
#include <type_traits>
namespace
{
using
std
::
size_t
;
using
std
::
uint64_t
;
// Each warp will process 256 bytes per loop iteration
template
<
typename
T
>
__global__
void
store_kv_cache_256x1
(
uint64_t
*
__restrict__
k_cache
,
uint64_t
*
__restrict__
v_cache
,
const
T
*
__restrict__
out_loc
,
const
size_t
length
,
const
uint64_t
*
__restrict__
k
,
const
uint64_t
*
__restrict__
v
,
const
size_t
kv_cache_stride
,
const
size_t
kv_input_stride
,
const
size_t
num_items
)
{
const
auto
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
auto
warp_id
=
idx
/
32
;
const
auto
lane_id
=
idx
%
32
;
if
(
warp_id
>=
length
)
return
;
const
auto
offset
=
out_loc
[
warp_id
];
const
auto
k_dst
=
k_cache
+
offset
*
kv_cache_stride
;
const
auto
v_dst
=
v_cache
+
offset
*
kv_cache_stride
;
const
auto
k_src
=
k
+
warp_id
*
kv_input_stride
;
const
auto
v_src
=
v
+
warp_id
*
kv_input_stride
;
for
(
size_t
i
=
0
;
i
<
num_items
;
++
i
)
{
k_dst
[
lane_id
+
i
*
32
]
=
k_src
[
lane_id
+
i
*
32
];
v_dst
[
lane_id
+
i
*
32
]
=
v_src
[
lane_id
+
i
*
32
];
}
}
// Each warp will process 128 bytes per loop iteration
template
<
typename
T
>
__global__
void
store_kv_cache_128x2
(
uint64_t
*
__restrict__
k_cache
,
uint64_t
*
__restrict__
v_cache
,
const
T
*
__restrict__
out_loc
,
const
size_t
length
,
const
uint64_t
*
__restrict__
k
,
const
uint64_t
*
__restrict__
v
,
const
size_t
kv_cache_stride
,
const
size_t
kv_input_stride
,
const
size_t
num_items
)
{
const
auto
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
auto
warp_id
=
idx
/
32
;
const
auto
lane_id
=
idx
%
32
;
if
(
warp_id
>=
length
)
return
;
const
auto
offset
=
out_loc
[
warp_id
];
const
auto
copy_k
=
lane_id
<
16
;
const
auto
copy_id
=
lane_id
%
16
;
const
auto
cache
=
copy_k
?
k_cache
:
v_cache
;
const
auto
input
=
copy_k
?
k
:
v
;
const
auto
dst
=
cache
+
offset
*
kv_cache_stride
;
const
auto
src
=
input
+
warp_id
*
kv_input_stride
;
for
(
size_t
i
=
0
;
i
<
num_items
;
++
i
)
{
dst
[
copy_id
+
i
*
16
]
=
src
[
copy_id
+
i
*
16
];
}
}
}
// namespace
auto
store_kv_cache
(
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
out_loc
,
at
::
Tensor
k
,
at
::
Tensor
v
)
->
void
{
const
auto
max_tokens
=
k_cache
.
size
(
0
);
const
auto
num_tokens
=
out_loc
.
size
(
0
);
k_cache
=
k_cache
.
view
({
max_tokens
,
-
1
});
v_cache
=
v_cache
.
view
({
max_tokens
,
-
1
});
k
=
k
.
view
({
num_tokens
,
-
1
});
v
=
v
.
view
({
num_tokens
,
-
1
});
TORCH_CHECK
(
k_cache
.
is_cuda
()
&&
v_cache
.
is_cuda
()
&&
out_loc
.
is_cuda
()
&&
k
.
is_cuda
()
&&
v
.
is_cuda
(),
"All tensors must be CUDA tensors"
);
TORCH_CHECK
(
k_cache
.
sizes
()
==
v_cache
.
sizes
(),
"k_cache and v_cache must have the same size"
);
TORCH_CHECK
(
k_cache
.
strides
()
==
v_cache
.
strides
(),
"k_cache and v_cache must have the same strides"
);
TORCH_CHECK
(
k
.
sizes
()
==
v
.
sizes
(),
"k and v must have the same size"
);
TORCH_CHECK
(
k
.
strides
()
==
v
.
strides
(),
"k and v must have the same strides"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
&&
k_cache
.
stride
(
-
1
)
==
1
,
"k and k_cache must be contiguous in head."
);
TORCH_CHECK
(
k
.
size
(
-
1
)
==
k_cache
.
size
(
-
1
),
"k and k_cache must have the same head size"
);
TORCH_CHECK
(
out_loc
.
dim
()
==
1
&&
out_loc
.
is_contiguous
(),
"out_loc must be a 1D contiguous tensor"
);
static_assert
(
sizeof
(
uint64_t
)
==
8
,
"uint64_t must be 8 bytes, our code assumes that"
);
const
auto
length
=
out_loc
.
size
(
0
);
const
auto
elem_size
=
k
.
element_size
();
const
auto
size_bytes
=
elem_size
*
k
.
size
(
-
1
);
const
auto
kv_cache_stride_bytes
=
elem_size
*
k_cache
.
stride
(
-
2
);
const
auto
kv_input_stride_bytes
=
elem_size
*
k
.
stride
(
-
2
);
const
auto
kv_cache_stride
=
kv_cache_stride_bytes
/
8
;
const
auto
kv_input_stride
=
kv_input_stride_bytes
/
8
;
const
auto
k_cache_ptr
=
static_cast
<
uint64_t
*>
(
k_cache
.
data_ptr
());
const
auto
v_cache_ptr
=
static_cast
<
uint64_t
*>
(
v_cache
.
data_ptr
());
const
auto
k_ptr
=
static_cast
<
const
uint64_t
*>
(
k
.
data_ptr
());
const
auto
v_ptr
=
static_cast
<
const
uint64_t
*>
(
v
.
data_ptr
());
const
auto
num_threads
=
256
;
const
auto
num_warps
=
num_threads
/
32
;
const
auto
num_blocks
=
(
length
+
num_warps
-
1
)
/
num_warps
;
const
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_INTEGRAL_TYPES
(
out_loc
.
scalar_type
(),
"store_kv_cache"
,
[
&
]
{
if
constexpr
(
!
std
::
is_same_v
<
scalar_t
,
int32_t
>
&&
!
std
::
is_same_v
<
scalar_t
,
int64_t
>
)
{
// do not instantiate the kernel if out_loc is not int32 or int64
TORCH_CHECK
(
false
,
"out_loc must be of type int32 or int64, got: "
,
out_loc
.
scalar_type
());
}
else
{
if
(
size_bytes
%
256
==
0
)
{
const
auto
items_per_warp
=
size_bytes
/
256
;
store_kv_cache_256x1
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
k_cache_ptr
,
v_cache_ptr
,
out_loc
.
data_ptr
<
scalar_t
>
(),
length
,
k_ptr
,
v_ptr
,
kv_cache_stride
,
kv_input_stride
,
items_per_warp
);
}
else
if
(
size_bytes
%
128
==
0
)
{
const
auto
items_per_warp
=
size_bytes
/
128
;
store_kv_cache_128x2
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
k_cache_ptr
,
v_cache_ptr
,
out_loc
.
data_ptr
<
scalar_t
>
(),
length
,
k_ptr
,
v_ptr
,
kv_cache_stride
,
kv_input_stride
,
items_per_warp
);
}
else
{
TORCH_CHECK
(
false
,
"The last dimension size bytes of k and v must be"
" divisible by 128 at least, got: "
,
size_bytes
);
}
}
});
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
86a0be65
...
...
@@ -699,3 +699,8 @@ void qserve_w4a8_per_group_gemm(
* From csrc/spatial
*/
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
);
/*
* From csrc/memory
*/
void
store_kv_cache
(
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
out_loc
,
at
::
Tensor
k
,
at
::
Tensor
v
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
86a0be65
...
...
@@ -67,6 +67,7 @@ from sgl_kernel.marlin import (
awq_marlin_repack
,
gptq_marlin_repack
,
)
from
sgl_kernel.memory
import
set_kv_buffer_kernel
from
sgl_kernel.moe
import
(
apply_shuffle_mul_sum
,
cutlass_fp4_group_mm
,
...
...
sgl-kernel/python/sgl_kernel/memory.py
0 → 100644
View file @
86a0be65
import
torch
def
set_kv_buffer_kernel
(
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
fallback
:
bool
=
False
,
):
try
:
if
fallback
:
raise
RuntimeError
(
"Fallback to torch implementation"
)
torch
.
ops
.
sgl_kernel
.
store_kv_cache
(
k_cache
,
v_cache
,
loc
,
k
,
v
)
except
RuntimeError
:
# ok, fallback to torch implementation
k_cache
[
loc
]
=
k
v_cache
[
loc
]
=
v
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment