Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c9bcffd2
Commit
c9bcffd2
authored
Nov 04, 2025
by
liucong
Browse files
增加dcu_alloc_decode_kernel实现
parent
46da9556
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
117 additions
and
8 deletions
+117
-8
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+22
-8
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+2
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+65
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+9
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+19
-0
No files found.
python/sglang/srt/mem_cache/allocator.py
View file @
c9bcffd2
...
@@ -28,6 +28,7 @@ import triton.language as tl
...
@@ -28,6 +28,7 @@ import triton.language as tl
from
sglang.srt.mem_cache.memory_pool
import
SWAKVPool
from
sglang.srt.mem_cache.memory_pool
import
SWAKVPool
from
sglang.srt.utils
import
get_bool_env_var
,
get_num_new_pages
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
get_num_new_pages
,
next_power_of_2
from
sgl_kernel.kvcacheio
import
dcu_alloc_decode_kernel
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
from
sglang.srt.mem_cache.memory_pool
import
KVCache
...
@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
,
need_sort
)
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
,
need_sort
)
self
.
num_pages
=
size
//
page_size
self
.
num_pages
=
size
//
page_size
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
use_dcu_decode_kernel
=
get_bool_env_var
(
"USE_DCU_DECODE_KERNEL"
)
self
.
seen_max_num_extend_tokens_next_power_of_2
=
1
self
.
seen_max_num_extend_tokens_next_power_of_2
=
1
self
.
clear
()
self
.
clear
()
...
@@ -525,6 +527,18 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -525,6 +527,18 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self
.
merge_and_sort_free
()
self
.
merge_and_sort_free
()
out_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
out_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
if
self
.
use_dcu_decode_kernel
:
dcu_alloc_decode_kernel
(
seq_lens_ptr
=
seq_lens
,
last_loc_ptr
=
last_loc
,
free_page_ptr
=
self
.
free_pages
,
out_indices
=
out_indices
,
bs
=
bs
,
bs_upper
=
next_power_of_2
(
bs
),
page_size
=
self
.
page_size
,
)
else
:
alloc_decode_kernel
[(
bs
,)](
alloc_decode_kernel
[(
bs
,)](
seq_lens
,
seq_lens
,
last_loc
,
last_loc
,
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
c9bcffd2
...
@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
/*
* From csrc/kvcacheio
* From csrc/kvcacheio
*/
*/
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_decode_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_decode_kernel
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
c9bcffd2
...
@@ -571,3 +571,68 @@ void transfer_kv_all_layer_direct_lf_pf(
...
@@ -571,3 +571,68 @@ void transfer_kv_all_layer_direct_lf_pf(
int64_t
page_size
)
{
int64_t
page_size
)
{
transfer_kv_page_first_direct_impl
<
true
>
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
0
,
page_size
);
transfer_kv_page_first_direct_impl
<
true
>
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
0
,
page_size
);
}
}
__device__
int64_t
ceil_div
(
int64_t
a
,
int64_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
__global__
void
launch_alloc_decode_kernel
(
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
last_loc_ptr
,
const
int64_t
*
free_page_ptr
,
int64_t
*
out_indices
,
int64_t
bs_upper
,
int64_t
page_size
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs_upper
)
return
;
int64_t
seq_len
=
seq_lens_ptr
[
pid
];
int64_t
pre_len
=
seq_len
-
1
;
int64_t
num_page_start_loc_self
=
ceil_div
(
seq_len
,
page_size
)
-
ceil_div
(
pre_len
,
page_size
);
int64_t
sum_num_new_pages
=
0
;
for
(
int64_t
i
=
0
;
i
<
pid
;
i
++
)
{
int64_t
other_seq_len
=
seq_lens_ptr
[
i
];
int64_t
other_pre_len
=
(
i
<=
pid
)
?
(
other_seq_len
-
1
)
:
other_seq_len
;
int64_t
other_num_pages_after
=
ceil_div
(
other_seq_len
,
page_size
);
int64_t
other_num_pages_before
=
ceil_div
(
other_pre_len
,
page_size
);
int64_t
other_num_new_pages
=
other_num_pages_after
-
other_num_pages_before
;
sum_num_new_pages
+=
other_num_new_pages
;
}
int64_t
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
;
if
(
num_page_start_loc_self
==
0
)
{
int32_t
last_loc
=
last_loc_ptr
[
pid
];
out_indices
[
pid
]
=
last_loc
+
1
;
}
else
{
int64_t
page
=
free_page_ptr
[
new_page_start_loc
];
out_indices
[
pid
]
=
page
*
page_size
;
}
}
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
const
at
::
Tensor
free_page_ptr
,
at
::
Tensor
out_indices
,
int64_t
bs
,
int64_t
bs_upper
,
int64_t
page_size
)
{
const
int64_t
*
seq_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
seq_lens_ptr
.
data_ptr
());
const
int32_t
*
last_loc_ptr1
=
static_cast
<
const
int32_t
*>
(
last_loc_ptr
.
data_ptr
());
const
int64_t
*
free_page_ptr1
=
static_cast
<
const
int64_t
*>
(
free_page_ptr
.
data_ptr
());
int64_t
*
out_indices1
=
static_cast
<
int64_t
*>
(
out_indices
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_alloc_decode_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
seq_lens_ptr1
,
last_loc_ptr1
,
free_page_ptr1
,
out_indices1
,
bs_upper
,
page_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
c9bcffd2
...
@@ -538,6 +538,15 @@ void segment_packbits(
...
@@ -538,6 +538,15 @@ void segment_packbits(
/*
/*
* From csrc/kvcacheio
* From csrc/kvcacheio
*/
*/
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
const
at
::
Tensor
free_page_ptr
,
at
::
Tensor
out_indices
,
int64_t
bs
,
int64_t
bs_upper
,
int64_t
page_size
);
void
transfer_kv_per_layer
(
void
transfer_kv_per_layer
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
c9bcffd2
...
@@ -10,6 +10,25 @@ def is_hip() -> bool:
...
@@ -10,6 +10,25 @@ def is_hip() -> bool:
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
def
dcu_alloc_decode_kernel
(
seq_lens_ptr
:
torch
.
Tensor
,
last_loc_ptr
:
torch
.
Tensor
,
free_page_ptr
:
torch
.
Tensor
,
out_indices
:
torch
.
Tensor
,
bs
:
int
,
bs_upper
:
int
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_alloc_decode_kernel
(
seq_lens_ptr
,
last_loc_ptr
,
free_page_ptr
,
out_indices
,
bs
,
bs_upper
,
page_size
,
)
def
transfer_kv_per_layer
(
def
transfer_kv_per_layer
(
src_k
:
torch
.
Tensor
,
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment