Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b859295
Unverified
Commit
1b859295
authored
Mar 16, 2025
by
Ying Sheng
Committed by
GitHub
Mar 16, 2025
Browse files
[Eagle] Remove the greedy branch and some redundant code (#4363)
Co-authored-by:
Sehoon Kim
<
sehoon@x.ai
>
parent
9971dc22
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
383 additions
and
675 deletions
+383
-675
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-11
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-1
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+7
-347
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+30
-5
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+204
-250
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+111
-46
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+11
-0
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+0
-3
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+4
-4
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+2
-2
test/srt/test_mla_flashinfer.py
test/srt/test_mla_flashinfer.py
+2
-2
No files found.
python/pyproject.toml
View file @
1b859295
...
@@ -43,7 +43,7 @@ runtime_common = [
...
@@ -43,7 +43,7 @@ runtime_common = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.0.5.post
1
"
,
"sgl-kernel==0.0.5.post
2
"
,
"flashinfer_python==0.2.3"
,
"flashinfer_python==0.2.3"
,
"torch==2.5.1"
,
"torch==2.5.1"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
1b859295
...
@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
...
@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
@
app
.
post
(
"/flush_cache"
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
]
)
async
def
flush_cache
():
async
def
flush_cache
():
"""Flush the radix cache."""
"""Flush the radix cache."""
_global_state
.
tokenizer_manager
.
flush_cache
()
_global_state
.
tokenizer_manager
.
flush_cache
()
...
...
python/sglang/srt/managers/scheduler.py
View file @
1b859295
...
@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f
"#token:
{
num_used
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"largest-len:
{
self
.
_largest_prefill_decode_len
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
)
spec_accept_length
=
0
spec_accept_length
=
0
...
@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"largest-len:
{
self
.
_largest_prefill_decode_len
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
1b859295
...
@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else
:
else
:
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
else
:
else
:
capture_bs
=
list
(
range
(
1
,
33
))
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs
=
list
(
range
(
1
,
9
))
+
list
(
range
(
9
,
33
,
2
))
+
[
64
,
96
,
128
,
160
]
if
_is_hip
:
if
_is_hip
:
capture_bs
+=
[
i
*
8
for
i
in
range
(
21
,
33
)]
capture_bs
+=
[
i
*
8
for
i
in
range
(
21
,
33
)]
...
@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs
=
list
(
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
-
1
]
+
[
sorted
(
model_runner
.
req_to_token_pool
.
size
set
(
]
capture_bs
+
[
model_runner
.
req_to_token_pool
.
size
-
1
]
+
[
model_runner
.
req_to_token_pool
.
size
]
)
)
)
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
capture_bs
=
[
capture_bs
=
[
bs
bs
for
bs
in
capture_bs
for
bs
in
capture_bs
...
@@ -508,7 +505,9 @@ class CudaGraphRunner:
...
@@ -508,7 +505,9 @@ class CudaGraphRunner:
self
.
raw_num_token
=
raw_num_token
self
.
raw_num_token
=
raw_num_token
self
.
bs
=
bs
self
.
bs
=
bs
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
)
self
.
replay_prepare
(
forward_batch
)
else
:
else
:
...
...
python/sglang/srt/server_args.py
View file @
1b859295
...
@@ -285,7 +285,6 @@ class ServerArgs:
...
@@ -285,7 +285,6 @@ class ServerArgs:
if
self
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
max_running_requests
is
None
:
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
32
self
.
max_running_requests
=
32
self
.
disable_cuda_graph_padding
=
True
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
logger
.
info
(
logger
.
info
(
"Overlap scheduler is disabled because of using "
"Overlap scheduler is disabled because of using "
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
1b859295
...
@@ -3,8 +3,13 @@
...
@@ -3,8 +3,13 @@
from
typing
import
List
from
typing
import
List
import
torch
import
torch
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
build_tree_kernel_efficient_preprocess
(
def
build_tree_kernel_efficient_preprocess
(
...
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
...
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
top_scores
=
torch
.
topk
(
score_list
,
num_verify_tokens
-
1
,
dim
=-
1
)
top_scores
=
torch
.
topk
(
score_list
,
num_verify_tokens
-
1
,
dim
=-
1
)
top_scores_index
=
top_scores
.
indices
top_scores_index
=
top_scores
.
indices
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
...
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
...
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
)
)
def
build_tree_kernel
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
):
parent_list
,
top_scores_index
,
draft_tokens
=
(
build_tree_kernel_efficient_preprocess
(
verified_id
,
score_list
,
token_list
,
parents_list
,
num_verify_tokens
,
)
)
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
retrive_index
=
torch
.
full
(
(
bs
,
num_verify_tokens
,
spec_steps
+
2
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel
(
parent_list
,
top_scores_index
,
seq_lens
.
to
(
torch
.
int32
),
tree_mask
,
positions
,
retrive_index
,
topk
,
spec_steps
,
num_verify_tokens
,
)
index
=
retrive_index
.
sum
(
dim
=-
1
)
!=
-
spec_steps
-
2
cum_len
=
torch
.
cumsum
(
torch
.
sum
(
index
,
dim
=-
1
),
dim
=-
1
)
retrive_cum_len
=
torch
.
zeros
(
(
cum_len
.
numel
()
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
retrive_cum_len
[
1
:]
=
cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index
=
retrive_index
[
index
]
return
tree_mask
,
positions
,
retrive_index
,
retrive_cum_len
,
draft_tokens
def
test_build_tree_kernel
():
def
findp
(
p_i
,
index
,
parent_list
):
pos
=
index
//
10
index_list
=
index
.
tolist
()
parent_list
=
parent_list
.
tolist
()
res
=
[
p_i
]
while
True
:
p
=
pos
[
p_i
]
if
p
==
0
:
break
token_idx
=
parent_list
[
p
]
p_i
=
index_list
.
index
(
token_idx
)
res
.
append
(
p_i
)
return
res
def
create_mask
(
seq_len
,
draft_token
,
index
,
parent_list
,
max_depth
):
mask
=
[]
positions
=
[]
retrive_index
=
[]
for
i
,
lens
in
enumerate
(
seq_len
.
tolist
()):
first_mask
=
torch
.
full
((
lens
+
draft_token
,),
True
)
first_mask
[
-
(
draft_token
-
1
)
:]
=
False
positions
.
append
(
lens
)
mask
.
append
(
first_mask
)
seq_order
=
[]
first_index
=
torch
.
Tensor
([
0
]
+
[
-
1
]
*
(
depth
+
1
)).
cuda
().
to
(
torch
.
long
)
r_index
=
[
first_index
]
for
j
in
range
(
draft_token
-
1
):
mask
.
append
(
torch
.
full
((
lens
+
1
,),
True
))
idx
=
findp
(
j
,
index
,
parent_list
)
seq_order
.
append
(
idx
)
positions
.
append
(
len
(
idx
)
+
seq_len
)
t
=
torch
.
full
((
draft_token
-
1
,),
False
)
t
[
idx
]
=
True
mask
.
append
(
t
)
for
i
in
range
(
1
,
draft_token
-
1
):
is_leaf
=
0
for
j
in
range
(
draft_token
-
1
):
if
i
in
seq_order
[
j
]:
is_leaf
+=
1
if
is_leaf
==
1
:
order_list
=
[
0
]
+
[
x
+
1
for
x
in
seq_order
[
i
][::
-
1
]]
for
_
in
range
(
max_depth
+
1
-
len
(
seq_order
[
i
])):
order_list
.
append
(
-
1
)
order
=
torch
.
Tensor
(
order_list
).
cuda
().
to
(
torch
.
long
)
r_index
.
append
(
order
)
retrive_index
.
append
(
torch
.
stack
(
r_index
))
return
(
torch
.
cat
(
mask
).
cuda
(),
torch
.
Tensor
(
positions
).
cuda
().
to
(
torch
.
long
),
torch
.
stack
(
retrive_index
),
)
index
=
(
torch
.
Tensor
(
[
0
,
1
,
2
,
3
,
10
,
11
,
12
,
13
,
20
,
21
,
22
,
30
,
110
,
130
,
150
,
160
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
217
,
218
,
219
,
220
,
230
,
310
,
311
,
312
,
313
,
314
,
315
,
316
,
317
,
320
,
321
,
322
,
330
,
360
,
380
,
390
,
410
,
411
,
412
,
413
,
414
,
415
,
416
,
417
,
418
,
419
,
420
,
421
,
422
,
423
,
430
,
431
,
440
,
441
,
460
,
470
,
]
)
.
to
(
torch
.
long
)
.
cuda
()
)
parent_list
=
(
torch
.
Tensor
(
[
-
1
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
20
,
30
,
21
,
13
,
22
,
40
,
23
,
110
,
130
,
160
,
150
,
190
,
120
,
111
,
121
,
200
,
180
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
220
,
230
,
217
,
310
,
311
,
312
,
313
,
320
,
314
,
321
,
315
,
316
,
317
,
]
)
.
to
(
torch
.
long
)
.
cuda
()
)
verified_seq_len
=
torch
.
Tensor
([
47
]).
to
(
torch
.
long
).
cuda
()
bs
=
verified_seq_len
.
shape
[
0
]
topk
=
10
depth
=
5
# depth <= 10
num_draft_token
=
64
tree_mask
=
torch
.
full
(
(
torch
.
sum
(
verified_seq_len
).
item
()
*
num_draft_token
+
num_draft_token
*
num_draft_token
*
bs
,
),
True
,
).
cuda
()
retrive_index
=
torch
.
full
(
(
bs
,
num_draft_token
,
depth
+
2
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_draft_token
,),
device
=
"cuda"
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel
(
parent_list
.
unsqueeze
(
0
),
index
.
unsqueeze
(
0
),
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
num_draft_token
,
)
retrive_index
=
retrive_index
[
retrive_index
.
sum
(
dim
=-
1
)
!=
-
depth
-
2
]
c_mask
,
c_positions
,
c_retive_index
=
create_mask
(
verified_seq_len
,
num_draft_token
,
index
,
parent_list
,
depth
)
assert
torch
.
allclose
(
tree_mask
,
c_mask
),
"tree mask has error."
assert
torch
.
allclose
(
positions
,
c_positions
),
"positions has error."
assert
torch
.
allclose
(
retrive_index
,
c_retive_index
),
"retrive_index has error."
def
test_build_tree_kernel_efficient
():
def
test_build_tree_kernel_efficient
():
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
score_list
=
[
score_list
=
[
...
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
...
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
depth
=
4
depth
=
4
num_draft_token
=
8
num_draft_token
=
8
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
build_tree_kernel
(
verified_id
=
verified_id
,
score_list
=
score_list
,
token_list
=
token_list
,
parents_list
=
parents_list
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
)
from
sglang.srt.utils
import
first_rank_print
first_rank_print
(
"=========== build tree kernel =========="
)
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_cum_len
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
0
,
2
,
4
,
6
,
-
1
,
-
1
],
[
0
,
1
,
3
,
5
,
7
,
-
1
],
[
8
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
10
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
12
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
13
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
11
,
14
,
15
,
-
1
],
]
assert
retrive_cum_len
.
tolist
()
==
[
0
,
3
,
8
]
assert
draft_tokens
.
tolist
()
==
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
]
(
(
tree_mask
,
tree_mask
,
position
,
position
,
...
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
...
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_build_tree_kernel_efficient
()
test_build_tree_kernel_efficient
()
test_build_tree_kernel
()
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
1b859295
...
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
...
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
EAGLEDraftCudaGraphRunner
:
class
EAGLEDraftCudaGraphRunner
:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
...
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
server_args
=
model_runner
.
server_args
server_args
=
model_runner
.
server_args
assert
self
.
disable_padding
# Batch sizes to capture
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
num_tokens_per_bs
=
server_args
.
speculative_eagle_topk
self
.
num_tokens_per_bs
=
server_args
.
speculative_eagle_topk
...
@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
set_global_graph_memory_pool
(
graph
.
pool
())
set_global_graph_memory_pool
(
graph
.
pool
())
return
graph
,
out
return
graph
,
out
def
_postprocess_output_to_raw_bs
(
self
,
out
,
raw_bs
):
score_list
,
token_list
,
parents_list
=
out
score_list
=
[
x
[:
raw_bs
]
for
x
in
score_list
]
token_list
=
[
x
[:
raw_bs
]
for
x
in
token_list
]
parents_list
=
[
x
[:
raw_bs
]
for
x
in
parents_list
]
return
(
score_list
,
token_list
,
parents_list
)
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
...
@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
seq_lens
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
self
.
positions
.
zero_
()
num_tokens
=
bs
*
self
.
num_tokens_per_bs
# Common inputs
# Common inputs
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
...
@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
self
.
hidden_states
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
self
.
hidden_states
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
# Attention backend
# Attention backend
if
bs
!=
raw_bs
:
forward_batch
.
batch_size
=
bs
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
forward_batch
,
forward_batch
.
batch_size
forward_batch
,
bs
)
)
# Replay
# Replay
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
out
=
self
.
output_buffers
[
bs
]
if
bs
!=
raw_bs
:
out
=
self
.
_postprocess_output_to_raw_bs
(
out
,
raw_bs
)
forward_batch
.
batch_size
=
raw_bs
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
return
self
.
output_buffers
[
bs
]
return
out
python/sglang/srt/speculative/eagle_utils.py
View file @
1b859295
This diff is collapsed.
Click to expand it.
python/sglang/srt/speculative/eagle_worker.py
View file @
1b859295
import
logging
import
logging
import
os
import
os
import
time
import
time
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.layers.dp_attention
import
disable_dp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
...
@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk
,
fast_topk
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.utils
import
get_available_gpu_memory
from
sglang.srt.utils
import
empty_context
,
get_available_gpu_memory
,
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
segment_packbits
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
disable_dp_size
(),
patch_tensor_parallel_group
(
tp_group
):
yield
class
EAGLEWorker
(
TpModelWorker
):
class
EAGLEWorker
(
TpModelWorker
):
def
__init__
(
def
__init__
(
...
@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
...
@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
None
self
.
hot_token_id
=
None
# Init draft worker
# Init draft worker
super
().
__init__
(
with
empty_context
():
gpu_id
=
gpu_id
,
super
().
__init__
(
tp_rank
=
tp_rank
,
gpu_id
=
gpu_id
,
server_args
=
server_args
,
tp_rank
=
tp_rank
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
dp_rank
=
dp_rank
,
nccl_port
=
nccl_port
,
is_draft_worker
=
True
,
dp_rank
=
dp_rank
,
req_to_token_pool
=
self
.
req_to_token_pool
,
is_draft_worker
=
True
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
req_to_token_pool
=
self
.
req_to_token_pool
,
)
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
# Share the embedding and lm_head
# Share the embedding and lm_head
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
...
@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
...
@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
self
.
draft_model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
draft_model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
# Init attention backend and cuda graphs
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
backup_disable_cuda_graph
)
)
self
.
draft_tp_context
=
(
self
.
init_attention_backend
()
draft_tp_context
if
server_args
.
enable_dp_attention
else
empty_context
self
.
init_cuda_graphs
()
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
def
init_attention_backend
(
self
):
def
init_attention_backend
(
self
):
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
...
@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
...
@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
)
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
from
sglang.srt.layers.attention.triton_backend
import
(
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
TritonMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
FlashInferMLAMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_attn_backend
=
FlashInferMLAMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"EAGLE is not supportted in attention backend
{
self
.
server_args
.
attention_backend
}
"
f
"EAGLE is not supportted in attention backend
{
self
.
server_args
.
attention_backend
}
"
)
)
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
if
self
.
server_args
.
disable_cuda_graph
:
if
self
.
server_args
.
disable_cuda_graph
:
return
return
# Capture draft
tic
=
time
.
time
()
tic
=
time
.
time
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
:.
2
f
}
GB"
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
after_mem
:.
2
f
}
GB. mem usage=
{
(
before_mem
-
after_mem
):.
2
f
}
GB
.
"
)
)
# Capture extend
if
self
.
draft_extend_attn_backend
:
raise
NotImplementedError
()
@
property
@
property
def
draft_model_runner
(
self
):
def
draft_model_runner
(
self
):
return
self
.
model_runner
return
self
.
model_runner
...
@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
,
int
]:
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
,
int
]:
"""Run speculative decoding forward.
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
NOTE: Many states of batch is modified as you go through. It is not guaranteed
that
the final output batch
doesn't
have the same state as the input.
the final output batch have the same state as the input.
Args:
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
batch: The batch to run forward. The state of the batch is modified as it runs.
...
@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
...
@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepeted,
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
"""
assert
not
batch
.
spec_algorithm
.
is_none
()
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
batch
,
spec_info
)
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
# if it is None, means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_draft_extend_after_decode
(
batch
)
# If it is None, it means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
return
(
logits_output
,
logits_output
,
verify_output
.
verified_id
,
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
sum
(
verify_output
.
accept_length_per_req_cpu
),
)
)
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
,
_
=
(
self
.
target_worker
.
forward_batch_generation
(
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
target_worker
.
model_runner
)
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
else
:
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
self
.
forward_draft_extend
(
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
batch
,
logits_output
.
hidden_states
,
next_token_ids
self
.
forward_draft_extend
(
)
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
return
logits_output
,
next_token_ids
,
bid
,
0
return
logits_output
,
next_token_ids
,
bid
,
0
def
forward_target_extend
(
def
forward_target_extend
(
...
@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
...
@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
num_seqs
=
batch
.
batch_size
()
num_seqs
=
batch
.
batch_size
()
spec_info
=
batch
.
spec_info
spec_info
=
batch
.
spec_info
# Accumulate penalty
if
batch
.
sampling_info
.
penalizer_orchestrator
.
is_required
:
# This is a relaxed version of penalties for speculative decoding.
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
spec_info
.
verified_id
.
to
(
torch
.
int64
)
)
# Allocate cache locations
# Allocate cache locations
out_cache_loc
=
batch
.
alloc_token_slots
(
out_cache_loc
=
batch
.
alloc_token_slots
(
num_seqs
*
self
.
topk
*
self
.
speculative_num_steps
num_seqs
*
self
.
topk
*
self
.
speculative_num_steps
...
@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
speculative_num_draft_tokens
,
self
.
server_args
.
speculative_num_draft_tokens
,
batch
.
sampling_info
.
is_all_greedy
,
)
)
return
ret
,
out_cache_loc
return
ret
,
out_cache_loc
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
...
@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
token_list
.
append
(
tree_info
[
1
])
token_list
.
append
(
tree_info
[
1
])
parents_list
.
append
(
tree_info
[
2
])
parents_list
.
append
(
tree_info
[
2
])
#
w
e don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
#
W
e don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if
i
==
self
.
speculative_num_steps
-
1
:
if
i
==
self
.
speculative_num_steps
-
1
:
break
break
...
@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
spec_info
.
hidden_states
=
hidden_states
spec_info
.
hidden_states
=
hidden_states
# Run forward
# Run forward
logits_output
=
self
.
model_runner
.
model
.
forward
(
logits_output
=
self
.
draft_
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
...
@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
# Post process based on verified outputs.
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
# Pick indices that we care (accepeted)
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
res
.
accepeted_indices_cpu
res
.
accepeted_indices
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices_cpu
]
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices
]
# Prepare the batch for the next draft forwards.
# Prepare the batch for the next draft forwards.
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
spec_info
=
res
.
draft_input
batch
.
spec_info
=
res
.
draft_input
...
@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
batch_next_token_ids
,
batch_next_token_ids
,
]
]
# Add output logprobs to the request
.
# Add output logprobs to the request
pt
=
0
pt
=
0
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
verified_ids
=
batch_next_token_ids
.
tolist
()
verified_ids
=
batch_next_token_ids
.
tolist
()
...
@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
...
@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
seq_lens_backup
=
batch
.
seq_lens
# Backup fileds that will be modified in-place
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
req_pool_indices_backup
=
batch
.
req_pool_indices
accept_length_backup
=
batch
.
spec_info
.
accept_length
return_logprob_backup
=
batch
.
return_logprob
# Prepare metadata
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
)
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
,
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
# We don't need logprob for this extend.
original_return_logprob
=
batch
.
return_logprob
batch
.
return_logprob
=
False
batch
.
return_logprob
=
False
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
)
)
# Run
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
# Restore backup.
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch
.
return_logprob
=
original_return_logprob
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
seq_lens
=
seq_lens_backup
batch
.
seq_lens
=
seq_lens_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
batch
.
spec_info
.
accept_length
=
accept_length_backup
batch
.
return_logprob
=
return_logprob_backup
def
capture_for_decode
(
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
draft_input
:
EagleDraftInput
self
,
logits_output
:
LogitsProcessorOutput
,
draft_input
:
EagleDraftInput
...
@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
if
self
.
enable_nan_detection
:
if
self
.
enable_nan_detection
:
logits
=
logits_output
.
next_token_logits
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
...
...
python/sglang/srt/utils.py
View file @
1b859295
...
@@ -36,6 +36,7 @@ import tempfile
...
@@ -36,6 +36,7 @@ import tempfile
import
threading
import
threading
import
time
import
time
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
functools
import
lru_cache
from
functools
import
lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
...
@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
...
@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
setattr
(
triton
,
"next_power_of_2"
,
next_power_of_2
)
setattr
(
triton
,
"next_power_of_2"
,
next_power_of_2
)
@
contextmanager
def
empty_context
(
*
args
,
**
kwargs
):
try
:
# Setup code goes here
yield
finally
:
# Cleanup code goes here
pass
def
add_prefix
(
name
:
str
,
prefix
:
str
)
->
str
:
def
add_prefix
(
name
:
str
,
prefix
:
str
)
->
str
:
"""Add a weight path prefix to a module name.
"""Add a weight path prefix to a module name.
...
...
scripts/ci_install_dependency.sh
View file @
1b859295
...
@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
...
@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels
# For compling xgrammar kernels
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
pip
install
sgl-kernel
==
0.0.5.post1
--force-reinstall
--no-deps
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
1b859295
...
@@ -36,8 +36,8 @@ template <
...
@@ -36,8 +36,8 @@ template <
typename
DType
,
typename
DType
,
typename
IdType
>
typename
IdType
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
predicts
,
// mutable
IdType
*
accept_index
,
IdType
*
accept_index
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_index
,
...
@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
...
@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
template
<
typename
DType
,
typename
IdType
>
template
<
typename
DType
,
typename
IdType
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
predicts
,
// mutable
IdType
*
output_token_ids
,
IdType
*
output_token_ids
,
// mutable
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_index
,
...
...
test/srt/test_eagle_infer.py
View file @
1b859295
...
@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
...
@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
def
_test_acc_length
(
self
,
engine
):
def
_test_acc_length
(
self
,
engine
):
prompt
=
[
prompt
=
[
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
,
]
*
5
]
*
5
# test batched generation
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
512
}
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
512
}
output
=
engine
.
generate
(
prompt
,
sampling_params
)
output
=
engine
.
generate
(
prompt
,
sampling_params
)
output
=
output
[
0
]
output
=
output
[
0
]
...
...
test/srt/test_mla_flashinfer.py
View file @
1b859295
...
@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
"--enable-torch-compile"
,
"--enable-torch-compile"
,
"--disable-cuda-graph"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--enable-flashinfer-mla"
,
"--enable-flashinfer-mla"
,
"--flashinfer-mla-disable-ragged"
,
"--flashinfer-mla-disable-ragged"
,
]
]
...
@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
...
@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
other_args
.
extend
(
other_args
.
extend
(
[
[
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--disable-radix"
,
"--disable-radix"
,
"--enable-torch-compile"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"--torch-compile-max-bs"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment