Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d2e0881a
Unverified
Commit
d2e0881a
authored
May 23, 2025
by
Byron Hsu
Committed by
GitHub
May 23, 2025
Browse files
[PD] support spec decode (#6507)
Co-authored-by:
SangBin Cho
<
rkooo567@gmail.com
>
parent
2f427491
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
190 additions
and
5 deletions
+190
-5
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+11
-1
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+3
-0
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
+2
-1
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+13
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+142
-1
No files found.
.pre-commit-config.yaml
View file @
d2e0881a
...
...
@@ -23,7 +23,7 @@ repos:
hooks
:
-
id
:
isort
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.11.
2
rev
:
v0.11.
7
hooks
:
-
id
:
ruff
args
:
[
--select=F401
,
--fixable=F401
]
...
...
python/sglang/srt/disaggregation/decode.py
View file @
d2e0881a
...
...
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -76,6 +76,7 @@ class DecodePreallocQueue:
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
List
[
torch
.
Tensor
],
aux_dtype
:
torch
.
dtype
,
...
...
@@ -91,6 +92,7 @@ class DecodePreallocQueue:
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
token_to_kv_pool
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
self
.
token_to_kv_pool
)
self
.
aux_dtype
=
aux_dtype
self
.
metadata_buffers
=
metadata_buffers
...
...
@@ -119,6 +121,14 @@ class DecodePreallocQueue:
self
.
token_to_kv_pool
.
get_contiguous_buf_infos
()
)
if
self
.
draft_token_to_kv_pool
is
not
None
:
draft_kv_data_ptrs
,
draft_kv_data_lens
,
draft_kv_item_lens
=
(
self
.
draft_token_to_kv_pool
.
get_contiguous_buf_infos
()
)
kv_data_ptrs
+=
draft_kv_data_ptrs
kv_data_lens
+=
draft_kv_data_lens
kv_item_lens
+=
draft_kv_item_lens
kv_args
.
kv_data_ptrs
=
kv_data_ptrs
kv_args
.
kv_data_lens
=
kv_data_lens
kv_args
.
kv_item_lens
=
kv_item_lens
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
d2e0881a
...
...
@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
return
src_groups
,
dst_groups
# prefill
@
dataclasses
.
dataclass
class
TransferKVChunk
:
room
:
int
...
...
@@ -60,6 +61,7 @@ class TransferKVChunk:
prefill_aux_index
:
Optional
[
int
]
# decode
@
dataclasses
.
dataclass
class
TransferInfo
:
room
:
int
...
...
@@ -93,6 +95,7 @@ class TransferInfo:
)
# decode
@
dataclasses
.
dataclass
class
KVArgsRegisterInfo
:
room
:
str
...
...
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
View file @
d2e0881a
...
...
@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
self
,
session_id
:
str
,
buffer
:
int
,
peer_buffer_address
:
int
,
length
:
int
)
->
int
:
"""Synchronously transfer data to the specified address."""
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret
=
self
.
engine
.
transfer_sync_write
(
session_id
,
buffer
,
peer_buffer_address
,
length
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
d2e0881a
...
...
@@ -61,6 +61,7 @@ class PrefillBootstrapQueue:
def
__init__
(
self
,
token_to_kv_pool
:
KVCache
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
List
[
torch
.
Tensor
],
aux_dtype
:
torch
.
dtype
,
...
...
@@ -72,6 +73,8 @@ class PrefillBootstrapQueue:
scheduler
:
Scheduler
,
):
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
token_to_kv_pool
)
self
.
aux_dtype
=
aux_dtype
...
...
@@ -98,6 +101,16 @@ class PrefillBootstrapQueue:
self
.
token_to_kv_pool
.
get_contiguous_buf_infos
()
)
if
self
.
draft_token_to_kv_pool
is
not
None
:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs
,
draft_kv_data_lens
,
draft_kv_item_lens
=
(
self
.
draft_token_to_kv_pool
.
get_contiguous_buf_infos
()
)
kv_data_ptrs
+=
draft_kv_data_ptrs
kv_data_lens
+=
draft_kv_data_lens
kv_item_lens
+=
draft_kv_item_lens
kv_args
.
kv_data_ptrs
=
kv_data_ptrs
kv_args
.
kv_data_lens
=
kv_data_lens
kv_args
.
kv_item_lens
=
kv_item_lens
...
...
python/sglang/srt/managers/scheduler.py
View file @
d2e0881a
...
...
@@ -591,6 +591,11 @@ class Scheduler(
self
.
disagg_decode_prealloc_queue
=
DecodePreallocQueue
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
draft_token_to_kv_pool
=
(
None
if
self
.
draft_worker
is
None
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
aux_dtype
=
aux_dtype
,
...
...
@@ -624,6 +629,11 @@ class Scheduler(
self
.
disagg_prefill_bootstrap_queue
=
PrefillBootstrapQueue
(
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
draft_token_to_kv_pool
=
(
None
if
self
.
draft_worker
is
None
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
aux_dtype
=
aux_dtype
,
...
...
@@ -1409,6 +1419,13 @@ class Scheduler(
self
.
running_batch
.
batch_is_full
=
True
break
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if
len
(
adder
.
can_run_list
)
>=
self
.
req_to_token_pool
.
available_size
():
self
.
running_batch
.
batch_is_full
=
True
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
,
self
.
enable_hierarchical_cache
,
...
...
test/srt/run_suite.py
View file @
d2e0881a
...
...
@@ -115,7 +115,7 @@ suites = {
# TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
#
TestFile("test_disaggregation.py", 210),
# disabled since we have different_tp test
TestFile
(
"test_disaggregation.py"
,
210
),
TestFile
(
"test_disaggregation_different_tp.py"
,
210
),
TestFile
(
"test_full_deepseek_v3.py"
,
250
),
],
...
...
test/srt/test_disaggregation.py
View file @
d2e0881a
...
...
@@ -8,6 +8,8 @@ import requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
...
...
@@ -17,7 +19,9 @@ from sglang.test.test_utils import (
)
class
TestDisaggregationMooncake
(
CustomTestCase
):
# skip the test because we have different_tp test
@
unittest
.
skip
(
"skip the test because we have different_tp test"
)
class
TestDisaggregationAccuracy
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -65,6 +69,8 @@ class TestDisaggregationMooncake(CustomTestCase):
str
(
cls
.
base_port
+
100
),
"--tp"
,
"4"
,
# "--disaggregation-ib-device",
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3",
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
...
...
@@ -87,6 +93,8 @@ class TestDisaggregationMooncake(CustomTestCase):
"4"
,
"--base-gpu-id"
,
"4"
,
# "--disaggregation-ib-device",
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
...
...
@@ -136,5 +144,138 @@ class TestDisaggregationMooncake(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
class
TestDisaggregationSpecAccuracy
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls
.
draft_model
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls
.
base_host
=
"127.0.0.1"
cls
.
base_port
=
int
(
DEFAULT_URL_FOR_TEST
.
split
(
":"
)[
-
1
])
cls
.
lb_url
=
DEFAULT_URL_FOR_TEST
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
100
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
200
}
"
cls
.
spec_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
cls
.
draft_model
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"4"
,
"--speculative-num-draft-tokens"
,
"16"
,
"--cuda-graph-max-bs"
,
"8"
,
]
run_with_timeout
(
cls
.
start_prefill
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
run_with_timeout
(
cls
.
start_decode
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang.srt.disaggregation.mini_lb"
,
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
),
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
subprocess
.
Popen
(
lb_command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
100
),
"--tp"
,
"4"
,
# "--disaggregation-ib-device",
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3",
]
+
cls
.
spec_args
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
prefill_args
,
)
@
classmethod
def
start_decode
(
cls
):
decode_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
200
),
"--tp"
,
"4"
,
"--base-gpu-id"
,
"4"
,
# "--disaggregation-ib-device",
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
]
+
cls
.
spec_args
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
decode_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
decode_args
,
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
4
,
# TODO: 128 crashes the decode
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
lb_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment