Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc0705a5
Unverified
Commit
dc0705a5
authored
Jun 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 09, 2025
Browse files
Simplify prepare_extend_after_decode (#6987)
parent
a968c888
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
140 additions
and
176 deletions
+140
-176
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+13
-10
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-1
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+11
-3
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+41
-130
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+53
-18
test/srt/test_full_deepseek_v3.py
test/srt/test_full_deepseek_v3.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
dc0705a5
...
@@ -1636,7 +1636,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1636,7 +1636,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
self
.
spec_info
:
if
self
.
spec_info
:
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
def
get_model_worker_batch
(
self
)
->
ModelWorkerBatch
:
def
get_model_worker_batch
(
self
,
seq_lens_cpu_cache
:
Optional
[
torch
.
Tensor
]
=
None
)
->
ModelWorkerBatch
:
if
self
.
forward_mode
.
is_decode_or_idle
():
if
self
.
forward_mode
.
is_decode_or_idle
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
else
:
else
:
...
@@ -1646,16 +1648,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1646,16 +1648,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Create seq_lens_cpu when needed
# Create seq_lens_cpu when needed
if
(
if
(
(
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
or
(
global_server_args_dict
[
"use_mla_backend"
]
global_server_args_dict
[
"use_mla_backend"
]
and
global_server_args_dict
[
"attention_backend"
]
==
"flashinfer"
and
global_server_args_dict
[
"attention_backend"
]
==
"flashinfer"
)
)
or
global_server_args_dict
[
"attention_backend"
]
==
"flashmla"
or
global_server_args_dict
[
"attention_backend"
]
==
"flashmla"
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
or
global_server_args_dict
[
"attention_backend"
]
==
"cutlass_mla"
or
global_server_args_dict
[
"attention_backend"
]
==
"cutlass_mla"
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
):
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
seq_lens_cpu
=
(
seq_lens_cpu_cache
if
seq_lens_cpu_cache
is
not
None
else
self
.
seq_lens
.
cpu
()
)
else
:
else
:
seq_lens_cpu
=
None
seq_lens_cpu
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
dc0705a5
...
@@ -1575,10 +1575,9 @@ class Scheduler(
...
@@ -1575,10 +1575,9 @@ class Scheduler(
num_accepted_tokens
,
num_accepted_tokens
,
can_run_cuda_graph
,
can_run_cuda_graph
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
spec_num_total_accepted_tokens
+=
(
bs
=
batch
.
batch_size
()
num_accepted_tokens
+
batch
.
batch_size
()
self
.
spec_num_total_accepted_tokens
+=
num_accepted_tokens
+
bs
)
self
.
spec_num_total_forward_ct
+=
bs
self
.
spec_num_total_forward_ct
+=
batch
.
batch_size
()
self
.
num_generated_tokens
+=
num_accepted_tokens
self
.
num_generated_tokens
+=
num_accepted_tokens
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
dc0705a5
...
@@ -56,6 +56,16 @@ def get_is_capture_mode():
...
@@ -56,6 +56,16 @@ def get_is_capture_mode():
return
is_capture_mode
return
is_capture_mode
@
contextmanager
def
model_capture_mode
():
global
is_capture_mode
is_capture_mode
=
True
yield
is_capture_mode
=
False
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
isinstance
(
sub
,
CustomOp
):
...
@@ -291,22 +301,13 @@ class CudaGraphRunner:
...
@@ -291,22 +301,13 @@ class CudaGraphRunner:
# Capture
# Capture
try
:
try
:
with
self
.
model_capture_mode
():
with
model_capture_mode
():
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
)
@
contextmanager
def
model_capture_mode
(
self
):
global
is_capture_mode
is_capture_mode
=
True
yield
is_capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
total_global_tokens
=
sum
(
forward_batch
.
global_num_tokens_cpu
)
total_global_tokens
=
sum
(
forward_batch
.
global_num_tokens_cpu
)
...
@@ -650,6 +651,8 @@ class CudaGraphRunner:
...
@@ -650,6 +651,8 @@ class CudaGraphRunner:
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
seq_lens_sum
=
None
,
seq_lens_cpu
=
None
,
)
)
return
spec_info
return
spec_info
...
...
python/sglang/srt/server_args.py
View file @
dc0705a5
...
@@ -1013,13 +1013,13 @@ class ServerArgs:
...
@@ -1013,13 +1013,13 @@ class ServerArgs:
type
=
str
,
type
=
str
,
choices
=
[
choices
=
[
"aiter"
,
"aiter"
,
"flashinfer"
,
"cutlass_mla"
,
"triton"
,
"torch_native"
,
"fa3"
,
"fa3"
,
"flashinfer"
,
"flashmla"
,
"flashmla"
,
"cutlass_mla"
,
"intel_amx"
,
"intel_amx"
,
"torch_native"
,
"triton"
,
],
],
default
=
ServerArgs
.
attention_backend
,
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
help
=
"Choose the kernels for attention layers."
,
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
dc0705a5
...
@@ -10,6 +10,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
...
@@ -10,6 +10,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner
,
CudaGraphRunner
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
model_capture_mode
,
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
...
@@ -80,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -80,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
# Capture
# Capture
try
:
try
:
with
model_capture_mode
():
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
dc0705a5
...
@@ -11,6 +11,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
...
@@ -11,6 +11,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
LogitsProcessorOutput
,
LogitsProcessorOutput
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
model_capture_mode
,
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
...
@@ -19,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -19,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
...
@@ -37,6 +38,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -37,6 +38,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
padded_static_len
=
-
1
self
.
padded_static_len
=
-
1
...
@@ -87,6 +89,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -87,6 +89,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# Capture
# Capture
try
:
try
:
with
model_capture_mode
():
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
...
@@ -170,6 +173,8 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -170,6 +173,8 @@ class EAGLEDraftExtendCudaGraphRunner:
forward_batch
.
positions
,
forward_batch
.
positions
,
forward_batch
,
forward_batch
,
)
)
probs
=
torch
.
softmax
(
ret
.
next_token_logits
,
dim
=-
1
)
ret
.
topk_p
,
ret
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
forward_batch
.
out_cache_loc
=
output_cache_loc_backup
forward_batch
.
out_cache_loc
=
output_cache_loc_backup
forward_batch
.
spec_info
.
hidden_states
=
hidden_states_backup
forward_batch
.
spec_info
.
hidden_states
=
hidden_states_backup
...
@@ -198,7 +203,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -198,7 +203,7 @@ class EAGLEDraftExtendCudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_b
s
:
if
bs
*
self
.
num_tokens_per_bs
!=
num_token
s
:
self
.
seq_lens
.
fill_
(
1
)
self
.
seq_lens
.
fill_
(
1
)
self
.
accept_length
.
fill_
(
1
)
self
.
accept_length
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
...
@@ -238,8 +243,11 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -238,8 +243,11 @@ class EAGLEDraftExtendCudaGraphRunner:
out
=
self
.
output_buffers
[
bs
]
out
=
self
.
output_buffers
[
bs
]
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
raw_bs
]
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
raw_bs
]
out_copy
=
out
out
=
LogitsProcessorOutput
(
out
=
LogitsProcessorOutput
(
next_token_logits
=
out
.
next_token_logits
[:
raw_bs
],
next_token_logits
=
out
.
next_token_logits
[:
raw_bs
],
hidden_states
=
out
.
hidden_states
[:
raw_bs
],
hidden_states
=
out
.
hidden_states
[:
raw_bs
],
)
)
out
.
topk_p
=
out_copy
.
topk_p
[:
raw_bs
]
out
.
topk_index
=
out_copy
.
topk_index
[:
raw_bs
]
return
out
return
out
python/sglang/srt/speculative/eagle_utils.py
View file @
dc0705a5
...
@@ -22,8 +22,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -22,8 +22,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
if
is_cuda
():
if
is_cuda
():
...
@@ -86,77 +85,28 @@ class EagleDraftInput:
...
@@ -86,77 +85,28 @@ class EagleDraftInput:
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
context_length
:
int
,
pad_input
:
bool
=
False
,
):
):
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
input_ids
=
self
.
verified_id
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
seq_lens_cpu
=
batch
.
seq_lens
.
tolist
()
batch
.
return_logprob
=
False
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
,
dtype
=
torch
.
long
)
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
self
.
accept_length
.
add_
(
1
)
self
.
accept_length
.
add_
(
1
)
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_
spec_info
[(
self
.
accept_length
.
numel
(
),)](
create_extend_
after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
self
.
verified
_id
,
batch
.
input
_id
s
,
batch
.
seq_lens
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
accept_length
,
torch
.
cumsum
(
self
.
accept_length
,
axis
=
0
,
dtype
=
torch
.
int
),
self
.
positions
,
self
.
positions
,
new_verified_id
,
self
.
verified_id
,
next_power_of_2
(
speculative_num_steps
+
1
),
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
batch
.
seq_lens_sum
=
sum
(
seq_lens_cpu
)
batch
.
input_ids
=
self
.
verified_id
self
.
verified_id
=
new_verified_id
if
not
pad_input
:
return
batch_size
=
sum
(
not
req
.
finished
()
for
req
in
batch
.
reqs
)
# Total constant input length after padding
static_len
=
speculative_num_steps
+
1
# Total size after padding
padded_input_size
=
batch_size
*
static_len
padded_len
=
padded_input_size
-
batch
.
input_ids
.
shape
[
0
]
if
padded_len
>
0
:
new_input_ids
=
torch
.
nn
.
functional
.
pad
(
batch
.
input_ids
,
(
0
,
padded_len
),
value
=
0
)
position_padding
=
torch
.
arange
(
padded_len
,
device
=
self
.
positions
.
device
)
new_positions
=
torch
.
cat
([
self
.
positions
,
position_padding
])
# need dummy hidden states for the padded positions
hidden_states_dim
=
self
.
hidden_states
.
shape
[
-
1
]
new_hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
torch
.
zeros
(
(
padded_len
,
hidden_states_dim
),
dtype
=
self
.
hidden_states
.
dtype
,
device
=
self
.
hidden_states
.
device
,
),
],
dim
=
0
,
)
# allocate KV cache location for the padded tokens
padded_cache_loc
=
torch
.
zeros
(
padded_len
,
dtype
=
batch
.
out_cache_loc
.
dtype
,
device
=
batch
.
out_cache_loc
.
device
,
)
)
new_out_cache_loc
=
torch
.
cat
([
batch
.
out_cache_loc
,
padded_cache_loc
])
batch
.
input_ids
=
new_input_ids
self
.
hidden_states
=
new_hidden_states
self
.
positions
=
new_positions
batch
.
out_cache_loc
=
new_out_cache_loc
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
self
,
self
,
...
@@ -173,8 +123,9 @@ class EagleDraftInput:
...
@@ -173,8 +123,9 @@ class EagleDraftInput:
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
empty
(
cum_kv_seq_len
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_to_token
,
...
@@ -238,54 +189,10 @@ class EagleVerifyInput:
...
@@ -238,54 +189,10 @@ class EagleVerifyInput:
topk
:
int
topk
:
int
draft_token_num
:
int
draft_token_num
:
int
capture_hidden_mode
:
CaptureHiddenMode
capture_hidden_mode
:
CaptureHiddenMode
seq_lens_sum
:
int
seq_lens_cpu
:
torch
.
Tensor
grammar
:
BaseGrammarObject
=
None
grammar
:
BaseGrammarObject
=
None
@
classmethod
def
create
(
cls
,
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
):
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
,
score_list
,
token_list
,
parents_list
,
seq_lens
,
seq_lens_sum
,
topk
,
spec_steps
,
num_verify_tokens
,
)
return
cls
(
draft_token
=
draft_tokens
,
custom_mask
=
tree_mask
,
positions
=
position
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
retrive_cum_len
=
None
,
spec_steps
=
spec_steps
,
topk
=
topk
,
draft_token_num
=
num_verify_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
batch
.
input_ids
=
self
.
draft_token
batch
.
input_ids
=
self
.
draft_token
...
@@ -614,26 +521,28 @@ class EagleVerifyInput:
...
@@ -614,26 +521,28 @@ class EagleVerifyInput:
@
triton
.
jit
@
triton
.
jit
def
create_extend_spec_info
(
def
create_extend_
after_decode_
spec_info
(
verified_id
,
verified_id
,
seq_len
,
seq_lens
,
accept_len
,
accept_lens
,
accept_len_cum
,
positions
,
positions
,
new_verified_id
,
new_verified_id
,
accept_len
_upper
:
tl
.
constexpr
,
bs
_upper
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
offset
=
0
if
pid
==
0
else
tl
.
load
(
accept_len_cum
+
pid
-
1
)
offsets
=
tl
.
arange
(
0
,
bs_upper
)
seq_length
=
tl
.
load
(
seq_len
+
pid
)
seq_length
=
tl
.
load
(
seq_lens
+
pid
)
accept_length
=
tl
.
load
(
accept_len
+
pid
)
accept_length
=
tl
.
load
(
accept_lens
+
pid
)
positions_ptr
=
positions
+
offset
data
=
tl
.
arange
(
0
,
accept_len_upper
)
accept_len_cumsum
=
tl
.
sum
(
mask
=
data
<
accept_length
tl
.
load
(
accept_lens
+
offsets
,
mask
=
offsets
<
pid
,
other
=
0
)
tl
.
store
(
positions_ptr
+
data
,
seq_length
-
accept_length
+
data
,
mask
)
)
positions_ptr
=
positions
+
accept_len_cumsum
offset
=
tl
.
load
(
accept_len_cum
+
pid
)
-
1
mask
=
offsets
<
accept_length
verified_id_data
=
tl
.
load
(
verified_id
+
offset
)
tl
.
store
(
positions_ptr
+
offsets
,
seq_length
-
accept_length
+
offsets
,
mask
)
accept_len_cumsum
+=
accept_length
-
1
verified_id_data
=
tl
.
load
(
verified_id
+
accept_len_cumsum
)
tl
.
store
(
new_verified_id
+
pid
,
verified_id_data
)
tl
.
store
(
new_verified_id
+
pid
,
verified_id_data
)
...
@@ -654,8 +563,8 @@ def assign_req_to_token_pool(
...
@@ -654,8 +563,8 @@ def assign_req_to_token_pool(
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
length_offset
=
tl
.
arange
(
0
,
bs_upper
)
length_offset
=
tl
.
arange
(
0
,
bs_upper
)
start
=
tl
.
load
(
start_offset
+
length_offset
,
mask
=
length_offset
<
pid
)
start
=
tl
.
load
(
start_offset
+
length_offset
,
mask
=
length_offset
<
pid
,
other
=
0
)
end
=
tl
.
load
(
end_offset
+
length_offset
,
mask
=
length_offset
<
pid
)
end
=
tl
.
load
(
end_offset
+
length_offset
,
mask
=
length_offset
<
pid
,
other
=
0
)
out_offset
=
tl
.
sum
(
end
-
start
,
axis
=
0
)
out_offset
=
tl
.
sum
(
end
-
start
,
axis
=
0
)
out_cache_ptr
=
out_cache_loc
+
out_offset
out_cache_ptr
=
out_cache_loc
+
out_offset
...
@@ -736,7 +645,7 @@ def generate_draft_decode_kv_indices(
...
@@ -736,7 +645,7 @@ def generate_draft_decode_kv_indices(
iters
+=
1
iters
+=
1
load_offset
=
tl
.
arange
(
0
,
bs_upper
)
load_offset
=
tl
.
arange
(
0
,
bs_upper
)
seq_lens
=
tl
.
load
(
paged_kernel_lens
+
load_offset
,
mask
=
load_offset
<
bid
)
seq_lens
=
tl
.
load
(
paged_kernel_lens
+
load_offset
,
mask
=
load_offset
<
bid
,
other
=
0
)
seq_len
=
tl
.
load
(
paged_kernel_lens
+
bid
)
seq_len
=
tl
.
load
(
paged_kernel_lens
+
bid
)
cum_seq_len
=
tl
.
sum
(
seq_lens
)
cum_seq_len
=
tl
.
sum
(
seq_lens
)
...
@@ -765,7 +674,7 @@ def generate_draft_decode_kv_indices(
...
@@ -765,7 +674,7 @@ def generate_draft_decode_kv_indices(
zid
=
bid
*
topk
+
topk_id
zid
=
bid
*
topk
+
topk_id
if
zid
==
0
:
if
zid
==
0
:
zid
=
num_seqs
*
topk
zid
=
num_seqs
*
topk
positions
=
tl
.
load
(
positions
+
bs_offset
,
mask
=
bs_offset
<
zid
)
positions
=
tl
.
load
(
positions
+
bs_offset
,
mask
=
bs_offset
<
zid
,
other
=
0
)
base
=
tl
.
sum
(
positions
)
base
=
tl
.
sum
(
positions
)
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
...
@@ -783,7 +692,9 @@ def align_evict_mask_to_page_size(
...
@@ -783,7 +692,9 @@ def align_evict_mask_to_page_size(
bid
=
tl
.
program_id
(
axis
=
0
)
bid
=
tl
.
program_id
(
axis
=
0
)
seq_len
=
tl
.
load
(
seq_lens
+
bid
)
seq_len
=
tl
.
load
(
seq_lens
+
bid
)
io_mask
=
t_range
<
num_draft_tokens
io_mask
=
t_range
<
num_draft_tokens
mask_row
=
tl
.
load
(
evict_mask
+
bid
*
num_draft_tokens
+
t_range
,
mask
=
io_mask
)
mask_row
=
tl
.
load
(
evict_mask
+
bid
*
num_draft_tokens
+
t_range
,
mask
=
io_mask
,
other
=
0
)
num_trues
=
tl
.
sum
(
mask_row
)
num_trues
=
tl
.
sum
(
mask_row
)
num_false
=
num_draft_tokens
-
num_trues
num_false
=
num_draft_tokens
-
num_trues
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
dc0705a5
...
@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
EAGLEDraftCudaGraphRunner
,
EAGLEDraftCudaGraphRunner
,
)
)
...
@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker):
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
server_args
.
speculative_num_steps
self
.
speculative_num_steps
=
server_args
.
speculative_num_steps
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
enable_nan_detection
=
server_args
.
enable_nan_detection
self
.
enable_nan_detection
=
server_args
.
enable_nan_detection
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
self
.
device
=
server_args
.
device
self
.
device
=
server_args
.
device
...
@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker):
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
self
.
padded_static_len
=
-
1
# Override context length with target model's context length
# Override context length with target model's context length
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
...
@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
self
.
draft_model_runner
,
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
from
sglang.srt.layers.attention.triton_backend
import
(
from
sglang.srt.layers.attention.triton_backend
import
(
...
@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
self
.
draft_model_runner
,
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
False
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
from
sglang.srt.layers.attention.flashattention_backend
import
(
...
@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
self
.
draft_model_runner
,
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
False
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
from
sglang.srt.layers.attention.flashmla_backend
import
(
from
sglang.srt.layers.attention.flashmla_backend
import
(
...
@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker):
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
None
self
.
draft_extend_attn_backend
=
None
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
False
self
.
has_prefill_wrapper_verify
=
False
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker):
...
@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker):
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
else
:
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
logits_output
,
next_token_ids
,
bid
,
seq_lens_cpu
=
(
self
.
forward_target_extend
(
batch
)
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
batch
,
logits_output
.
hidden_states
,
next_token_ids
,
seq_lens_cpu
)
)
return
logits_output
,
next_token_ids
,
bid
,
0
,
False
return
logits_output
,
next_token_ids
,
bid
,
0
,
False
...
@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker):
...
@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
return
(
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
model_worker_batch
.
seq_lens_cpu
,
)
def
draft
(
self
,
batch
:
ScheduleBatch
):
def
draft
(
self
,
batch
:
ScheduleBatch
):
# Parse args
# Parse args
...
@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker):
...
@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker):
self
.
token_to_kv_pool_allocator
.
restore_state
(
token_to_kv_pool_state_backup
)
self
.
token_to_kv_pool_allocator
.
restore_state
(
token_to_kv_pool_state_backup
)
ret
=
EagleVerifyInput
.
create
(
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
spec_info
.
verified_id
,
spec_info
.
verified_id
,
score_list
,
score_list
,
token_list
,
token_list
,
...
@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker):
...
@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker):
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
speculative_num_draft_tokens
,
self
.
server_args
.
speculative_num_draft_tokens
,
)
)
return
ret
return
EagleVerifyInput
(
draft_token
=
draft_tokens
,
custom_mask
=
tree_mask
,
positions
=
position
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
retrive_cum_len
=
None
,
spec_steps
=
self
.
speculative_num_steps
,
topk
=
self
.
topk
,
draft_token_num
=
self
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
seq_lens_sum
=
forward_batch
.
seq_lens_sum
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
# Parse args
# Parse args
...
@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker):
spec_info
.
prepare_for_verify
(
batch
,
self
.
page_size
)
spec_info
.
prepare_for_verify
(
batch
,
self
.
page_size
)
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
spec_info
batch
.
spec_info
=
spec_info
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
(
seq_lens_cpu_cache
=
spec_info
.
seq_lens_cpu
)
if
batch
.
has_grammar
:
if
batch
.
has_grammar
:
retrieve_next_token_cpu
=
spec_info
.
retrive_next_token
.
cpu
()
retrieve_next_token_cpu
=
spec_info
.
retrive_next_token
.
cpu
()
...
@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker):
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
next_token_ids
:
List
[
int
],
next_token_ids
:
List
[
int
],
seq_lens_cpu
:
torch
.
Tensor
,
):
):
"""Run draft model extend. This API modifies the states of the batch.
"""Run draft model extend. This API modifies the states of the batch.
...
@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker):
)
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
(
seq_lens_cpu_cache
=
seq_lens_cpu
)
forward_batch
=
ForwardBatch
.
init_new
(
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
)
)
...
@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker):
...
@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker):
return_logprob_backup
=
batch
.
return_logprob
return_logprob_backup
=
batch
.
return_logprob
# Prepare metadata
# Prepare metadata
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
batch
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
context_length
,
pad_input
=
self
.
cuda_graph_runner_for_draft_extend
is
not
None
,
)
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
return_logprob
=
False
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
)
)
if
forward_batch
.
seq_lens_cpu
is
not
None
:
forward_batch
.
seq_lens_sum
=
forward_batch
.
seq_lens_cpu
.
sum
().
item
()
else
:
forward_batch
.
seq_lens_sum
=
batch
.
seq_lens
.
sum
().
item
()
# Run
# Run
can_cuda_graph
=
(
can_cuda_graph
=
(
...
@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker):
...
@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker):
logits_output
=
self
.
cuda_graph_runner_for_draft_extend
.
replay
(
logits_output
=
self
.
cuda_graph_runner_for_draft_extend
.
replay
(
forward_batch
forward_batch
)
)
forward_batch
.
spec_info
.
topk_p
,
forward_batch
.
spec_info
.
topk_index
=
(
logits_output
.
topk_p
,
logits_output
.
topk_index
,
)
forward_batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
else
:
else
:
self
.
draft_model_runner
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
draft_model_runner
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
logits_output
=
self
.
draft_model_runner
.
model
.
forward
(
logits_output
=
self
.
draft_model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
# Restore backup.
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
...
...
test/srt/test_full_deepseek_v3.py
View file @
dc0705a5
...
@@ -87,7 +87,7 @@ class TestDeepseekV3MTP(CustomTestCase):
...
@@ -87,7 +87,7 @@ class TestDeepseekV3MTP(CustomTestCase):
"--speculative-num-steps"
,
"--speculative-num-steps"
,
"3"
,
"3"
,
"--speculative-eagle-topk"
,
"--speculative-eagle-topk"
,
"
2
"
,
"
1
"
,
"--speculative-num-draft-tokens"
,
"--speculative-num-draft-tokens"
,
"4"
,
"4"
,
]
]
...
@@ -155,7 +155,7 @@ class TestDeepseekV3MTP(CustomTestCase):
...
@@ -155,7 +155,7 @@ class TestDeepseekV3MTP(CustomTestCase):
if
is_in_amd_ci
():
if
is_in_amd_ci
():
self
.
assertGreater
(
speed
,
15
)
self
.
assertGreater
(
speed
,
15
)
else
:
else
:
self
.
assertGreater
(
speed
,
10
5
)
self
.
assertGreater
(
speed
,
1
3
0
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment