Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8cda5a62
Unverified
Commit
8cda5a62
authored
Sep 07, 2025
by
Qiaolin Yu
Committed by
GitHub
Sep 07, 2025
Browse files
Standalone speculative decoding (#10090)
parent
400d3b97
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
285 additions
and
9 deletions
+285
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+17
-5
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+5
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+8
-0
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+5
-0
python/sglang/srt/speculative/standalone_worker.py
python/sglang/srt/speculative/standalone_worker.py
+109
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+4
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_standalone_speculative_decoding.py
test/srt/test_standalone_speculative_decoding.py
+115
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
8cda5a62
...
...
@@ -1539,7 +1539,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
DECODE
bs
=
len
(
self
.
reqs
)
if
self
.
spec_algorithm
.
is_eagle
():
if
self
.
spec_algorithm
.
is_eagle
()
or
self
.
spec_algorithm
.
is_standalone
()
:
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
...
...
python/sglang/srt/managers/scheduler.py
View file @
8cda5a62
...
...
@@ -349,6 +349,18 @@ class Scheduler(
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
)
elif
self
.
spec_algorithm
.
is_standalone
():
from
sglang.srt.speculative.standalone_worker
import
StandaloneWorker
self
.
draft_worker
=
StandaloneWorker
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
moe_ep_rank
=
moe_ep_rank
,
server_args
=
server_args
,
nccl_port
=
port_args
.
nccl_port
,
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
)
else
:
self
.
draft_worker
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
8cda5a62
...
...
@@ -271,7 +271,10 @@ class CudaGraphRunner:
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
(
model_runner
.
spec_algorithm
.
is_eagle
()
or
model_runner
.
spec_algorithm
.
is_standalone
()
):
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen"
)
else
:
...
...
@@ -827,7 +830,10 @@ class CudaGraphRunner:
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
if
(
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
or
self
.
model_runner
.
spec_algorithm
.
is_standalone
()
):
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
...
...
python/sglang/srt/server_args.py
View file @
8cda5a62
...
...
@@ -473,9 +473,14 @@ class ServerArgs:
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
reserved_mem
=
32
*
1024
# draft model and larger cuda graph buffers
if
self
.
speculative_algorithm
is
not
None
:
# draft model and larger cuda graph buffers
reserved_mem
+=
2
*
1024
if
self
.
speculative_algorithm
==
"STANDALONE"
:
# Standalone speculative decoding needs more memory than other speculative
# decoding algorithms since the draft model is typically larger.
reserved_mem
+=
6
*
1024
else
:
reserved_mem
+=
2
*
1024
if
self
.
enable_dp_attention
:
reserved_mem
+=
4
*
1024
...
...
@@ -704,7 +709,12 @@ class ServerArgs:
# NEXTN shares the same implementation of EAGLE
self
.
speculative_algorithm
=
"EAGLE"
if
self
.
speculative_algorithm
in
(
"EAGLE"
,
"EAGLE3"
):
if
self
.
speculative_algorithm
in
(
"EAGLE"
,
"EAGLE3"
,
"STANDALONE"
):
if
self
.
speculative_algorithm
==
"STANDALONE"
:
# TODO: support dp attention for standalone speculative decoding
assert
(
self
.
enable_dp_attention
is
False
),
"Currently standalone speculative decoding does not support dp attention."
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
48
self
.
disable_overlap_schedule
=
True
...
...
@@ -1499,7 +1509,7 @@ class ServerArgs:
parser
.
add_argument
(
"--speculative-algorithm"
,
type
=
str
,
choices
=
[
"EAGLE"
,
"EAGLE3"
,
"NEXTN"
],
choices
=
[
"EAGLE"
,
"EAGLE3"
,
"NEXTN"
,
"STANDALONE"
],
help
=
"Speculative algorithm."
,
)
parser
.
add_argument
(
...
...
@@ -2635,7 +2645,9 @@ def auto_choose_speculative_params(self: ServerArgs):
"""
hf_config
=
self
.
get_hf_config
()
arch
=
hf_config
.
architectures
[
0
]
if
self
.
speculative_algorithm
==
"STANDALONE"
:
# The default value for standalone speculative decoding
return
(
3
,
1
,
4
)
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
return
(
5
,
4
,
8
)
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
8cda5a62
...
...
@@ -341,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
extend_seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
extend_seq_lens
)
self
.
out_cache_loc
[:
num_tokens
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
num_tokens
].
copy_
(
forward_batch
.
positions
)
self
.
hidden_states
[:
num_tokens
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
if
(
forward_batch
.
spec_info
.
hidden_states
.
shape
[
1
]
==
self
.
hidden_states
.
shape
[
1
]
):
self
.
hidden_states
[:
num_tokens
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
if
forward_batch
.
spec_info
.
accept_length
is
not
None
:
self
.
accept_length
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
accept_length
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
8cda5a62
...
...
@@ -730,6 +730,14 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
forward_batch
.
input_ids
=
input_ids
# This is a temporary fix for the case that the user is using standalone
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
# rope kernel needs cache_loc to be contiguous.
if
(
self
.
server_args
.
speculative_algorithm
==
"STANDALONE"
and
self
.
model_config
.
hf_config
.
architectures
[
0
]
==
"GptOssForCausalLM"
):
out_cache_loc
=
out_cache_loc
.
contiguous
()
forward_batch
.
out_cache_loc
=
out_cache_loc
[
i
]
forward_batch
.
positions
.
add_
(
1
)
forward_batch
.
attn_backend
=
self
.
draft_attn_backend
.
attn_backends
[
i
]
...
...
python/sglang/srt/speculative/spec_info.py
View file @
8cda5a62
...
...
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
NONE
=
auto
()
EAGLE
=
auto
()
EAGLE3
=
auto
()
STANDALONE
=
auto
()
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
...
...
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
def
is_eagle3
(
self
):
return
self
==
SpeculativeAlgorithm
.
EAGLE3
def
is_standalone
(
self
):
return
self
==
SpeculativeAlgorithm
.
STANDALONE
@
staticmethod
def
from_string
(
name
:
str
):
name_map
=
{
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE3"
:
SpeculativeAlgorithm
.
EAGLE3
,
"STANDALONE"
:
SpeculativeAlgorithm
.
STANDALONE
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
if
name
is
not
None
:
...
...
python/sglang/srt/speculative/standalone_worker.py
0 → 100644
View file @
8cda5a62
import
logging
from
contextlib
import
contextmanager
from
typing
import
Optional
import
torch
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
,
load_token_map
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
empty_context
,
get_bool_env_var
,
is_cuda
if
is_cuda
():
from
sgl_kernel
import
segment_packbits
logger
=
logging
.
getLogger
(
__name__
)
RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"RETURN_ORIGINAL_LOGPROB"
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
class
StandaloneWorker
(
EAGLEWorker
):
def
__init__
(
self
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
tp_rank
:
int
,
dp_rank
:
Optional
[
int
],
moe_ep_rank
:
int
,
nccl_port
:
int
,
target_worker
:
TpModelWorker
,
):
# Parse arguments
self
.
server_args
=
server_args
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
server_args
.
speculative_num_steps
self
.
speculative_num_draft_tokens
=
server_args
.
speculative_num_draft_tokens
self
.
enable_nan_detection
=
server_args
.
enable_nan_detection
self
.
gpu_id
=
gpu_id
self
.
device
=
server_args
.
device
self
.
target_worker
=
target_worker
self
.
page_size
=
server_args
.
page_size
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
padded_static_len
=
-
1
# Override the context length of the draft model to be the same as the target model.
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
# Do not capture cuda graph in `super().__init__()`
# It will be captured later.
backup_disable_cuda_graph
=
server_args
.
disable_cuda_graph
server_args
.
disable_cuda_graph
=
True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
=
(
target_worker
.
get_memory_pool
()
)
# Load hot token ids
if
server_args
.
speculative_token_map
is
not
None
:
self
.
hot_token_id
=
load_token_map
(
server_args
.
speculative_token_map
)
server_args
.
json_model_override_args
=
(
f
'{{"hot_vocab_size":
{
len
(
self
.
hot_token_id
)
}
}}'
)
else
:
self
.
hot_token_id
=
None
# Init draft worker
with
empty_context
():
TpModelWorker
.
__init__
(
self
,
server_args
=
server_args
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
pp_rank
=
0
,
# FIXME
dp_rank
=
dp_rank
,
moe_ep_rank
=
moe_ep_rank
,
nccl_port
=
nccl_port
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
# Init attention backend and cuda graphs
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
)
self
.
draft_tp_context
=
(
draft_tp_context
if
server_args
.
enable_dp_attention
else
empty_context
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
# Some dummy tensors
self
.
num_new_pages_per_topk
=
torch
.
empty
(
(),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
extend_lens
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
python/sglang/test/test_utils.py
View file @
8cda5a62
...
...
@@ -72,6 +72,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
=
"meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
=
"lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
=
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
=
(
"meta-llama/Llama-3.1-8B-Instruct"
)
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
# Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
=
(
...
...
test/srt/run_suite.py
View file @
8cda5a62
...
...
@@ -76,6 +76,7 @@ suites = {
TestFile
(
"test_harmony_parser.py"
,
20
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hybrid_attn_backend.py"
,
100
),
TestFile
(
"test_standalone_speculative_decoding.py"
,
250
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_io_struct.py"
,
8
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
...
...
test/srt/test_standalone_speculative_decoding.py
0 → 100644
View file @
8cda5a62
import
os
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
GSM_DATASET_PATH
=
None
# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS
=
[
"--trust-remote-code"
,
"--cuda-graph-max-bs"
,
"8"
,
"--speculative-algorithm"
,
"STANDALONE"
,
"--speculative-draft-model-path"
,
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"4"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"7"
,
"--mem-fraction-static"
,
0.7
,
]
class
TestStandaloneSpeculativeDecodingBase
(
CustomTestCase
):
model
=
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
draft_model
=
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.7
# derived tests need to override this
spec_decode_threshold
=
3.6
# derived spec decoding tests need to override this
@
classmethod
def
get_server_args
(
cls
):
"""Return the arguments for the server launch. Override in subclasses."""
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"fa3"
]
@
classmethod
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
model
=
cls
.
model
cls
.
process
=
popen_launch_server
(
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
4
,
num_questions
=
100
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
data_path
=
GSM_DATASET_PATH
,
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"
{
metrics
=
}
"
)
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
self
.
assertGreater
(
metrics
[
metric_key
],
self
.
accuracy_threshold
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
avg_spec_accept_length
=
server_info
.
json
()[
"internal_states"
][
0
][
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
self
.
spec_decode_threshold
)
class
TestStandaloneSpeculativeDecodingTriton
(
TestStandaloneSpeculativeDecodingBase
):
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"triton"
]
class
TestStandaloneSpeculativeDecodingFlashinfer
(
TestStandaloneSpeculativeDecodingBase
):
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"flashinfer"
]
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment