Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
56222658
Unverified
Commit
56222658
authored
Oct 14, 2025
by
yinghui
Committed by
GitHub
Oct 14, 2025
Browse files
move eagle draft post process to cuda graph (#11434)
Co-authored-by:
Lianmin Zheng
<
lianminzheng@gmail.com
>
parent
dc965db0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
469 additions
and
549 deletions
+469
-549
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+0
-427
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-5
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+1
-107
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+138
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+16
-7
python/sglang/srt/speculative/eagle_worker_v2.py
python/sglang/srt/speculative/eagle_worker_v2.py
+2
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_build_eagle_tree.py
test/srt/test_build_eagle_tree.py
+308
-0
No files found.
python/sglang/srt/speculative/build_eagle_tree.py
deleted
100644 → 0
View file @
dc965db0
# NOTE: Please run this file to make sure the test cases are correct.
import
math
from
enum
import
IntEnum
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
build_tree_kernel_efficient_preprocess
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
num_verify_tokens
:
int
,
):
score_list
=
torch
.
cat
(
score_list
,
dim
=
1
).
flatten
(
1
)
# b, n, topk; n= 1 + (num_steps-1) * self.topk
ss_token_list
=
torch
.
cat
(
token_list
,
dim
=
1
)
# b, (self.topk + (num_steps-1) * self.topk)
top_scores
=
torch
.
topk
(
score_list
,
num_verify_tokens
-
1
,
dim
=-
1
)
top_scores_index
=
top_scores
.
indices
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
if
len
(
parents_list
)
>
1
:
parent_list
=
torch
.
cat
(
parents_list
[:
-
1
],
dim
=
1
)
else
:
batch_size
=
parents_list
[
0
].
shape
[
0
]
parent_list
=
torch
.
empty
(
batch_size
,
0
,
device
=
parents_list
[
0
].
device
)
return
parent_list
,
top_scores_index
,
draft_tokens
class
TreeMaskMode
(
IntEnum
):
FULL_MASK
=
0
QLEN_ONLY
=
1
QLEN_ONLY_BITPACKING
=
2
def
build_tree_kernel_efficient
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
tree_mask_mode
:
TreeMaskMode
=
TreeMaskMode
.
FULL_MASK
,
tree_mask_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
position_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
parent_list
,
top_scores_index
,
draft_tokens
=
(
build_tree_kernel_efficient_preprocess
(
verified_id
,
score_list
,
token_list
,
parents_list
,
num_verify_tokens
,
)
)
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if
tree_mask_buf
is
not
None
:
tree_mask
=
tree_mask_buf
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY
:
tree_mask
=
torch
.
full
(
(
num_verify_tokens
*
bs
*
num_verify_tokens
,),
True
,
dtype
=
torch
.
bool
,
device
=
device
,
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY_BITPACKING
:
packed_dtypes
=
[
torch
.
uint8
,
torch
.
uint16
,
torch
.
uint32
]
packed_dtype_idx
=
int
(
math
.
ceil
(
math
.
log2
((
num_verify_tokens
+
7
)
//
8
)))
tree_mask
=
torch
.
zeros
(
(
num_verify_tokens
*
bs
,),
dtype
=
packed_dtypes
[
packed_dtype_idx
],
device
=
device
,
)
elif
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
else
:
raise
NotImplementedError
(
f
"Invalid tree mask:
{
tree_mask_mode
=
}
"
)
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
retrive_index
=
torch
.
full
(
(
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
retrive_next_token
=
torch
.
full
(
(
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
retrive_next_sibling
=
torch
.
full
(
(
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
if
position_buf
is
not
None
:
positions
=
position_buf
else
:
positions
=
torch
.
empty
(
(
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel_efficient
(
parent_list
,
top_scores_index
,
seq_lens
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
spec_steps
,
num_verify_tokens
,
tree_mask_mode
,
)
return
(
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
def
test_build_tree_kernel_efficient
():
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
score_list
=
[
torch
.
tensor
(
[
[[
7.1127e-01
,
2.8292e-01
,
2.2995e-03
,
1.7357e-03
]],
[[
9.7476e-01
,
2.2219e-02
,
6.5031e-04
,
1.3212e-04
]],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.9142e-01
,
1.2863e-02
,
1.6873e-03
,
1.1871e-03
],
[
2.4787e-01
,
1.8818e-02
,
1.4204e-02
,
9.2235e-04
],
[
2.2971e-03
,
1.6700e-06
,
1.8737e-07
,
8.3146e-08
],
[
1.2771e-03
,
2.4374e-04
,
1.7832e-04
,
1.1947e-05
],
],
[
[
8.4832e-02
,
6.6068e-02
,
5.8304e-02
,
5.7851e-02
],
[
2.3616e-03
,
1.1243e-03
,
5.4368e-04
,
2.7768e-04
],
[
2.5286e-04
,
1.5578e-04
,
2.8817e-05
,
1.2888e-05
],
[
1.2834e-04
,
2.5417e-06
,
1.1279e-06
,
1.6088e-08
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.6438e-01
,
2.6997e-02
,
2.4236e-05
,
4.0821e-06
],
[
2.4402e-01
,
2.8409e-03
,
5.0935e-04
,
2.9022e-04
],
[
1.6178e-02
,
2.0567e-03
,
4.5892e-04
,
3.0034e-05
],
[
1.3023e-02
,
5.0497e-04
,
3.6371e-04
,
8.7750e-05
],
],
[
[
2.3263e-02
,
2.0054e-02
,
9.3990e-03
,
2.7783e-03
],
[
6.4156e-02
,
5.5506e-04
,
1.0429e-04
,
9.7211e-05
],
[
4.9950e-02
,
5.0630e-03
,
9.0068e-04
,
3.3656e-04
],
[
7.5817e-03
,
8.5731e-04
,
6.9972e-04
,
6.0793e-04
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.6420e-01
,
1.0525e-04
,
6.5864e-05
,
1.2253e-06
],
[
1.3019e-01
,
1.0461e-01
,
5.2083e-03
,
1.6777e-03
],
[
2.0103e-02
,
6.7335e-03
,
1.2625e-04
,
1.0364e-05
],
[
1.5142e-02
,
7.0819e-04
,
9.6595e-05
,
8.7951e-05
],
],
[
[
5.8608e-02
,
1.8840e-03
,
7.8535e-04
,
4.4400e-04
],
[
1.2185e-02
,
2.0684e-03
,
1.7418e-03
,
1.4327e-03
],
[
6.2455e-03
,
6.1487e-03
,
2.6862e-03
,
1.8034e-03
],
[
1.8590e-03
,
1.6151e-03
,
1.2481e-03
,
3.6038e-04
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
]
token_list
=
[
torch
.
tensor
(
[[
29896
,
29906
,
29900
,
29945
],
[
13
,
2
,
29871
,
28956
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29889
,
29974
,
29945
,
29900
,
29974
,
29922
,
29930
,
29958
,
29889
,
29974
,
29930
,
29945
,
29974
,
29922
,
29930
,
29958
,
],
[
22550
,
4136
,
16492
,
8439
,
29871
,
2
,
3001
,
13
,
2
,
13
,
29906
,
29946
,
2
,
13
,
29871
,
259
,
],
],
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29946
,
29945
,
29953
,
29906
,
29896
,
29945
,
29900
,
29906
,
29896
,
29945
,
29906
,
29953
,
29896
,
29945
,
29906
,
29946
,
],
[
29871
,
2
,
29901
,
29889
,
29871
,
2
,
395
,
259
,
29901
,
29871
,
2
,
29889
,
3001
,
1234
,
7146
,
2186
,
],
],
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29946
,
29974
,
29945
,
29930
,
29889
,
29922
,
29974
,
29930
,
29974
,
29946
,
29930
,
29922
,
29889
,
29974
,
29945
,
29922
,
],
[
29941
,
29906
,
2
,
29946
,
29871
,
450
,
319
,
14990
,
29946
,
29941
,
2
,
29906
,
29871
,
2
,
3001
,
13
,
],
],
device
=
"cuda"
,
),
]
parents_list
=
[
torch
.
tensor
(
[[
-
1
,
0
,
1
,
2
,
3
],
[
-
1
,
0
,
1
,
2
,
3
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
([[
4
,
8
,
9
,
10
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
20
,
24
,
21
,
28
],
[
24
,
28
,
20
,
21
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
36
,
40
,
41
,
44
],
[
36
,
40
,
44
,
45
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
]
seq_lens
=
torch
.
tensor
([
5
,
10
],
dtype
=
torch
.
int64
,
device
=
"cuda"
)
topk
=
4
depth
=
4
num_draft_token
=
8
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
=
verified_id
,
score_list
=
score_list
,
token_list
=
token_list
,
parents_list
=
parents_list
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
print
(
"=========== build tree kernel efficient =========="
)
print
(
f
"
{
tree_mask
=
}
"
)
print
(
f
"
{
position
=
}
"
)
print
(
f
"
{
retrive_index
=
}
"
)
print
(
f
"
{
retrive_next_token
=
}
"
)
print
(
f
"
{
retrive_next_sibling
=
}
"
)
print
(
f
"
{
draft_tokens
=
}
"
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
]
assert
retrive_next_token
.
tolist
()
==
[
[
1
,
3
,
4
,
5
,
6
,
7
,
-
1
,
-
1
],
[
1
,
2
,
-
1
,
6
,
-
1
,
-
1
,
7
,
-
1
],
]
assert
retrive_next_sibling
.
tolist
()
==
[
[
-
1
,
2
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
3
,
4
,
5
,
-
1
,
-
1
,
-
1
],
]
assert
draft_tokens
.
tolist
()
==
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
]
if
__name__
==
"__main__"
:
test_build_tree_kernel_efficient
()
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
56222658
...
@@ -276,11 +276,9 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -276,11 +276,9 @@ class EAGLEDraftCudaGraphRunner:
return
graph
,
out
return
graph
,
out
def
_postprocess_output_to_raw_bs
(
self
,
out
,
raw_bs
):
def
_postprocess_output_to_raw_bs
(
self
,
out
,
raw_bs
):
score_list
,
token_list
,
parents_list
=
out
# Keep the variables name for readability
score_list
=
[
x
[:
raw_bs
]
for
x
in
score_list
]
parent_list
,
top_scores_index
,
draft_tokens
=
(
t
[:
raw_bs
]
for
t
in
out
)
token_list
=
[
x
[:
raw_bs
]
for
x
in
token_list
]
return
parent_list
,
top_scores_index
,
draft_tokens
parents_list
=
[
x
[:
raw_bs
]
for
x
in
parents_list
]
return
(
score_list
,
token_list
,
parents_list
)
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
assert
forward_batch
.
out_cache_loc
is
not
None
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
56222658
from
__future__
import
annotations
from
__future__
import
annotations
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -19,7 +18,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -19,7 +18,6 @@ from sglang.srt.model_executor.forward_batch_info import (
)
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.build_eagle_tree
import
TreeMaskMode
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
SIMULATE_ACC_LEN
,
SIMULATE_ACC_LEN
,
generate_simulated_accept_index
,
generate_simulated_accept_index
,
...
@@ -286,110 +284,6 @@ class EagleVerifyInputV2Mixin:
...
@@ -286,110 +284,6 @@ class EagleVerifyInputV2Mixin:
return
predict
,
accept_length
,
accept_index
return
predict
,
accept_length
,
accept_index
def
build_tree_kernel_efficient_tmp
(
verified_id
:
torch
.
Tensor
,
parent_list
:
List
[
torch
.
Tensor
],
top_scores_index
:
torch
.
Tensor
,
draft_tokens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
tree_mask_mode
:
TreeMaskMode
=
TreeMaskMode
.
FULL_MASK
,
tree_mask_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
position_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO(lsyin): make it compatible with default code path
# TODO(lsyin): support cuda graph graph padding for eagle
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if
tree_mask_buf
is
not
None
:
tree_mask
=
tree_mask_buf
if
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY
:
tree_mask
.
fill_
(
True
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY_BITPACKING
:
tree_mask
.
fill_
(
0
)
elif
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
tree_mask
.
fill_
(
True
)
else
:
raise
NotImplementedError
(
f
"Invalid tree mask:
{
tree_mask_mode
=
}
"
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY
:
tree_mask
=
torch
.
full
(
(
num_verify_tokens
*
bs
*
num_verify_tokens
,),
True
,
dtype
=
torch
.
bool
,
device
=
device
,
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY_BITPACKING
:
packed_dtypes
=
[
torch
.
uint8
,
torch
.
uint16
,
torch
.
uint32
]
packed_dtype_idx
=
int
(
math
.
ceil
(
math
.
log2
((
num_verify_tokens
+
7
)
//
8
)))
tree_mask
=
torch
.
zeros
(
(
num_verify_tokens
*
bs
,),
dtype
=
packed_dtypes
[
packed_dtype_idx
],
device
=
device
,
)
elif
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
else
:
raise
NotImplementedError
(
f
"Invalid tree mask:
{
tree_mask_mode
=
}
"
)
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
retrive_buf
=
torch
.
full
(
(
3
,
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
retrive_index
,
retrive_next_token
,
retrive_next_sibling
=
retrive_buf
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
if
position_buf
is
not
None
:
positions
=
position_buf
else
:
positions
=
torch
.
empty
(
(
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
sgl_build_tree_kernel_efficient
(
parent_list
,
top_scores_index
,
seq_lens
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
spec_steps
,
num_verify_tokens
,
tree_mask_mode
,
)
return
(
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
)
def
select_top_k_tokens_tmp
(
def
select_top_k_tokens_tmp
(
i
:
int
,
i
:
int
,
...
...
python/sglang/srt/speculative/eagle_utils.py
0 → 100644
View file @
56222658
import
math
from
enum
import
IntEnum
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
organize_draft_results
(
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
num_draft_token
:
int
,
):
score_list
=
torch
.
cat
(
score_list
,
dim
=
1
).
flatten
(
1
)
ss_token_list
=
torch
.
cat
(
token_list
,
dim
=
1
)
top_scores
=
torch
.
topk
(
score_list
,
num_draft_token
-
1
,
dim
=-
1
)
top_scores_index
=
top_scores
.
indices
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
if
len
(
parents_list
)
>
1
:
parent_list
=
torch
.
cat
(
parents_list
[:
-
1
],
dim
=
1
)
else
:
batch_size
=
parents_list
[
0
].
shape
[
0
]
parent_list
=
torch
.
empty
(
batch_size
,
0
,
device
=
parents_list
[
0
].
device
)
return
parent_list
,
top_scores_index
,
draft_tokens
class
TreeMaskMode
(
IntEnum
):
FULL_MASK
=
0
QLEN_ONLY
=
1
QLEN_ONLY_BITPACKING
=
2
def
build_tree_kernel_efficient
(
verified_id
:
torch
.
Tensor
,
parent_list
:
List
[
torch
.
Tensor
],
top_scores_index
:
torch
.
Tensor
,
draft_tokens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
tree_mask_mode
:
TreeMaskMode
=
TreeMaskMode
.
FULL_MASK
,
tree_mask_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
position_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if
tree_mask_buf
is
not
None
:
tree_mask
=
tree_mask_buf
if
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY
:
tree_mask
.
fill_
(
True
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY_BITPACKING
:
tree_mask
.
fill_
(
0
)
elif
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
tree_mask
.
fill_
(
True
)
else
:
raise
NotImplementedError
(
f
"Invalid tree mask:
{
tree_mask_mode
=
}
"
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY
:
tree_mask
=
torch
.
full
(
(
num_verify_tokens
*
bs
*
num_verify_tokens
,),
True
,
dtype
=
torch
.
bool
,
device
=
device
,
)
elif
tree_mask_mode
==
TreeMaskMode
.
QLEN_ONLY_BITPACKING
:
packed_dtypes
=
[
torch
.
uint8
,
torch
.
uint16
,
torch
.
uint32
]
packed_dtype_idx
=
int
(
math
.
ceil
(
math
.
log2
((
num_verify_tokens
+
7
)
//
8
)))
tree_mask
=
torch
.
zeros
(
(
num_verify_tokens
*
bs
,),
dtype
=
packed_dtypes
[
packed_dtype_idx
],
device
=
device
,
)
elif
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
else
:
raise
NotImplementedError
(
f
"Invalid tree mask:
{
tree_mask_mode
=
}
"
)
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
retrive_buf
=
torch
.
full
(
(
3
,
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
retrive_index
,
retrive_next_token
,
retrive_next_sibling
=
retrive_buf
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
if
position_buf
is
not
None
:
positions
=
position_buf
else
:
positions
=
torch
.
empty
(
(
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel_efficient
(
parent_list
,
top_scores_index
,
seq_lens
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
spec_steps
,
num_verify_tokens
,
tree_mask_mode
,
)
return
(
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
python/sglang/srt/speculative/eagle_worker.py
View file @
56222658
...
@@ -28,7 +28,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -28,7 +28,6 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
EAGLEDraftCudaGraphRunner
,
EAGLEDraftCudaGraphRunner
,
)
)
...
@@ -40,6 +39,10 @@ from sglang.srt.speculative.eagle_info import (
...
@@ -40,6 +39,10 @@ from sglang.srt.speculative.eagle_info import (
EagleVerifyInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
EagleVerifyOutput
,
)
)
from
sglang.srt.speculative.eagle_utils
import
(
build_tree_kernel_efficient
,
organize_draft_results
,
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
assign_draft_cache_locs
,
assign_draft_cache_locs
,
...
@@ -677,7 +680,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -677,7 +680,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch
forward_batch
)
)
if
can_cuda_graph
:
if
can_cuda_graph
:
score
_list
,
to
ken_list
,
parents_list
=
self
.
cuda_graph_runner
.
replay
(
parent
_list
,
to
p_scores_index
,
draft_tokens
=
self
.
cuda_graph_runner
.
replay
(
forward_batch
forward_batch
)
)
else
:
else
:
...
@@ -686,7 +689,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -686,7 +689,9 @@ class EAGLEWorker(TpModelWorker):
# Initialize attention backend
# Initialize attention backend
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
# Run forward steps
# Run forward steps
score_list
,
token_list
,
parents_list
=
self
.
draft_forward
(
forward_batch
)
parent_list
,
top_scores_index
,
draft_tokens
=
self
.
draft_forward
(
forward_batch
)
if
batch
.
forward_mode
.
is_idle
():
if
batch
.
forward_mode
.
is_idle
():
return
EagleVerifyInput
.
create_idle_input
(
return
EagleVerifyInput
.
create_idle_input
(
...
@@ -704,9 +709,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -704,9 +709,9 @@ class EAGLEWorker(TpModelWorker):
draft_tokens
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
)
=
build_tree_kernel_efficient
(
spec_info
.
verified_id
,
spec_info
.
verified_id
,
score
_list
,
parent
_list
,
to
ken_list
,
to
p_scores_index
,
parents_list
,
draft_tokens
,
batch
.
seq_lens
,
batch
.
seq_lens
,
batch
.
seq_lens_sum
,
batch
.
seq_lens_sum
,
self
.
topk
,
self
.
topk
,
...
@@ -795,7 +800,11 @@ class EAGLEWorker(TpModelWorker):
...
@@ -795,7 +800,11 @@ class EAGLEWorker(TpModelWorker):
topk_index
=
self
.
hot_token_id
[
topk_index
]
topk_index
=
self
.
hot_token_id
[
topk_index
]
hidden_states
=
logits_output
.
hidden_states
hidden_states
=
logits_output
.
hidden_states
return
score_list
,
token_list
,
parents_list
parent_list
,
top_scores_index
,
draft_tokens
=
organize_draft_results
(
score_list
,
token_list
,
parents_list
,
self
.
speculative_num_draft_tokens
)
return
parent_list
,
top_scores_index
,
draft_tokens
def
clear_cache_pool
(
self
):
def
clear_cache_pool
(
self
):
self
.
model_runner
.
req_to_token_pool
.
clear
()
self
.
model_runner
.
req_to_token_pool
.
clear
()
...
...
python/sglang/srt/speculative/eagle_worker_v2.py
View file @
56222658
...
@@ -12,15 +12,14 @@ from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
...
@@ -12,15 +12,14 @@ from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardBatch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.build_eagle_tree
import
TreeMaskMode
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_info_v2
import
(
from
sglang.srt.speculative.eagle_info_v2
import
(
assign_extend_cache_locs
,
assign_extend_cache_locs
,
build_tree_kernel_efficient_tmp
,
fill_accepted_out_cache_loc
,
fill_accepted_out_cache_loc
,
fill_new_verified_id
,
fill_new_verified_id
,
select_top_k_tokens_tmp
,
select_top_k_tokens_tmp
,
)
)
from
sglang.srt.speculative.eagle_utils
import
TreeMaskMode
,
build_tree_kernel_efficient
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.utils.common
import
fast_topk
,
next_power_of_2
from
sglang.srt.utils.common
import
fast_topk
,
next_power_of_2
...
@@ -116,7 +115,7 @@ class EAGLEWorkerV2(EAGLEWorker):
...
@@ -116,7 +115,7 @@ class EAGLEWorkerV2(EAGLEWorker):
retrive_next_token
,
retrive_next_token
,
retrive_next_sibling
,
retrive_next_sibling
,
draft_tokens
,
draft_tokens
,
)
=
build_tree_kernel_efficient
_tmp
(
)
=
build_tree_kernel_efficient
(
draft_input
.
verified_id
,
draft_input
.
verified_id
,
parent_list
,
parent_list
,
top_scores_index
,
top_scores_index
,
...
...
test/srt/run_suite.py
View file @
56222658
...
@@ -69,6 +69,7 @@ suites = {
...
@@ -69,6 +69,7 @@ suites = {
TestFile
(
"test_chunked_prefill.py"
,
313
),
TestFile
(
"test_chunked_prefill.py"
,
313
),
TestFile
(
"test_create_kvindices.py"
,
2
),
TestFile
(
"test_create_kvindices.py"
,
2
),
TestFile
(
"test_deterministic.py"
,
300
),
TestFile
(
"test_deterministic.py"
,
300
),
TestFile
(
"test_build_eagle_tree.py"
,
8
),
TestFile
(
"test_eagle_infer_a.py"
,
370
),
TestFile
(
"test_eagle_infer_a.py"
,
370
),
TestFile
(
"test_eagle_infer_b.py"
,
700
),
TestFile
(
"test_eagle_infer_b.py"
,
700
),
TestFile
(
"test_eagle_infer_beta.py"
,
300
),
TestFile
(
"test_eagle_infer_beta.py"
,
300
),
...
...
test/srt/test_build_eagle_tree.py
0 → 100644
View file @
56222658
import
unittest
import
torch
from
sglang.srt.speculative.eagle_utils
import
(
build_tree_kernel_efficient
,
organize_draft_results
,
)
class
TestBuildEagleTree
(
unittest
.
TestCase
):
"""Unit tests for build_eagle_tree functionality."""
def
test_build_tree_kernel_efficient
(
self
):
"""Test the build_tree_kernel_efficient function with known inputs and expected outputs."""
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
score_list
=
[
torch
.
tensor
(
[
[[
7.1127e-01
,
2.8292e-01
,
2.2995e-03
,
1.7357e-03
]],
[[
9.7476e-01
,
2.2219e-02
,
6.5031e-04
,
1.3212e-04
]],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.9142e-01
,
1.2863e-02
,
1.6873e-03
,
1.1871e-03
],
[
2.4787e-01
,
1.8818e-02
,
1.4204e-02
,
9.2235e-04
],
[
2.2971e-03
,
1.6700e-06
,
1.8737e-07
,
8.3146e-08
],
[
1.2771e-03
,
2.4374e-04
,
1.7832e-04
,
1.1947e-05
],
],
[
[
8.4832e-02
,
6.6068e-02
,
5.8304e-02
,
5.7851e-02
],
[
2.3616e-03
,
1.1243e-03
,
5.4368e-04
,
2.7768e-04
],
[
2.5286e-04
,
1.5578e-04
,
2.8817e-05
,
1.2888e-05
],
[
1.2834e-04
,
2.5417e-06
,
1.1279e-06
,
1.6088e-08
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.6438e-01
,
2.6997e-02
,
2.4236e-05
,
4.0821e-06
],
[
2.4402e-01
,
2.8409e-03
,
5.0935e-04
,
2.9022e-04
],
[
1.6178e-02
,
2.0567e-03
,
4.5892e-04
,
3.0034e-05
],
[
1.3023e-02
,
5.0497e-04
,
3.6371e-04
,
8.7750e-05
],
],
[
[
2.3263e-02
,
2.0054e-02
,
9.3990e-03
,
2.7783e-03
],
[
6.4156e-02
,
5.5506e-04
,
1.0429e-04
,
9.7211e-05
],
[
4.9950e-02
,
5.0630e-03
,
9.0068e-04
,
3.3656e-04
],
[
7.5817e-03
,
8.5731e-04
,
6.9972e-04
,
6.0793e-04
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.6420e-01
,
1.0525e-04
,
6.5864e-05
,
1.2253e-06
],
[
1.3019e-01
,
1.0461e-01
,
5.2083e-03
,
1.6777e-03
],
[
2.0103e-02
,
6.7335e-03
,
1.2625e-04
,
1.0364e-05
],
[
1.5142e-02
,
7.0819e-04
,
9.6595e-05
,
8.7951e-05
],
],
[
[
5.8608e-02
,
1.8840e-03
,
7.8535e-04
,
4.4400e-04
],
[
1.2185e-02
,
2.0684e-03
,
1.7418e-03
,
1.4327e-03
],
[
6.2455e-03
,
6.1487e-03
,
2.6862e-03
,
1.8034e-03
],
[
1.8590e-03
,
1.6151e-03
,
1.2481e-03
,
3.6038e-04
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
]
token_list
=
[
torch
.
tensor
(
[[
29896
,
29906
,
29900
,
29945
],
[
13
,
2
,
29871
,
28956
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29889
,
29974
,
29945
,
29900
,
29974
,
29922
,
29930
,
29958
,
29889
,
29974
,
29930
,
29945
,
29974
,
29922
,
29930
,
29958
,
],
[
22550
,
4136
,
16492
,
8439
,
29871
,
2
,
3001
,
13
,
2
,
13
,
29906
,
29946
,
2
,
13
,
29871
,
259
,
],
],
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29946
,
29945
,
29953
,
29906
,
29896
,
29945
,
29900
,
29906
,
29896
,
29945
,
29906
,
29953
,
29896
,
29945
,
29906
,
29946
,
],
[
29871
,
2
,
29901
,
29889
,
29871
,
2
,
395
,
259
,
29901
,
29871
,
2
,
29889
,
3001
,
1234
,
7146
,
2186
,
],
],
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29946
,
29974
,
29945
,
29930
,
29889
,
29922
,
29974
,
29930
,
29974
,
29946
,
29930
,
29922
,
29889
,
29974
,
29945
,
29922
,
],
[
29941
,
29906
,
2
,
29946
,
29871
,
450
,
319
,
14990
,
29946
,
29941
,
2
,
29906
,
29871
,
2
,
3001
,
13
,
],
],
device
=
"cuda"
,
),
]
parents_list
=
[
torch
.
tensor
(
[[
-
1
,
0
,
1
,
2
,
3
],
[
-
1
,
0
,
1
,
2
,
3
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
4
,
8
,
9
,
10
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
20
,
24
,
21
,
28
],
[
24
,
28
,
20
,
21
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
36
,
40
,
41
,
44
],
[
36
,
40
,
44
,
45
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
]
seq_lens
=
torch
.
tensor
([
5
,
10
],
dtype
=
torch
.
int64
,
device
=
"cuda"
)
topk
=
4
depth
=
4
num_draft_token
=
8
parent_list
,
top_scores_index
,
draft_tokens
=
organize_draft_results
(
score_list
,
token_list
,
parents_list
,
num_draft_token
)
(
tree_mask
,
position
,
retrieve_index
,
retrieve_next_token
,
retrieve_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
=
verified_id
,
parent_list
=
parent_list
,
top_scores_index
=
top_scores_index
,
draft_tokens
=
draft_tokens
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
# Verify expected outputs
self
.
assertEqual
(
position
.
tolist
(),
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
],
"Position tensor does not match expected values"
,
)
self
.
assertEqual
(
retrieve_index
.
tolist
(),
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
],
"Retrieve index tensor does not match expected values"
,
)
self
.
assertEqual
(
retrieve_next_token
.
tolist
(),
[
[
1
,
3
,
4
,
5
,
6
,
7
,
-
1
,
-
1
],
[
1
,
2
,
-
1
,
6
,
-
1
,
-
1
,
7
,
-
1
],
],
"Retrieve next token tensor does not match expected values"
,
)
self
.
assertEqual
(
retrieve_next_sibling
.
tolist
(),
[
[
-
1
,
2
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
3
,
4
,
5
,
-
1
,
-
1
,
-
1
],
],
"Retrieve next sibling tensor does not match expected values"
,
)
self
.
assertEqual
(
draft_tokens
.
tolist
(),
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
],
"Draft tokens tensor does not match expected values"
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment