Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
33fbf3ca
Commit
33fbf3ca
authored
Nov 12, 2025
by
liucong
Browse files
增加dcu_create_extend_after_decode_spec_info_kernel实现
parent
cb4fb0ee
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
152 additions
and
10 deletions
+152
-10
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+22
-9
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+2
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+104
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+16
-0
No files found.
python/sglang/srt/speculative/eagle_info.py
View file @
33fbf3ca
...
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_target_cache_loc
,
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_extend_after_decode_spec_info
if
is_cuda
():
from
sgl_kernel
import
(
...
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
use_sglang_create_extend_after_decode_spec_info
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
if
self
.
use_sglang_create_extend_after_decode_spec_info
:
dcu_create_extend_after_decode_spec_info
(
verified_id
=
batch
.
input_ids
,
seq_lens
=
batch
.
seq_lens
,
accept_lens
=
self
.
accept_length
,
positions
=
self
.
positions
,
new_verified_id
=
self
.
verified_id
,
bs
=
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
)),
)
else
:
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
self
,
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
33fbf3ca
...
...
@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
33fbf3ca
...
...
@@ -694,6 +694,66 @@ __global__ void launch_alloc_extend_kernel(
}
}
__global__
void
launch_create_extend_after_decode_spec_info_int32_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int32_t
accept_length
=
accept_lens_ptr
[
pid
];
int32_t
accept_len_cumsum
=
0
;
for
(
int32_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int32_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int32_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
__global__
void
launch_create_extend_after_decode_spec_info_int64_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int64_t
accept_length
=
accept_lens_ptr
[
pid
];
int64_t
accept_len_cumsum
=
0
;
for
(
int64_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int64_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int64_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
last_loc_ptr
,
...
...
@@ -734,4 +794,47 @@ void dcu_alloc_extend_kernel(
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_alloc_extend_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
pre_lens_ptr1
,
seq_lens_ptr1
,
last_loc_ptr1
,
free_page_ptr1
,
out_indices1
,
bs
,
page_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
\ No newline at end of file
}
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
)
{
const
int32_t
*
verified_id_ptr
;
const
int64_t
*
seq_lens_ptr
;
const
int32_t
*
accept_lens_ptr_int32
;
const
int64_t
*
accept_lens_ptr_int64
;
int64_t
*
positions_ptr
;
int32_t
*
new_verified_id_ptr
;
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
accept_lens
.
dtype
()
==
torch
::
kInt32
)
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int32
=
static_cast
<
const
int32_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int32_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int32
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
else
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int64
=
static_cast
<
const
int64_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int64_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int64
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
};
\ No newline at end of file
sgl-kernel/include/sgl_kernel_ops.h
View file @
33fbf3ca
...
...
@@ -538,6 +538,14 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
33fbf3ca
...
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip
=
is_hip
()
def
dcu_create_extend_after_decode_spec_info
(
verified_id
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
accept_lens
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
new_verified_id
:
torch
.
Tensor
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_extend_after_decode_spec_info
(
verified_id
,
seq_lens
,
accept_lens
,
positions
,
new_verified_id
,
bs
,
)
def
dcu_alloc_extend_kernel
(
pre_lens_ptr
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment