Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b859295
Unverified
Commit
1b859295
authored
Mar 16, 2025
by
Ying Sheng
Committed by
GitHub
Mar 16, 2025
Browse files
[Eagle] Remove the greedy branch and some redundant code (#4363)
Co-authored-by:
Sehoon Kim
<
sehoon@x.ai
>
parent
9971dc22
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
383 additions
and
675 deletions
+383
-675
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-11
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-1
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+7
-347
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+30
-5
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+204
-250
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+111
-46
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+11
-0
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+0
-3
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+4
-4
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+2
-2
test/srt/test_mla_flashinfer.py
test/srt/test_mla_flashinfer.py
+2
-2
No files found.
python/pyproject.toml
View file @
1b859295
...
...
@@ -43,7 +43,7 @@ runtime_common = [
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.0.5.post
1
"
,
"sgl-kernel==0.0.5.post
2
"
,
"flashinfer_python==0.2.3"
,
"torch==2.5.1"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
1b859295
...
...
@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return
_create_error_response
(
e
)
@
app
.
post
(
"/flush_cache"
)
@
app
.
api_route
(
"/flush_cache"
,
methods
=
[
"GET"
,
"POST"
]
)
async
def
flush_cache
():
"""Flush the radix cache."""
_global_state
.
tokenizer_manager
.
flush_cache
()
...
...
python/sglang/srt/managers/scheduler.py
View file @
1b859295
...
...
@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"largest-len:
{
self
.
_largest_prefill_decode_len
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
spec_accept_length
=
0
...
...
@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"largest-len:
{
self
.
_largest_prefill_decode_len
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
1b859295
...
...
@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else
:
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
else
:
capture_bs
=
list
(
range
(
1
,
33
))
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs
=
list
(
range
(
1
,
9
))
+
list
(
range
(
9
,
33
,
2
))
+
[
64
,
96
,
128
,
160
]
if
_is_hip
:
capture_bs
+=
[
i
*
8
for
i
in
range
(
21
,
33
)]
...
...
@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs
=
list
(
sorted
(
set
(
capture_bs
+
[
model_runner
.
req_to_token_pool
.
size
-
1
]
+
[
model_runner
.
req_to_token_pool
.
size
]
)
)
)
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
-
1
]
+
[
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
capture_bs
=
[
bs
for
bs
in
capture_bs
...
...
@@ -508,7 +505,9 @@ class CudaGraphRunner:
self
.
raw_num_token
=
raw_num_token
self
.
bs
=
bs
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
)
else
:
...
...
python/sglang/srt/server_args.py
View file @
1b859295
...
...
@@ -285,7 +285,6 @@ class ServerArgs:
if
self
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
32
self
.
disable_cuda_graph_padding
=
True
self
.
disable_overlap_schedule
=
True
logger
.
info
(
"Overlap scheduler is disabled because of using "
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
1b859295
...
...
@@ -3,8 +3,13 @@
from
typing
import
List
import
torch
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
build_tree_kernel_efficient_preprocess
(
...
...
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
top_scores
=
torch
.
topk
(
score_list
,
num_verify_tokens
-
1
,
dim
=-
1
)
top_scores_index
=
top_scores
.
indices
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
...
...
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
)
def
build_tree_kernel
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
):
parent_list
,
top_scores_index
,
draft_tokens
=
(
build_tree_kernel_efficient_preprocess
(
verified_id
,
score_list
,
token_list
,
parents_list
,
num_verify_tokens
,
)
)
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
retrive_index
=
torch
.
full
(
(
bs
,
num_verify_tokens
,
spec_steps
+
2
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel
(
parent_list
,
top_scores_index
,
seq_lens
.
to
(
torch
.
int32
),
tree_mask
,
positions
,
retrive_index
,
topk
,
spec_steps
,
num_verify_tokens
,
)
index
=
retrive_index
.
sum
(
dim
=-
1
)
!=
-
spec_steps
-
2
cum_len
=
torch
.
cumsum
(
torch
.
sum
(
index
,
dim
=-
1
),
dim
=-
1
)
retrive_cum_len
=
torch
.
zeros
(
(
cum_len
.
numel
()
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
retrive_cum_len
[
1
:]
=
cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index
=
retrive_index
[
index
]
return
tree_mask
,
positions
,
retrive_index
,
retrive_cum_len
,
draft_tokens
def
test_build_tree_kernel
():
def
findp
(
p_i
,
index
,
parent_list
):
pos
=
index
//
10
index_list
=
index
.
tolist
()
parent_list
=
parent_list
.
tolist
()
res
=
[
p_i
]
while
True
:
p
=
pos
[
p_i
]
if
p
==
0
:
break
token_idx
=
parent_list
[
p
]
p_i
=
index_list
.
index
(
token_idx
)
res
.
append
(
p_i
)
return
res
def
create_mask
(
seq_len
,
draft_token
,
index
,
parent_list
,
max_depth
):
mask
=
[]
positions
=
[]
retrive_index
=
[]
for
i
,
lens
in
enumerate
(
seq_len
.
tolist
()):
first_mask
=
torch
.
full
((
lens
+
draft_token
,),
True
)
first_mask
[
-
(
draft_token
-
1
)
:]
=
False
positions
.
append
(
lens
)
mask
.
append
(
first_mask
)
seq_order
=
[]
first_index
=
torch
.
Tensor
([
0
]
+
[
-
1
]
*
(
depth
+
1
)).
cuda
().
to
(
torch
.
long
)
r_index
=
[
first_index
]
for
j
in
range
(
draft_token
-
1
):
mask
.
append
(
torch
.
full
((
lens
+
1
,),
True
))
idx
=
findp
(
j
,
index
,
parent_list
)
seq_order
.
append
(
idx
)
positions
.
append
(
len
(
idx
)
+
seq_len
)
t
=
torch
.
full
((
draft_token
-
1
,),
False
)
t
[
idx
]
=
True
mask
.
append
(
t
)
for
i
in
range
(
1
,
draft_token
-
1
):
is_leaf
=
0
for
j
in
range
(
draft_token
-
1
):
if
i
in
seq_order
[
j
]:
is_leaf
+=
1
if
is_leaf
==
1
:
order_list
=
[
0
]
+
[
x
+
1
for
x
in
seq_order
[
i
][::
-
1
]]
for
_
in
range
(
max_depth
+
1
-
len
(
seq_order
[
i
])):
order_list
.
append
(
-
1
)
order
=
torch
.
Tensor
(
order_list
).
cuda
().
to
(
torch
.
long
)
r_index
.
append
(
order
)
retrive_index
.
append
(
torch
.
stack
(
r_index
))
return
(
torch
.
cat
(
mask
).
cuda
(),
torch
.
Tensor
(
positions
).
cuda
().
to
(
torch
.
long
),
torch
.
stack
(
retrive_index
),
)
index
=
(
torch
.
Tensor
(
[
0
,
1
,
2
,
3
,
10
,
11
,
12
,
13
,
20
,
21
,
22
,
30
,
110
,
130
,
150
,
160
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
217
,
218
,
219
,
220
,
230
,
310
,
311
,
312
,
313
,
314
,
315
,
316
,
317
,
320
,
321
,
322
,
330
,
360
,
380
,
390
,
410
,
411
,
412
,
413
,
414
,
415
,
416
,
417
,
418
,
419
,
420
,
421
,
422
,
423
,
430
,
431
,
440
,
441
,
460
,
470
,
]
)
.
to
(
torch
.
long
)
.
cuda
()
)
parent_list
=
(
torch
.
Tensor
(
[
-
1
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
20
,
30
,
21
,
13
,
22
,
40
,
23
,
110
,
130
,
160
,
150
,
190
,
120
,
111
,
121
,
200
,
180
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
220
,
230
,
217
,
310
,
311
,
312
,
313
,
320
,
314
,
321
,
315
,
316
,
317
,
]
)
.
to
(
torch
.
long
)
.
cuda
()
)
verified_seq_len
=
torch
.
Tensor
([
47
]).
to
(
torch
.
long
).
cuda
()
bs
=
verified_seq_len
.
shape
[
0
]
topk
=
10
depth
=
5
# depth <= 10
num_draft_token
=
64
tree_mask
=
torch
.
full
(
(
torch
.
sum
(
verified_seq_len
).
item
()
*
num_draft_token
+
num_draft_token
*
num_draft_token
*
bs
,
),
True
,
).
cuda
()
retrive_index
=
torch
.
full
(
(
bs
,
num_draft_token
,
depth
+
2
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_draft_token
,),
device
=
"cuda"
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel
(
parent_list
.
unsqueeze
(
0
),
index
.
unsqueeze
(
0
),
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
num_draft_token
,
)
retrive_index
=
retrive_index
[
retrive_index
.
sum
(
dim
=-
1
)
!=
-
depth
-
2
]
c_mask
,
c_positions
,
c_retive_index
=
create_mask
(
verified_seq_len
,
num_draft_token
,
index
,
parent_list
,
depth
)
assert
torch
.
allclose
(
tree_mask
,
c_mask
),
"tree mask has error."
assert
torch
.
allclose
(
positions
,
c_positions
),
"positions has error."
assert
torch
.
allclose
(
retrive_index
,
c_retive_index
),
"retrive_index has error."
def
test_build_tree_kernel_efficient
():
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
score_list
=
[
...
...
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
depth
=
4
num_draft_token
=
8
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
build_tree_kernel
(
verified_id
=
verified_id
,
score_list
=
score_list
,
token_list
=
token_list
,
parents_list
=
parents_list
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
)
from
sglang.srt.utils
import
first_rank_print
first_rank_print
(
"=========== build tree kernel =========="
)
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_cum_len
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
0
,
2
,
4
,
6
,
-
1
,
-
1
],
[
0
,
1
,
3
,
5
,
7
,
-
1
],
[
8
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
10
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
12
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
13
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
11
,
14
,
15
,
-
1
],
]
assert
retrive_cum_len
.
tolist
()
==
[
0
,
3
,
8
]
assert
draft_tokens
.
tolist
()
==
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
]
(
tree_mask
,
position
,
...
...
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
if
__name__
==
"__main__"
:
test_build_tree_kernel_efficient
()
test_build_tree_kernel
()
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
1b859295
...
...
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
EAGLEDraftCudaGraphRunner
:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
...
...
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
server_args
=
model_runner
.
server_args
assert
self
.
disable_padding
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
num_tokens_per_bs
=
server_args
.
speculative_eagle_topk
...
...
@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
set_global_graph_memory_pool
(
graph
.
pool
())
return
graph
,
out
def
_postprocess_output_to_raw_bs
(
self
,
out
,
raw_bs
):
score_list
,
token_list
,
parents_list
=
out
score_list
=
[
x
[:
raw_bs
]
for
x
in
score_list
]
token_list
=
[
x
[:
raw_bs
]
for
x
in
token_list
]
parents_list
=
[
x
[:
raw_bs
]
for
x
in
parents_list
]
return
(
score_list
,
token_list
,
parents_list
)
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
...
...
@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
self
.
positions
.
zero_
()
num_tokens
=
bs
*
self
.
num_tokens_per_bs
# Common inputs
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
...
...
@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
self
.
hidden_states
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
# Attention backend
if
bs
!=
raw_bs
:
forward_batch
.
batch_size
=
bs
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
forward_batch
,
forward_batch
.
batch_size
forward_batch
,
bs
)
# Replay
self
.
graphs
[
bs
].
replay
()
out
=
self
.
output_buffers
[
bs
]
if
bs
!=
raw_bs
:
out
=
self
.
_postprocess_output_to_raw_bs
(
out
,
raw_bs
)
forward_batch
.
batch_size
=
raw_bs
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
return
self
.
output_buffers
[
bs
]
return
out
python/sglang/srt/speculative/eagle_utils.py
View file @
1b859295
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -13,18 +13,24 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
(
build_tree_kernel
,
build_tree_kernel_efficient
,
)
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
tree_speculative_sampling_target_only
from
sgl_kernel
import
(
top_k_renorm_prob
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
)
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
import
logging
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
EagleDraftInput
:
...
...
@@ -47,12 +53,9 @@ class EagleDraftInput:
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
# indices of unfinished requests during extend-after-decode
# e.g. [0, 2, 3, 4] if only the 1st request is finished
keep_indices
:
List
[
int
]
=
None
all_padding_lens
:
Optional
[
torch
.
Tensor
]
=
None
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
assert
batch
.
input_ids
.
numel
()
==
batch
.
out_cache_loc
.
shape
[
0
]
# Prefill only generate 1 token.
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
...
...
@@ -64,27 +67,18 @@ class EagleDraftInput:
)
pt
+=
extend_len
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
speculative_num_steps
):
assert
self
.
verified_id
.
numel
()
==
batch
.
out_cache_loc
.
shape
[
0
]
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
speculative_num_steps
:
int
,
):
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
out_cache_loc
)
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
seq_lens_cpu
=
batch
.
seq_lens
.
tolist
()
assert
len
(
batch
.
req_pool_indices
)
==
len
(
batch
.
reqs
)
pt
=
0
i
=
0
self
.
keep_indices
=
[]
for
idx
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
continue
self
.
keep_indices
.
append
(
idx
)
# assert seq_len - pre_len == req.extend_input_len
input_len
=
batch
.
extend_lens
[
i
]
seq_len
=
seq_lens_cpu
[
i
]
pt
+=
input_len
i
+=
1
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
,
dtype
=
torch
.
long
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
...
...
@@ -112,10 +106,6 @@ class EagleDraftInput:
req_to_token
:
torch
.
Tensor
,
):
bs
=
self
.
accept_length
.
numel
()
keep_indices
=
torch
.
tensor
(
self
.
keep_indices
,
device
=
req_pool_indices
.
device
)
req_pool_indices
=
req_pool_indices
[
keep_indices
]
assert
req_pool_indices
.
shape
[
0
]
==
bs
assert
req_pool_indices
.
shape
[
0
]
==
paged_kernel_lens
.
shape
[
0
]
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
...
...
@@ -172,7 +162,7 @@ class EagleVerifyOutput:
# Accepeted token length per sequence in a batch in CPU.
accept_length_per_req_cpu
:
List
[
int
]
# Accepeted indices from logits_output.next_token_logits
accepeted_indices
_cpu
:
List
[
int
]
accepeted_indices
:
torch
.
Tensor
@
dataclass
...
...
@@ -200,67 +190,38 @@ class EagleVerifyInput:
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
is_all_greedy
:
bool
,
):
if
is_all_greedy
:
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
build_tree_kernel
(
verified_id
,
score_list
,
# b, n, topk; n= 1 + (num_steps-1) * self.topk
token_list
,
parents_list
,
seq_lens
,
seq_lens_sum
,
topk
,
spec_steps
,
num_verify_tokens
,
)
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
None
,
None
,
retrive_cum_len
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
)
else
:
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
,
score_list
,
token_list
,
parents_list
,
seq_lens
,
seq_lens_sum
,
topk
,
spec_steps
,
num_verify_tokens
,
)
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
,
score_list
,
token_list
,
parents_list
,
seq_lens
,
seq_lens_sum
,
topk
,
spec_steps
,
num_verify_tokens
,
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
None
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
)
return
cls
(
draft_tokens
,
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
None
,
num_verify_tokens
,
spec_steps
,
CaptureHiddenMode
.
FULL
,
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
):
batch
.
input_ids
=
self
.
draft_token
...
...
@@ -291,7 +252,6 @@ class EagleVerifyInput:
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
cum_kv_seq_len
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -304,7 +264,6 @@ class EagleVerifyInput:
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
req_to_token
,
req_pool_indices
,
...
...
@@ -322,65 +281,79 @@ class EagleVerifyInput:
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
)
->
torch
.
Tensor
:
"""WARNING: This API in-place modifies the states of logits_output
"""
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
WARNING: This API in-place modifies the states of logits_output
This API updates values inside logits_output based on the accepted
tokens. I.e., logits_output.next_token_logits only contains
accepeted token logits.
"""
draft_token
=
torch
.
cat
(
[
self
.
draft_token
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
dim
=-
1
,
bs
=
self
.
retrive_index
.
shape
[
0
]
candidates
=
self
.
draft_token
.
reshape
(
bs
,
self
.
draft_token_num
)
sampling_info
=
batch
.
sampling_info
predict_shape
=
list
(
logits_output
.
next_token_logits
.
shape
)[:
-
1
]
predict_shape
[
-
1
]
+=
1
predict
=
torch
.
empty
(
predict_shape
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_index
=
torch
.
full
(
(
bs
,
self
.
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
candidates
=
draft_token
[
self
.
retrive_index
]
if
batch
.
sampling_info
.
is_all_greedy
:
# temp == 0
bs
=
self
.
retrive_cum_len
.
numel
()
-
1
predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
predict
=
torch
.
cat
(
[
predict
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
dim
=-
1
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
sampling_info
.
penalizer_orchestrator
.
is_required
:
# This is a relaxed version of penalties for speculative decoding.
linear_penalty
=
torch
.
zeros
(
(
bs
,
logits_output
.
next_token_logits
.
shape
[
1
]),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
target_predict
=
predict
[
self
.
retrive_index
]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask
=
candidates
[:,
1
:]
==
target_predict
[:,
:
-
1
]
accept_mask
=
(
torch
.
cumprod
(
accept_mask
,
dim
=
1
)).
sum
(
dim
=
1
)
max_draft_len
=
self
.
retrive_index
.
shape
[
-
1
]
accept_index
=
torch
.
full
(
(
bs
,
max_draft_len
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
sampling_info
.
apply_logits_bias
(
linear_penalty
)
logits_output
.
next_token_logits
.
add_
(
torch
.
repeat_interleave
(
linear_penalty
,
self
.
draft_token_num
,
dim
=
0
)
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
extract_index
=
torch
.
full
((
bs
*
2
,),
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
eagle_verify_retrive
[(
bs
,)](
self
.
retrive_index
.
contiguous
(),
accept_mask
.
contiguous
(),
self
.
retrive_cum_len
,
accept_index
,
accept_length
,
extract_index
,
max_draft_len
,
self
.
draft_token_num
,
triton
.
next_power_of_2
(
max_draft_len
),
if
batch
.
sampling_info
.
is_all_greedy
:
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
verify_tree_greedy
(
predicts
=
predict
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_length
,
# mutable
candidates
=
candidates
.
to
(
torch
.
int32
),
retrive_index
=
self
.
retrive_index
.
to
(
torch
.
int32
),
retrive_next_token
=
self
.
retrive_next_token
.
to
(
torch
.
int32
),
retrive_next_sibling
=
self
.
retrive_next_sibling
.
to
(
torch
.
int32
),
target_predict
=
target_predict
.
to
(
torch
.
int32
),
)
else
:
# temp > 0
bs
=
self
.
retrive_index
.
shape
[
0
]
predict_shape
=
list
(
logits_output
.
next_token_logits
.
shape
)[:
-
1
]
predict_shape
[
-
1
]
+=
1
target_logits
=
logits_output
.
next_token_logits
[
self
.
retrive_index
]
predict
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_index
=
torch
.
full
(
(
bs
,
self
.
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
# apply temperature and get target probs
expanded_temperature
=
torch
.
repeat_interleave
(
sampling_info
.
temperatures
,
self
.
draft_token_num
,
dim
=
0
)
# (bs * draft_token_num, 1)
target_probs
=
F
.
softmax
(
logits_output
.
next_token_logits
/
expanded_temperature
,
dim
=-
1
)
# (bs * draft_token_num, vocab_size)
target_probs
=
top_k_renorm_prob
(
target_probs
,
torch
.
repeat_interleave
(
sampling_info
.
top_ks
,
self
.
draft_token_num
,
dim
=
0
),
)
# (bs * draft_token_num, vocab_size)
target_probs
=
top_p_renorm_prob
(
target_probs
,
torch
.
repeat_interleave
(
sampling_info
.
top_ps
,
self
.
draft_token_num
,
dim
=
0
),
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expanded_temperature
=
batch
.
sampling_info
.
temperatures
.
unsqueeze
(
1
)
target_probs
=
F
.
softmax
(
target_logits
/
expanded_temperature
,
dim
=-
1
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
target_probs
=
target_probs
.
reshape
(
bs
,
self
.
draft_token_num
,
-
1
)
draft_probs
=
torch
.
zeros
(
target_probs
.
shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
coins
=
torch
.
rand_like
(
candidates
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
tree_speculative_sampling_target_only
(
...
...
@@ -394,6 +367,12 @@ class EagleVerifyInput:
uniform_samples
=
coins
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
"speculative_accept_threshold_single"
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
)
...
...
@@ -425,119 +404,94 @@ class EagleVerifyInput:
new_accept_index
.
extend
(
new_accept_index_
)
unfinished_index
.
append
(
i
)
req
.
spec_verify_ct
+=
1
accept_length
=
(
accept_index
!=
-
1
).
sum
(
dim
=
1
)
-
1
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_length_cpu
=
accept_length
.
tolist
()
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
token_to_kv_pool_allocator
.
free
(
mem_need_free_idx
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
[
accept_index
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
triton
.
next_power_of_2
(
bs
),
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
draft_input
=
EagleDraftInput
()
if
len
(
new_accept_index
)
>
0
:
new_accept_index
=
torch
.
tensor
(
new_accept_index
,
device
=
"cuda"
)
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
new_accept_index
]
draft_input
.
verified_id
=
predict
[
new_accept_index
]
draft_input
.
accept_length
=
accept_length
[
unfinished_index
]
draft_input
.
accept_length_cpu
=
[
accept_length_cpu
[
i
]
for
i
in
unfinished_index
]
if
has_finished
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
]
else
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
new_accept_index
]
return
EagleVerifyOutput
(
draft_input
=
draft_input
,
logits_output
=
logits_output
,
verified_id
=
verified_id
,
accept_length_per_req_cpu
=
accept_length_cpu
,
accepeted_indices_cpu
=
accept_index
,
)
@
triton
.
jit
def
eagle_verify_retrive
(
retrive_index
,
accept_mask
,
retrive_cum_len
,
accept_index
,
accept_length
,
extract_index
,
max_len
:
tl
.
constexpr
,
draft_token_num
:
tl
.
constexpr
,
max_len_upper
:
tl
.
constexpr
,
):
"""
Args:
retrive_index: Pointer to indices of draft tokens
accept_mask: Mask indicating which tokens were accepted
retrive_cum_len: Cumulative lengths of token sequences in a batch
accept_index (out): Accept token indices
accept_length (out): Length of accepted tokens per sequence in a batch
extract_index (out): Index for last accepted tokens
max_len: Maximum length in a batch
draft_token_num: Number of tokens speculatively generated
max_len_upper An upper bound for token sequence length
"""
pid
=
tl
.
program_id
(
axis
=
0
)
retrive_end
=
tl
.
load
(
retrive_cum_len
+
pid
+
1
)
retrive_start
=
tl
.
load
(
retrive_cum_len
+
pid
)
retrive_len
=
retrive_end
-
retrive_start
accept_ptr
=
accept_mask
+
retrive_start
accept_offset
=
tl
.
arange
(
0
,
draft_token_num
)
accept_load_mask
=
accept_offset
<
retrive_len
accept_len_list
=
tl
.
load
(
accept_ptr
+
accept_offset
,
mask
=
accept_load_mask
,
other
=-
1
)
accept_len
=
tl
.
max
(
accept_len_list
)
max_index
=
tl
.
argmax
(
accept_len_list
,
axis
=
0
,
tie_break_left
=
True
)
# triton is not support argmax with tie_break_right, so I need implement it by some way
mask_max
=
accept_len_list
==
accept_len
count_mask
=
tl
.
full
(
shape
=
[
draft_token_num
],
value
=
0
,
dtype
=
tl
.
int32
)
count
=
tl
.
sum
(
tl
.
where
(
mask_max
,
1
,
count_mask
))
if
count
>
1
:
index
=
tl
.
arange
(
0
,
draft_token_num
)
mask_left
=
index
!=
max_index
remained_index
=
tl
.
where
(
mask_max
and
mask_left
,
index
,
0
)
max_index
=
tl
.
max
(
remained_index
)
tl
.
store
(
accept_length
+
pid
,
accept_len
)
retrive_index_ptr
=
retrive_index
+
(
retrive_start
+
max_index
)
*
max_len
retrive_offset
=
tl
.
arange
(
0
,
max_len_upper
)
retrive_load_mask
=
retrive_offset
<
accept_len
+
1
data
=
tl
.
load
(
retrive_index_ptr
+
retrive_offset
,
mask
=
retrive_load_mask
)
tl
.
store
(
accept_index
+
pid
*
max_len
+
retrive_offset
,
data
,
mask
=
retrive_load_mask
)
extract_load_ptr
=
accept_index
+
pid
*
max_len
+
accept_len
if
accept_len
==
max_len
-
1
:
extract_data
=
tl
.
load
(
extract_load_ptr
-
1
)
tl
.
store
(
extract_index
+
pid
*
2
,
extract_data
)
extract_data
=
tl
.
load
(
extract_load_ptr
)
tl
.
store
(
extract_index
+
pid
*
2
+
1
,
extract_data
)
else
:
extract_data
=
tl
.
load
(
extract_load_ptr
)
tl
.
store
(
extract_index
+
pid
*
2
,
extract_data
)
if
not
has_finished
:
accept_index
=
accept_index
[
accept_index
!=
-
1
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
token_to_kv_pool_allocator
.
free
(
mem_need_free_idx
)
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
accept_index
]
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
triton
.
next_power_of_2
(
bs
),
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
accept_length_cpu
=
accept_length
.
tolist
()
draft_input
=
EagleDraftInput
()
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
accept_index
]
draft_input
.
verified_id
=
verified_id
draft_input
.
accept_length
=
accept_length
draft_input
.
accept_length_cpu
=
accept_length_cpu
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
return
EagleVerifyOutput
(
draft_input
=
draft_input
,
logits_output
=
logits_output
,
verified_id
=
verified_id
,
accept_length_per_req_cpu
=
accept_length_cpu
,
accepeted_indices
=
accept_index
,
)
else
:
accept_length
=
(
accept_index
!=
-
1
).
sum
(
dim
=
1
)
-
1
accept_index
=
accept_index
[
accept_index
!=
-
1
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
token_to_kv_pool_allocator
.
free
(
mem_need_free_idx
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
[
accept_index
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
triton
.
next_power_of_2
(
bs
),
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
accept_length_cpu
=
accept_length
.
tolist
()
draft_input
=
EagleDraftInput
()
if
len
(
new_accept_index
)
>
0
:
new_accept_index
=
torch
.
tensor
(
new_accept_index
,
device
=
"cuda"
)
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
new_accept_index
]
draft_input
.
verified_id
=
predict
[
new_accept_index
]
draft_input
.
accept_length
=
accept_length
[
unfinished_index
]
draft_input
.
accept_length_cpu
=
[
accept_length_cpu
[
i
]
for
i
in
unfinished_index
]
if
has_finished
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
]
draft_input
.
req_pool_indices_for_draft_extend
=
(
batch
.
req_pool_indices
[
unfinished_index
]
)
else
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
draft_input
.
req_pool_indices_for_draft_extend
=
(
batch
.
req_pool_indices
)
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
new_accept_index
]
return
EagleVerifyOutput
(
draft_input
=
draft_input
,
logits_output
=
logits_output
,
verified_id
=
verified_id
,
accept_length_per_req_cpu
=
accept_length_cpu
,
accepeted_indices
=
accept_index
,
)
@
triton
.
jit
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
1b859295
import
logging
import
os
import
time
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
huggingface_hub
import
snapshot_download
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.layers.dp_attention
import
disable_dp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk
,
select_top_k_tokens
,
)
from
sglang.srt.utils
import
get_available_gpu_memory
from
sglang.srt.utils
import
empty_context
,
get_available_gpu_memory
,
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
segment_packbits
logger
=
logging
.
getLogger
(
__name__
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
disable_dp_size
(),
patch_tensor_parallel_group
(
tp_group
):
yield
class
EAGLEWorker
(
TpModelWorker
):
def
__init__
(
...
...
@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
None
# Init draft worker
super
().
__init__
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
with
empty_context
():
super
().
__init__
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
# Share the embedding and lm_head
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
...
...
@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
self
.
draft_model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
# Init attention backend and cuda graphs
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
)
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
self
.
draft_tp_context
=
(
draft_tp_context
if
server_args
.
enable_dp_attention
else
empty_context
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
def
init_attention_backend
(
self
):
# Create multi-step attn backends and cuda graph runners
...
...
@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"triton"
:
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
FlashInferMLAMultiStepDraftBackend
(
self
.
model_runner
,
self
.
draft_
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
else
:
raise
ValueError
(
f
"EAGLE is not supportted in attention backend
{
self
.
server_args
.
attention_backend
}
"
)
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
if
self
.
server_args
.
disable_cuda_graph
:
return
# Capture draft
tic
=
time
.
time
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
:.
2
f
}
GB"
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
after_mem
:.
2
f
}
GB. mem usage=
{
(
before_mem
-
after_mem
):.
2
f
}
GB
.
"
)
# Capture extend
if
self
.
draft_extend_attn_backend
:
raise
NotImplementedError
()
@
property
def
draft_model_runner
(
self
):
return
self
.
model_runner
...
...
@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
,
int
]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch
doesn't
have the same state as the input.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
that
the final output batch have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
...
...
@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert
not
batch
.
spec_algorithm
.
is_none
()
if
batch
.
forward_mode
.
is_decode
():
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
# if it is None, means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_draft_extend_after_decode
(
batch
)
# If it is None, it means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
logits_output
,
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
)
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
,
_
=
(
self
.
target_worker
.
forward_batch_generation
(
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
target_worker
.
model_runner
)
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
return
logits_output
,
next_token_ids
,
bid
,
0
def
forward_target_extend
(
...
...
@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
num_seqs
=
batch
.
batch_size
()
spec_info
=
batch
.
spec_info
# Accumulate penalty
if
batch
.
sampling_info
.
penalizer_orchestrator
.
is_required
:
# This is a relaxed version of penalties for speculative decoding.
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
spec_info
.
verified_id
.
to
(
torch
.
int64
)
)
# Allocate cache locations
out_cache_loc
=
batch
.
alloc_token_slots
(
num_seqs
*
self
.
topk
*
self
.
speculative_num_steps
...
...
@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
speculative_num_steps
,
self
.
server_args
.
speculative_num_draft_tokens
,
batch
.
sampling_info
.
is_all_greedy
,
)
return
ret
,
out_cache_loc
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
...
...
@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
token_list
.
append
(
tree_info
[
1
])
parents_list
.
append
(
tree_info
[
2
])
#
w
e don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
#
W
e don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if
i
==
self
.
speculative_num_steps
-
1
:
break
...
...
@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
spec_info
.
hidden_states
=
hidden_states
# Run forward
logits_output
=
self
.
model_runner
.
model
.
forward
(
logits_output
=
self
.
draft_
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
...
...
@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
res
.
accepeted_indices_cpu
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices_cpu
res
.
accepeted_indices
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices
]
# Prepare the batch for the next draft forwards.
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
spec_info
=
res
.
draft_input
...
...
@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
batch_next_token_ids
,
]
# Add output logprobs to the request
.
# Add output logprobs to the request
pt
=
0
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
verified_ids
=
batch_next_token_ids
.
tolist
()
...
...
@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
seq_lens_backup
=
batch
.
seq_lens
# Backup fileds that will be modified in-place
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
req_pool_indices_backup
=
batch
.
req_pool_indices
accept_length_backup
=
batch
.
spec_info
.
accept_length
return_logprob_backup
=
batch
.
return_logprob
# Prepare metadata
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
)
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
,
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
# We don't need logprob for this extend.
original_return_logprob
=
batch
.
return_logprob
batch
.
return_logprob
=
False
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
# Run
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch
.
return_logprob
=
original_return_logprob
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
seq_lens
=
seq_lens_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
batch
.
spec_info
.
accept_length
=
accept_length_backup
batch
.
return_logprob
=
return_logprob_backup
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
draft_input
:
EagleDraftInput
...
...
@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
if
self
.
enable_nan_detection
:
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
...
...
python/sglang/srt/utils.py
View file @
1b859295
...
...
@@ -36,6 +36,7 @@ import tempfile
import
threading
import
time
import
warnings
from
contextlib
import
contextmanager
from
functools
import
lru_cache
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.util
import
find_spec
...
...
@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
setattr
(
triton
,
"next_power_of_2"
,
next_power_of_2
)
@
contextmanager
def
empty_context
(
*
args
,
**
kwargs
):
try
:
# Setup code goes here
yield
finally
:
# Cleanup code goes here
pass
def
add_prefix
(
name
:
str
,
prefix
:
str
)
->
str
:
"""Add a weight path prefix to a module name.
...
...
scripts/ci_install_dependency.sh
View file @
1b859295
...
...
@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
pip
install
sgl-kernel
==
0.0.5.post1
--force-reinstall
--no-deps
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
1b859295
...
...
@@ -36,8 +36,8 @@ template <
typename
DType
,
typename
IdType
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
predicts
,
// mutable
IdType
*
accept_index
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
...
...
@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
template
<
typename
DType
,
typename
IdType
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
output_token_ids
,
IdType
*
predicts
,
// mutable
IdType
*
output_token_ids
,
// mutable
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
...
...
test/srt/test_eagle_infer.py
View file @
1b859295
...
...
@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
def
_test_acc_length
(
self
,
engine
):
prompt
=
[
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
]
*
5
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
,
]
*
5
# test batched generation
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
512
}
output
=
engine
.
generate
(
prompt
,
sampling_params
)
output
=
output
[
0
]
...
...
test/srt/test_mla_flashinfer.py
View file @
1b859295
...
...
@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
"--enable-torch-compile"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--enable-flashinfer-mla"
,
"--flashinfer-mla-disable-ragged"
,
]
...
...
@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
other_args
.
extend
(
[
"--cuda-graph-max-bs"
,
"
2
"
,
"
4
"
,
"--disable-radix"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment