Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b859295
Unverified
Commit
1b859295
authored
Mar 16, 2025
by
Ying Sheng
Committed by
GitHub
Mar 16, 2025
Browse files
[Eagle] Remove the greedy branch and some redundant code (#4363)
Co-authored-by:
Sehoon Kim
<
sehoon@x.ai
>
parent
9971dc22
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
383 additions
and
675 deletions
+383
-675
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-11
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-1
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+7
-347
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+30
-5
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+204
-250
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+111
-46
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+11
-0
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+0
-3
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+4
-4
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+2
-2
test/srt/test_mla_flashinfer.py
test/srt/test_mla_flashinfer.py
+2
-2
No files found.
python/pyproject.toml
View file @
1b859295
...
...
@@ -43,7 +43,7 @@ runtime_common = [
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.0.5.post
1
"
,
"sgl-kernel==0.0.5.post
2
"
,
"flashinfer_python==0.2.3"
,
"torch==2.5.1"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
1b859295
...
...
@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
@
app
.
post
(
"/flush_cache"
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
]
)
async
def
flush_cache
():
"""Flush the radix cache."""
_global_state
.
tokenizer_manager
.
flush_cache
()
...
...
python/sglang/srt/managers/scheduler.py
View file @
1b859295
...
...
@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"largest-len:
{
self
.
_largest_prefill_decode_len
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
spec_accept_length
=
0
...
...
@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"largest-len:
{
self
.
_largest_prefill_decode_len
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
1b859295
...
...
@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else
:
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
else
:
capture_bs
=
list
(
range
(
1
,
33
))
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs
=
list
(
range
(
1
,
9
))
+
list
(
range
(
9
,
33
,
2
))
+
[
64
,
96
,
128
,
160
]
if
_is_hip
:
capture_bs
+=
[
i
*
8
for
i
in
range
(
21
,
33
)]
...
...
@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs
=
list
(
sorted
(
set
(
capture_bs
+
[
model_runner
.
req_to_token_pool
.
size
-
1
]
+
[
model_runner
.
req_to_token_pool
.
size
]
)
)
)
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
-
1
]
+
[
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
capture_bs
=
[
bs
for
bs
in
capture_bs
...
...
@@ -508,7 +505,9 @@ class CudaGraphRunner:
self
.
raw_num_token
=
raw_num_token
self
.
bs
=
bs
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
)
else
:
...
...
python/sglang/srt/server_args.py
View file @
1b859295
...
...
@@ -285,7 +285,6 @@ class ServerArgs:
if
self
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
32
self
.
disable_cuda_graph_padding
=
True
self
.
disable_overlap_schedule
=
True
logger
.
info
(
"Overlap scheduler is disabled because of using "
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
1b859295
...
...
@@ -3,8 +3,13 @@
from
typing
import
List
import
torch
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
build_tree_kernel_efficient_preprocess
(
...
...
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
top_scores
=
torch
.
topk
(
score_list
,
num_verify_tokens
-
1
,
dim
=-
1
)
top_scores_index
=
top_scores
.
indices
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
...
...
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
)
def
build_tree_kernel
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
):
parent_list
,
top_scores_index
,
draft_tokens
=
(
build_tree_kernel_efficient_preprocess
(
verified_id
,
score_list
,
token_list
,
parents_list
,
num_verify_tokens
,
)
)
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
retrive_index
=
torch
.
full
(
(
bs
,
num_verify_tokens
,
spec_steps
+
2
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel
(
parent_list
,
top_scores_index
,
seq_lens
.
to
(
torch
.
int32
),
tree_mask
,
positions
,
retrive_index
,
topk
,
spec_steps
,
num_verify_tokens
,
)
index
=
retrive_index
.
sum
(
dim
=-
1
)
!=
-
spec_steps
-
2
cum_len
=
torch
.
cumsum
(
torch
.
sum
(
index
,
dim
=-
1
),
dim
=-
1
)
retrive_cum_len
=
torch
.
zeros
(
(
cum_len
.
numel
()
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
retrive_cum_len
[
1
:]
=
cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index
=
retrive_index
[
index
]
return
tree_mask
,
positions
,
retrive_index
,
retrive_cum_len
,
draft_tokens
def
test_build_tree_kernel
():
def
findp
(
p_i
,
index
,
parent_list
):
pos
=
index
//
10
index_list
=
index
.
tolist
()
parent_list
=
parent_list
.
tolist
()
res
=
[
p_i
]
while
True
:
p
=
pos
[
p_i
]
if
p
==
0
:
break
token_idx
=
parent_list
[
p
]
p_i
=
index_list
.
index
(
token_idx
)
res
.
append
(
p_i
)
return
res
def
create_mask
(
seq_len
,
draft_token
,
index
,
parent_list
,
max_depth
):
mask
=
[]
positions
=
[]
retrive_index
=
[]
for
i
,
lens
in
enumerate
(
seq_len
.
tolist
()):
first_mask
=
torch
.
full
((
lens
+
draft_token
,),
True
)
first_mask
[
-
(
draft_token
-
1
)
:]
=
False
positions
.
append
(
lens
)
mask
.
append
(
first_mask
)
seq_order
=
[]
first_index
=
torch
.
Tensor
([
0
]
+
[
-
1
]
*
(
depth
+
1
)).
cuda
().
to
(
torch
.
long
)
r_index
=
[
first_index
]
for
j
in
range
(
draft_token
-
1
):
mask
.
append
(
torch
.
full
((
lens
+
1
,),
True
))
idx
=
findp
(
j
,
index
,
parent_list
)
seq_order
.
append
(
idx
)
positions
.
append
(
len
(
idx
)
+
seq_len
)
t
=
torch
.
full
((
draft_token
-
1
,),
False
)
t
[
idx
]
=
True
mask
.
append
(
t
)
for
i
in
range
(
1
,
draft_token
-
1
):
is_leaf
=
0
for
j
in
range
(
draft_token
-
1
):
if
i
in
seq_order
[
j
]:
is_leaf
+=
1
if
is_leaf
==
1
:
order_list
=
[
0
]
+
[
x
+
1
for
x
in
seq_order
[
i
][::
-
1
]]
for
_
in
range
(
max_depth
+
1
-
len
(
seq_order
[
i
])):
order_list
.
append
(
-
1
)
order
=
torch
.
Tensor
(
order_list
).
cuda
().
to
(
torch
.
long
)
r_index
.
append
(
order
)
retrive_index
.
append
(
torch
.
stack
(
r_index
))
return
(
torch
.
cat
(
mask
).
cuda
(),
torch
.
Tensor
(
positions
).
cuda
().
to
(
torch
.
long
),
torch
.
stack
(
retrive_index
),
)
index
=
(
torch
.
Tensor
(
[
0
,
1
,
2
,
3
,
10
,
11
,
12
,
13
,
20
,
21
,
22
,
30
,
110
,
130
,
150
,
160
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
217
,
218
,
219
,
220
,
230
,
310
,
311
,
312
,
313
,
314
,
315
,
316
,
317
,
320
,
321
,
322
,
330
,
360
,
380
,
390
,
410
,
411
,
412
,
413
,
414
,
415
,
416
,
417
,
418
,
419
,
420
,
421
,
422
,
423
,
430
,
431
,
440
,
441
,
460
,
470
,
]
)
.
to
(
torch
.
long
)
.
cuda
()
)
parent_list
=
(
torch
.
Tensor
(
[
-
1
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
20
,
30
,
21
,
13
,
22
,
40
,
23
,
110
,
130
,
160
,
150
,
190
,
120
,
111
,
121
,
200
,
180
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
220
,
230
,
217
,
310
,
311
,
312
,
313
,
320
,
314
,
321
,
315
,
316
,
317
,
]
)
.
to
(
torch
.
long
)
.
cuda
()
)
verified_seq_len
=
torch
.
Tensor
([
47
]).
to
(
torch
.
long
).
cuda
()
bs
=
verified_seq_len
.
shape
[
0
]
topk
=
10
depth
=
5
# depth <= 10
num_draft_token
=
64
tree_mask
=
torch
.
full
(
(
torch
.
sum
(
verified_seq_len
).
item
()
*
num_draft_token
+
num_draft_token
*
num_draft_token
*
bs
,
),
True
,
).
cuda
()
retrive_index
=
torch
.
full
(
(
bs
,
num_draft_token
,
depth
+
2
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_draft_token
,),
device
=
"cuda"
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel
(
parent_list
.
unsqueeze
(
0
),
index
.
unsqueeze
(
0
),
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
num_draft_token
,
)
retrive_index
=
retrive_index
[
retrive_index
.
sum
(
dim
=-
1
)
!=
-
depth
-
2
]
c_mask
,
c_positions
,
c_retive_index
=
create_mask
(
verified_seq_len
,
num_draft_token
,
index
,
parent_list
,
depth
)
assert
torch
.
allclose
(
tree_mask
,
c_mask
),
"tree mask has error."
assert
torch
.
allclose
(
positions
,
c_positions
),
"positions has error."
assert
torch
.
allclose
(
retrive_index
,
c_retive_index
),
"retrive_index has error."
def
test_build_tree_kernel_efficient
():
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
score_list
=
[
...
...
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
depth
=
4
num_draft_token
=
8
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
build_tree_kernel
(
verified_id
=
verified_id
,
score_list
=
score_list
,
token_list
=
token_list
,
parents_list
=
parents_list
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
)
from
sglang.srt.utils
import
first_rank_print
first_rank_print
(
"=========== build tree kernel =========="
)
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_cum_len
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
0
,
2
,
4
,
6
,
-
1
,
-
1
],
[
0
,
1
,
3
,
5
,
7
,
-
1
],
[
8
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
10
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
12
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
13
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
11
,
14
,
15
,
-
1
],
]
assert
retrive_cum_len
.
tolist
()
==
[
0
,
3
,
8
]
assert
draft_tokens
.
tolist
()
==
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
]
(
tree_mask
,
position
,
...
...
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
if
__name__
==
"__main__"
:
test_build_tree_kernel_efficient
()
test_build_tree_kernel
()
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
1b859295
...
...
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
EAGLEDraftCudaGraphRunner
:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
...
...
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
server_args
=
model_runner
.
server_args
assert
self
.
disable_padding
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
num_tokens_per_bs
=
server_args
.
speculative_eagle_topk
...
...
@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
set_global_graph_memory_pool
(
graph
.
pool
())
return
graph
,
out
def
_postprocess_output_to_raw_bs
(
self
,
out
,
raw_bs
):
score_list
,
token_list
,
parents_list
=
out
score_list
=
[
x
[:
raw_bs
]
for
x
in
score_list
]
token_list
=
[
x
[:
raw_bs
]
for
x
in
token_list
]
parents_list
=
[
x
[:
raw_bs
]
for
x
in
parents_list
]
return
(
score_list
,
token_list
,
parents_list
)
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
...
...
@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
self
.
positions
.
zero_
()
num_tokens
=
bs
*
self
.
num_tokens_per_bs
# Common inputs
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
...
...
@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
self
.
hidden_states
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
# Attention backend
if
bs
!=
raw_bs
:
forward_batch
.
batch_size
=
bs
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
forward_batch
,
forward_batch
.
batch_size
forward_batch
,
bs
)
# Replay
self
.
graphs
[
bs
].
replay
()
out
=
self
.
output_buffers
[
bs
]
if
bs
!=
raw_bs
:
out
=
self
.
_postprocess_output_to_raw_bs
(
out
,
raw_bs
)
forward_batch
.
batch_size
=
raw_bs
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
return
self
.
output_buffers
[
bs
]
return
out
python/sglang/srt/speculative/eagle_utils.py
View file @
1b859295
This diff is collapsed.
Click to expand it.
python/sglang/srt/speculative/eagle_worker.py
View file @
1b859295
import
logging
import
os
import
time
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
huggingface_hub
import
snapshot_download
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.layers.dp_attention
import
disable_dp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk
,
select_top_k_tokens
,
)
from
sglang.srt.utils
import
get_available_gpu_memory
from
sglang.srt.utils
import
empty_context
,
get_available_gpu_memory
,
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
segment_packbits
logger
=
logging
.
getLogger
(
__name__
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
disable_dp_size
(),
patch_tensor_parallel_group
(
tp_group
):
yield
class
EAGLEWorker
(
TpModelWorker
):
def
__init__
(
...
...
@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
None
# Init draft worker
super
().
__init__
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
with
empty_context
():
super
().
__init__
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
# Share the embedding and lm_head
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
...
...
@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
self
.
draft_model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
# Init attention backend and cuda graphs
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
)
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
self
.
draft_tp_context
=
(
draft_tp_context
if
server_args
.
enable_dp_attention
else
empty_context
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
def
init_attention_backend
(
self
):
# Create multi-step attn backends and cuda graph runners
...
...
@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"triton"
:
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
FlashInferMLAMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
else
:
raise
ValueError
(
f
"EAGLE is not supportted in attention backend
{
self
.
server_args
.
attention_backend
}
"
)
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
if
self
.
server_args
.
disable_cuda_graph
:
return
# Capture draft
tic
=
time
.
time
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
:.
2
f
}
GB"
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
after_mem
:.
2
f
}
GB. mem usage=
{
(
before_mem
-
after_mem
):.
2
f
}
GB
.
"
)
# Capture extend
if
self
.
draft_extend_attn_backend
:
raise
NotImplementedError
()
@
property
def
draft_model_runner
(
self
):
return
self
.
model_runner
...
...
@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
,
int
]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch
doesn't
have the same state as the input.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
that
the final output batch have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
...
...
@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert
not
batch
.
spec_algorithm
.
is_none
()
if
batch
.
forward_mode
.
is_decode
():
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
# if it is None, means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_draft_extend_after_decode
(
batch
)
# If it is None, it means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
logits_output
,
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
)
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
,
_
=
(
self
.
target_worker
.
forward_batch_generation
(
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
target_worker
.
model_runner
)
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
return
logits_output
,
next_token_ids
,
bid
,
0
def
forward_target_extend
(
...
...
@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
num_seqs
=
batch
.
batch_size
()
spec_info
=
batch
.
spec_info
# Accumulate penalty
if
batch
.
sampling_info
.
penalizer_orchestrator
.
is_required
:
# This is a relaxed version of penalties for speculative decoding.
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
spec_info
.
verified_id
.
to
(
torch
.
int64
)
)
# Allocate cache locations
out_cache_loc
=
batch
.
alloc_token_slots
(
num_seqs
*
self
.
topk
*
self
.
speculative_num_steps
...
...
@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
speculative_num_steps
,
self
.
server_args
.
speculative_num_draft_tokens
,
batch
.
sampling_info
.
is_all_greedy
,
)
return
ret
,
out_cache_loc
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
...
...
@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
token_list
.
append
(
tree_info
[
1
])
parents_list
.
append
(
tree_info
[
2
])
#
w
e don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
#
W
e don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if
i
==
self
.
speculative_num_steps
-
1
:
break
...
...
@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
spec_info
.
hidden_states
=
hidden_states
# Run forward
logits_output
=
self
.
model_runner
.
model
.
forward
(
logits_output
=
self
.
draft_
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
...
...
@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
res
.
accepeted_indices_cpu
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices_cpu
res
.
accepeted_indices
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices
]
# Prepare the batch for the next draft forwards.
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
spec_info
=
res
.
draft_input
...
...
@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
batch_next_token_ids
,
]
# Add output logprobs to the request
.
# Add output logprobs to the request
pt
=
0
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
verified_ids
=
batch_next_token_ids
.
tolist
()
...
...
@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
seq_lens_backup
=
batch
.
seq_lens
# Backup fileds that will be modified in-place
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
req_pool_indices_backup
=
batch
.
req_pool_indices
accept_length_backup
=
batch
.
spec_info
.
accept_length
return_logprob_backup
=
batch
.
return_logprob
# Prepare metadata
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
)
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
,
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
# We don't need logprob for this extend.
original_return_logprob
=
batch
.
return_logprob
batch
.
return_logprob
=
False
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
# Run
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch
.
return_logprob
=
original_return_logprob
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
seq_lens
=
seq_lens_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
batch
.
spec_info
.
accept_length
=
accept_length_backup
batch
.
return_logprob
=
return_logprob_backup
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
draft_input
:
EagleDraftInput
...
...
@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
if
self
.
enable_nan_detection
:
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
...
...
python/sglang/srt/utils.py
View file @
1b859295
...
...
@@ -36,6 +36,7 @@ import tempfile
import
threading
import
time
import
warnings
from
contextlib
import
contextmanager
from
functools
import
lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
...
...
@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
setattr
(
triton
,
"next_power_of_2"
,
next_power_of_2
)
@
contextmanager
def
empty_context
(
*
args
,
**
kwargs
):
try
:
# Setup code goes here
yield
finally
:
# Cleanup code goes here
pass
def
add_prefix
(
name
:
str
,
prefix
:
str
)
->
str
:
"""Add a weight path prefix to a module name.
...
...
scripts/ci_install_dependency.sh
View file @
1b859295
...
...
@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
pip
install
sgl-kernel
==
0.0.5.post1
--force-reinstall
--no-deps
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
1b859295
...
...
@@ -36,8 +36,8 @@ template <
typename
DType
,
typename
IdType
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
predicts
,
// mutable
IdType
*
accept_index
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
...
...
@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
template
<
typename
DType
,
typename
IdType
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
output_token_ids
,
IdType
*
predicts
,
// mutable
IdType
*
output_token_ids
,
// mutable
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
...
...
test/srt/test_eagle_infer.py
View file @
1b859295
...
...
@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
def
_test_acc_length
(
self
,
engine
):
prompt
=
[
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
]
*
5
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
,
]
*
5
# test batched generation
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
512
}
output
=
engine
.
generate
(
prompt
,
sampling_params
)
output
=
output
[
0
]
...
...
test/srt/test_mla_flashinfer.py
View file @
1b859295
...
...
@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
"--enable-torch-compile"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--enable-flashinfer-mla"
,
"--flashinfer-mla-disable-ragged"
,
]
...
...
@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
other_args
.
extend
(
[
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--disable-radix"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment