Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7d316991
Unverified
Commit
7d316991
authored
Jun 14, 2025
by
Byron Hsu
Committed by
GitHub
Jun 14, 2025
Browse files
[PD] Update prefill.py (#7190)
parent
ab1a4fa5
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
412 additions
and
199 deletions
+412
-199
python/sglang/srt/disaggregation/base/conn.py
python/sglang/srt/disaggregation/base/conn.py
+23
-9
python/sglang/srt/disaggregation/common/utils.py
python/sglang/srt/disaggregation/common/utils.py
+42
-0
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+2
-2
python/sglang/srt/disaggregation/fake/conn.py
python/sglang/srt/disaggregation/fake/conn.py
+8
-1
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+9
-4
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+10
-5
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+125
-43
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+125
-121
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+2
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+18
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+48
-4
No files found.
python/sglang/srt/disaggregation/base/conn.py
View file @
7d316991
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
numpy.typing
as
npt
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
class
KVArgs
:
engine_rank
:
int
kv_data_ptrs
:
l
ist
[
int
]
kv_data_lens
:
l
ist
[
int
]
kv_item_lens
:
l
ist
[
int
]
aux_data_ptrs
:
l
ist
[
int
]
aux_data_lens
:
l
ist
[
int
]
aux_item_lens
:
l
ist
[
int
]
kv_data_ptrs
:
L
ist
[
int
]
kv_data_lens
:
L
ist
[
int
]
kv_item_lens
:
L
ist
[
int
]
aux_data_ptrs
:
L
ist
[
int
]
aux_data_lens
:
L
ist
[
int
]
aux_item_lens
:
L
ist
[
int
]
ib_device
:
str
ib_traffic_class
:
str
gpu_id
:
int
# for different tp
decode_tp_size
:
int
# for pp prefill
prefill_pp_size
:
int
class
KVPoll
:
...
...
@@ -45,7 +54,12 @@ class BaseKVSender(ABC):
@
abstractmethod
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
,
dest_tp_ranks
:
List
[
int
],
pp_rank
:
int
,
):
...
@
abstractmethod
...
...
python/sglang/srt/disaggregation/common/utils.py
0 → 100644
View file @
7d316991
import
threading
from
collections
import
deque
from
typing
import
List
,
Tuple
import
numpy
as
np
import
numpy.typing
as
npt
class
FastQueue
:
def
__init__
(
self
):
self
.
_buf
=
deque
()
self
.
_cond
=
threading
.
Condition
()
def
put
(
self
,
item
):
with
self
.
_cond
:
self
.
_buf
.
append
(
item
)
# wake up a thread of wait()
self
.
_cond
.
notify
()
def
get
(
self
):
with
self
.
_cond
:
# if queue is empty ,block until is notified()
while
not
self
.
_buf
:
self
.
_cond
.
wait
()
return
self
.
_buf
.
popleft
()
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
python/sglang/srt/disaggregation/decode.py
View file @
7d316991
...
...
@@ -33,8 +33,8 @@ from torch.distributed import ProcessGroup
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.utils
import
(
FAKE_BOOTSTRAP_HOST
,
DisaggregationMode
,
FakeBootstrapHost
,
KVClassType
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
...
...
@@ -207,7 +207,7 @@ class DecodePreallocQueue:
def
add
(
self
,
req
:
Req
)
->
None
:
"""Add a request to the pending queue."""
if
req
.
bootstrap_host
==
F
akeBootstrapHost
:
if
req
.
bootstrap_host
==
F
AKE_BOOTSTRAP_HOST
:
# Fake transfer for warmup reqs
kv_receiver_class
=
get_kv_class
(
TransferBackend
.
FAKE
,
KVClassType
.
RECEIVER
)
else
:
...
...
python/sglang/srt/disaggregation/fake/conn.py
View file @
7d316991
...
...
@@ -17,7 +17,14 @@ logger = logging.getLogger(__name__)
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
class
FakeKVSender
(
BaseKVSender
):
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
):
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
,
dest_tp_ranks
:
List
[
int
],
pp_rank
:
int
,
):
self
.
has_sent
=
False
def
poll
(
self
)
->
KVPoll
:
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
7d316991
...
...
@@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import (
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
from
sglang.srt.disaggregation.common.utils
import
(
FastQueue
,
group_concurrent_contiguous
,
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_free_port
,
...
...
@@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager):
class
MooncakeKVSender
(
BaseKVSender
):
def
__init__
(
self
,
mgr
:
MooncakeKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
self
,
mgr
:
MooncakeKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
,
dest_tp_ranks
:
List
[
int
],
pp_rank
:
int
,
):
self
.
kv_mgr
=
mgr
self
.
bootstrap_room
=
bootstrap_room
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
7d316991
...
...
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
CommonKVManager
,
CommonKVReceiver
,
)
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
group_concurrent_contiguous
,
)
from
sglang.srt.disaggregation.common.utils
import
group_concurrent_contiguous
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_local_ip_by_remote
...
...
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
class
NixlKVSender
(
BaseKVSender
):
def
__init__
(
self
,
mgr
:
NixlKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
):
def
__init__
(
self
,
mgr
:
NixlKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
,
dest_tp_ranks
:
List
[
int
],
pp_rank
:
int
,
):
self
.
kv_mgr
=
mgr
self
.
bootstrap_room
=
bootstrap_room
self
.
aux_index
=
None
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
7d316991
...
...
@@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVPoll
from
sglang.srt.disaggregation.utils
import
(
FAKE_BOOTSTRAP_HOST
,
DisaggregationMode
,
FakeBootstrapHost
,
KVClassType
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
...
...
@@ -51,7 +51,6 @@ if TYPE_CHECKING:
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
,
Scheduler
from
sglang.srt.mem_cache.memory_pool
import
KVCache
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -68,35 +67,45 @@ class PrefillBootstrapQueue:
metadata_buffers
:
MetadataBuffers
,
tp_rank
:
int
,
tp_size
:
int
,
gpu_id
:
int
,
bootstrap_port
:
int
,
gloo_group
:
ProcessGroup
,
transfer_backend
:
TransferBackend
,
max_total_num_tokens
:
int
,
decode_tp_size
:
int
,
decode_dp_size
:
int
,
scheduler
:
Scheduler
,
pp_rank
:
int
,
pp_size
:
int
,
transfer_backend
:
TransferBackend
,
):
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
token_to_kv_pool
)
self
.
metadata_buffers
=
metadata_buffers
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
transfer_backend
=
transfer_backend
self
.
scheduler
=
scheduler
self
.
kv_manager
=
self
.
_init_kv_manager
()
self
.
decode_tp_size
=
decode_tp_size
self
.
decode_dp_size
=
decode_dp_size
self
.
pp_rank
=
pp_rank
self
.
pp_size
=
pp_size
self
.
gpu_id
=
gpu_id
self
.
bootstrap_port
=
bootstrap_port
self
.
queue
:
List
[
Req
]
=
[]
self
.
pp_rank
=
pp_rank
self
.
pp_size
=
pp_size
self
.
gloo_group
=
gloo_group
self
.
bootstrap_port
=
bootstrap_port
def
store_prefill_results
(
self
,
idx
:
int
,
token_id
:
int
):
assert
token_id
>=
0
,
f
"token_id:
{
token_id
}
is negative"
output_id_buffer
=
self
.
metadata_buffers
[
0
]
output_id_buffer
[
idx
]
=
token_id
self
.
max_total_num_tokens
=
max_total_num_tokens
self
.
scheduler
=
scheduler
self
.
transfer_backend
=
transfer_backend
self
.
kv_manager
=
self
.
_init_kv_manager
()
def
_init_kv_manager
(
self
)
->
BaseKVManager
:
kv_args
=
KVArgs
()
kv_args_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
KVARGS
)
kv_args
=
kv_args_class
()
kv_args
.
engine_rank
=
self
.
tp_rank
kv_args
.
decode_tp_size
=
self
.
decode_tp_size
//
self
.
decode_dp_size
kv_args
.
prefill_pp_size
=
self
.
pp_size
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
=
(
self
.
token_to_kv_pool
.
get_contiguous_buf_infos
()
)
...
...
@@ -115,12 +124,12 @@ class PrefillBootstrapQueue:
kv_args
.
kv_data_lens
=
kv_data_lens
kv_args
.
kv_item_lens
=
kv_item_lens
# Define req -> input ids buffer
kv_args
.
aux_data_ptrs
,
kv_args
.
aux_data_lens
,
kv_args
.
aux_item_lens
=
(
self
.
metadata_buffers
.
get_buf_infos
()
)
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
kv_manager
=
kv_manager_class
(
kv_args
,
...
...
@@ -130,23 +139,39 @@ class PrefillBootstrapQueue:
)
return
kv_manager
def
add
(
self
,
req
:
Req
)
->
None
:
if
req
.
bootstrap_host
==
FakeBootstrapHost
:
# Fake transfer for warmup reqs
def
add
(
self
,
req
:
Req
,
num_kv_heads
:
int
)
->
None
:
if
self
.
_check_if_req_exceed_kv_capacity
(
req
):
return
if
req
.
bootstrap_host
==
FAKE_BOOTSTRAP_HOST
:
kv_sender_class
=
get_kv_class
(
TransferBackend
.
FAKE
,
KVClassType
.
SENDER
)
else
:
kv_sender_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
SENDER
)
dest_tp_ranks
=
[
self
.
tp_rank
]
req
.
disagg_kv_sender
=
kv_sender_class
(
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
dest_tp_ranks
=
dest_tp_ranks
,
pp_rank
=
self
.
pp_rank
,
)
self
.
_process_req
(
req
)
self
.
queue
.
append
(
req
)
def
extend
(
self
,
reqs
:
List
[
Req
])
->
None
:
def
extend
(
self
,
reqs
:
List
[
Req
]
,
num_kv_heads
:
int
)
->
None
:
for
req
in
reqs
:
self
.
add
(
req
)
self
.
add
(
req
,
num_kv_heads
)
def
_check_if_req_exceed_kv_capacity
(
self
,
req
:
Req
)
->
bool
:
if
len
(
req
.
origin_input_ids
)
>
self
.
max_total_num_tokens
:
message
=
f
"Request
{
req
.
rid
}
exceeds the maximum number of tokens:
{
len
(
req
.
origin_input_ids
)
}
>
{
self
.
max_total_num_tokens
}
"
logger
.
error
(
message
)
prepare_abort
(
req
,
message
)
self
.
scheduler
.
stream_output
([
req
],
req
.
return_logprob
)
return
True
return
False
def
_process_req
(
self
,
req
:
Req
)
->
None
:
"""
...
...
@@ -154,19 +179,40 @@ class PrefillBootstrapQueue:
"""
req
.
sampling_params
.
max_new_tokens
=
1
def
pop_bootstrapped
(
self
)
->
List
[
Req
]:
"""pop the reqs which has finished bootstrapping"""
def
pop_bootstrapped
(
self
,
return_failed_reqs
:
bool
=
False
,
rids_to_check
:
Optional
[
List
[
str
]]
=
None
,
)
->
List
[
Req
]:
"""
pop the reqs which has finished bootstrapping
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
bootstrapped_reqs
=
[]
failed_reqs
=
[]
indices_to_remove
=
set
()
if
len
(
self
.
queue
)
==
0
:
return
[]
if
return_failed_reqs
is
False
:
return
[]
else
:
return
[],
[]
polls
=
poll_and_all_reduce
(
[
req
.
disagg_kv_sender
for
req
in
self
.
queue
],
self
.
gloo_group
)
for
i
,
(
req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
if
rids_to_check
is
not
None
:
# if req not in reqs_info_to_check, skip
if
req
.
rid
not
in
rids_to_check
:
continue
# Either waiting for input or failed
assert
poll
==
KVPoll
.
WaitingForInput
or
poll
==
KVPoll
.
Failed
if
poll
==
KVPoll
.
Bootstrapping
:
continue
elif
poll
==
KVPoll
.
Failed
:
...
...
@@ -181,9 +227,10 @@ class PrefillBootstrapQueue:
)
self
.
scheduler
.
stream_output
([
req
],
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
failed_reqs
.
append
(
req
)
continue
# KV.WaitingForInput
# KV.WaitingForInput
- init here
num_kv_indices
=
len
(
req
.
origin_input_ids
)
if
self
.
req_to_metadata_buffer_idx_allocator
.
available_size
()
==
0
:
break
...
...
@@ -192,9 +239,9 @@ class PrefillBootstrapQueue:
self
.
req_to_metadata_buffer_idx_allocator
.
alloc
()
)
assert
req
.
metadata_buffer_index
is
not
None
num_pages
=
kv_to_page_num
(
num_kv_indices
,
self
.
token_to_kv_pool
.
page_size
)
req
.
disagg_kv_sender
.
init
(
num_pages
,
req
.
metadata_buffer_index
)
bootstrapped_reqs
.
append
(
req
)
indices_to_remove
.
add
(
i
)
...
...
@@ -202,7 +249,10 @@ class PrefillBootstrapQueue:
entry
for
i
,
entry
in
enumerate
(
self
.
queue
)
if
i
not
in
indices_to_remove
]
return
bootstrapped_reqs
if
return_failed_reqs
is
False
:
return
bootstrapped_reqs
else
:
return
bootstrapped_reqs
,
failed_reqs
class
SchedulerDisaggregationPrefillMixin
:
...
...
@@ -211,7 +261,7 @@ class SchedulerDisaggregationPrefillMixin:
"""
@
torch
.
no_grad
()
def
event_loop_normal_disagg_prefill
(
self
:
Scheduler
):
def
event_loop_normal_disagg_prefill
(
self
:
Scheduler
)
->
None
:
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while
True
:
...
...
@@ -229,7 +279,6 @@ class SchedulerDisaggregationPrefillMixin:
or
self
.
server_args
.
enable_sp_layernorm
):
batch
,
_
=
self
.
prepare_dp_attn_batch
(
batch
)
self
.
cur_batch
=
batch
if
batch
:
...
...
@@ -250,7 +299,7 @@ class SchedulerDisaggregationPrefillMixin:
self
.
running_batch
.
batch_is_full
=
False
@
torch
.
no_grad
()
def
event_loop_overlap_disagg_prefill
(
self
:
Scheduler
):
def
event_loop_overlap_disagg_prefill
(
self
:
Scheduler
)
->
None
:
self
.
result_queue
=
deque
()
while
True
:
...
...
@@ -268,9 +317,7 @@ class SchedulerDisaggregationPrefillMixin:
or
self
.
server_args
.
enable_sp_layernorm
):
batch
,
_
=
self
.
prepare_dp_attn_batch
(
batch
)
self
.
cur_batch
=
batch
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
...
...
@@ -287,6 +334,9 @@ class SchedulerDisaggregationPrefillMixin:
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result_disagg_prefill
(
tmp_batch
,
tmp_result
)
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
...
...
@@ -309,7 +359,7 @@ class SchedulerDisaggregationPrefillMixin:
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
)
->
None
:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_inf
l
ight_queue
Adapted from process_batch_result_prefill
"""
(
...
...
@@ -325,7 +375,7 @@ class SchedulerDisaggregationPrefillMixin:
)
logprob_pt
=
0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
# Transfer kv for prefill completed requests and add it into disagg_prefill_inf
l
ight_queue
if
self
.
enable_overlap
:
# wait
logits_output
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
...
...
@@ -397,11 +447,15 @@ class SchedulerDisaggregationPrefillMixin:
# We need to remove the sync in the following function for overlap schedule.
self
.
set_next_batch_sampling_info_done
(
batch
)
def
process_disagg_prefill_inflight_queue
(
self
:
Scheduler
)
->
None
:
def
process_disagg_prefill_inflight_queue
(
self
:
Scheduler
,
rids_to_check
:
Optional
[
List
[
str
]]
=
None
)
->
List
[
Req
]:
"""
Poll the requests in the middle of transfer. If done, return the request.
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
assert
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
if
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
return
[]
done_reqs
=
[]
...
...
@@ -413,6 +467,14 @@ class SchedulerDisaggregationPrefillMixin:
undone_reqs
:
List
[
Req
]
=
[]
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for
req
,
poll
in
zip
(
self
.
disagg_prefill_inflight_queue
,
polls
):
if
rids_to_check
is
not
None
:
if
req
.
rid
not
in
rids_to_check
:
undone_reqs
.
append
(
req
)
continue
assert
poll
==
KVPoll
.
Success
or
poll
==
KVPoll
.
Failed
if
poll
in
[
KVPoll
.
WaitingForInput
,
KVPoll
.
Transferring
]:
undone_reqs
.
append
(
req
)
elif
poll
==
KVPoll
.
Success
:
# transfer done
...
...
@@ -434,11 +496,8 @@ class SchedulerDisaggregationPrefillMixin:
req
,
error_message
,
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
done_reqs
.
append
(
req
)
for
req
in
done_reqs
:
self
.
disagg_prefill_bootstrap_queue
.
req_to_metadata_buffer_idx_allocator
.
free
(
req
.
metadata_buffer_index
)
else
:
assert
False
,
f
"Unexpected polling state
{
poll
=
}
"
# Stream requests which have finished transfer
self
.
stream_output
(
...
...
@@ -446,9 +505,32 @@ class SchedulerDisaggregationPrefillMixin:
any
(
req
.
return_logprob
for
req
in
done_reqs
),
None
,
)
for
req
in
done_reqs
:
req
:
Req
self
.
req_to_metadata_buffer_idx_allocator
.
free
(
req
.
metadata_buffer_index
)
req
.
metadata_buffer_index
=
-
1
self
.
disagg_prefill_inflight_queue
=
undone_reqs
return
done_reqs
def
get_transferred_rids
(
self
:
Scheduler
)
->
List
[
str
]:
"""
Used by PP, get the transferred rids but **do not pop**
"""
polls
=
poll_and_all_reduce
(
[
req
.
disagg_kv_sender
for
req
in
self
.
disagg_prefill_inflight_queue
],
self
.
tp_worker
.
get_tp_group
().
cpu_group
,
)
transferred_rids
:
List
[
str
]
=
[]
for
req
,
poll
in
zip
(
self
.
disagg_prefill_inflight_queue
,
polls
):
if
poll
==
KVPoll
.
Success
or
poll
==
KVPoll
.
Failed
:
transferred_rids
.
append
(
req
.
rid
)
return
transferred_rids
def
process_prefill_chunk
(
self
:
Scheduler
)
->
None
:
if
self
.
last_batch
and
self
.
last_batch
.
forward_mode
.
is_extend
():
if
self
.
chunked_req
:
...
...
python/sglang/srt/disaggregation/utils.py
View file @
7d316991
...
...
@@ -14,15 +14,15 @@ import requests
import
torch
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_ip
,
get_local_ip_by_remote
from
sglang.srt.utils
import
get_ip
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
FakeBootstrapHost
=
"2.2.2.2"
#
env var for testing failure, convert to float explicitly
FA
ILURE_PROB
=
float
(
os
.
getenv
(
"DISAGGREGATION_TEST_FAILURE_PROB"
,
0
))
#########################
# Constants & Enums
#
########################
FA
KE_BOOTSTRAP_HOST
=
"2.2.2.2"
class
DisaggregationMode
(
Enum
):
...
...
@@ -31,6 +31,14 @@ class DisaggregationMode(Enum):
DECODE
=
"decode"
#########################
# Synchronization
#########################
# env var for testing failure, convert to float explicitly
FAILURE_PROB
=
float
(
os
.
getenv
(
"DISAGGREGATION_TEST_FAILURE_PROB"
,
0
))
def
poll_and_all_reduce
(
pollers
,
gloo_group
):
# at a certain prob, the poll is failed to simulate failure
if
FAILURE_PROB
>
0
:
...
...
@@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group):
return
tensor_to_reduce
.
tolist
()
#########################
# Metadata Buffers
#########################
class
ReqToMetadataIdxAllocator
:
"""A memory pool that maps a request to its first output token location."""
...
...
@@ -70,6 +83,91 @@ class ReqToMetadataIdxAllocator:
self
.
free_slots
.
append
(
free_index
)
class
MetadataBuffers
:
def
__init__
(
self
,
size
:
int
,
max_top_logprobs_num
:
int
=
128
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_token_logprobs_idx
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_top_logprobs_val
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
get_buf_infos
(
self
):
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
]
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
]
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
]
return
ptrs
,
data_lens
,
item_lens
def
get_buf
(
self
,
idx
:
int
):
return
(
self
.
output_ids
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
)
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_val
[
0
]
)
if
req
.
output_token_logprobs_idx
:
# not none or empty list
self
.
output_token_logprobs_idx
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_idx
[
0
]
)
if
req
.
output_top_logprobs_val
:
# not none or empty list
self
.
output_top_logprobs_val
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_val
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_val
[
0
],
dtype
=
torch
.
float32
,
device
=
"cpu"
)
if
req
.
output_top_logprobs_idx
:
# not none or empty list
self
.
output_top_logprobs_idx
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_idx
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
#########################
# Transfer Backend
#########################
class
TransferBackend
(
Enum
):
MOONCAKE
=
"mooncake"
NIXL
=
"nixl"
...
...
@@ -77,6 +175,7 @@ class TransferBackend(Enum):
class
KVClassType
(
Enum
):
KVARGS
=
"kvargs"
MANAGER
=
"manager"
SENDER
=
"sender"
RECEIVER
=
"receiver"
...
...
@@ -87,6 +186,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
from
sglang.srt.disaggregation.fake
import
FakeKVReceiver
,
FakeKVSender
if
transfer_backend
==
TransferBackend
.
MOONCAKE
:
from
sglang.srt.disaggregation.base
import
KVArgs
from
sglang.srt.disaggregation.mooncake
import
(
MooncakeKVBootstrapServer
,
MooncakeKVManager
,
...
...
@@ -95,6 +195,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
)
class_mapping
=
{
KVClassType
.
KVARGS
:
KVArgs
,
KVClassType
.
MANAGER
:
MooncakeKVManager
,
KVClassType
.
SENDER
:
MooncakeKVSender
,
KVClassType
.
RECEIVER
:
(
MooncakeKVReceiver
),
...
...
@@ -102,6 +203,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
}
return
class_mapping
.
get
(
class_type
)
if
transfer_backend
==
TransferBackend
.
NIXL
:
from
sglang.srt.disaggregation.base
import
KVArgs
from
sglang.srt.disaggregation.nixl
import
(
NixlKVBootstrapServer
,
NixlKVManager
,
...
...
@@ -110,6 +212,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
)
class_mapping
=
{
KVClassType
.
KVARGS
:
KVArgs
,
KVClassType
.
MANAGER
:
NixlKVManager
,
KVClassType
.
SENDER
:
NixlKVSender
,
KVClassType
.
RECEIVER
:
(
NixlKVReceiver
),
...
...
@@ -117,9 +220,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
}
return
class_mapping
.
get
(
class_type
)
if
transfer_backend
==
TransferBackend
.
FAKE
:
from
sglang.srt.disaggregation.base
import
KVArgs
from
sglang.srt.disaggregation.fake
import
FakeKVReceiver
,
FakeKVSender
class_mapping
=
{
KVClassType
.
KVARGS
:
KVArgs
,
KVClassType
.
SENDER
:
FakeKVSender
,
KVClassType
.
RECEIVER
:
(
FakeKVReceiver
),
}
...
...
@@ -128,6 +233,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
raise
ValueError
(
f
"Unsupported transfer backend:
{
transfer_backend
}
"
)
#########################
# KV Pages
#########################
def
kv_to_page_indices
(
kv_indices
:
np
.
ndarray
,
page_size
:
int
):
# 1. The page is guaranteed to be full except the last page.
# 2. page index = kv_index // page_size
...
...
@@ -143,6 +253,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
return
(
num_kv_indices
+
page_size
-
1
)
//
page_size
#########################
# PDLB Registry
#########################
@
dataclasses
.
dataclass
class
PDRegistryRequest
:
"""A request to register a machine itself to the LB."""
...
...
@@ -181,6 +296,11 @@ def register_disaggregation_server(
)
#########################
# Misc
#########################
def
is_mla_backend
(
target_kv_pool
)
->
bool
:
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
...
...
@@ -200,119 +320,3 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
req
.
input_top_logprobs_idx
=
[]
req
.
input_token_ids_logprobs_val
=
[]
req
.
input_token_ids_logprobs_idx
=
[]
class
MetadataBuffers
:
def
__init__
(
self
,
size
:
int
,
max_top_logprobs_num
:
int
=
128
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_token_logprobs_idx
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_top_logprobs_val
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
get_buf_infos
(
self
):
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
]
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
]
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
]
return
ptrs
,
data_lens
,
item_lens
def
get_buf
(
self
,
idx
:
int
):
return
(
self
.
output_ids
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
)
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_val
[
0
]
)
if
req
.
output_token_logprobs_idx
:
# not none or empty list
self
.
output_token_logprobs_idx
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_idx
[
0
]
)
if
req
.
output_top_logprobs_val
:
# not none or empty list
self
.
output_top_logprobs_val
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_val
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_val
[
0
],
dtype
=
torch
.
float32
,
device
=
"cpu"
)
if
req
.
output_top_logprobs_idx
:
# not none or empty list
self
.
output_top_logprobs_idx
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_idx
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
class
FastQueue
:
def
__init__
(
self
):
self
.
_buf
=
deque
()
self
.
_cond
=
threading
.
Condition
()
def
put
(
self
,
item
):
with
self
.
_cond
:
self
.
_buf
.
append
(
item
)
# wake up a thread of wait()
self
.
_cond
.
notify
()
def
get
(
self
):
with
self
.
_cond
:
# if queue is empty ,block until is notified()
while
not
self
.
_buf
:
self
.
_cond
.
wait
()
return
self
.
_buf
.
popleft
()
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
"""Vectorised NumPy implementation."""
if
src_indices
.
size
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
python/sglang/srt/entrypoints/http_server.py
View file @
7d316991
...
...
@@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.srt.disaggregation.utils
import
(
F
akeBootstrapHost
,
F
AKE_BOOTSTRAP_HOST
,
register_disaggregation_server
,
)
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
...
...
@@ -878,7 +878,7 @@ def _wait_and_warmup(
"max_new_tokens"
:
8
,
"ignore_eos"
:
True
,
},
"bootstrap_host"
:
[
F
akeBootstrapHost
]
*
server_args
.
dp_size
,
"bootstrap_host"
:
[
F
AKE_BOOTSTRAP_HOST
]
*
server_args
.
dp_size
,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room"
:
[
...
...
python/sglang/srt/managers/scheduler.py
View file @
7d316991
...
...
@@ -619,7 +619,7 @@ class Scheduler(
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
):
# *2 for the headroom.
buffer_size
=
(
self
.
req_to_token_pool
.
size
)
*
2
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
self
.
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
buffer_size
)
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
)
...
...
@@ -627,7 +627,7 @@ class Scheduler(
# The decode requests polling kv cache
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
gloo_group
=
self
.
attn_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
scheduler
=
self
,
tree_cache
=
self
.
tree_cache
,
...
...
@@ -642,7 +642,7 @@ class Scheduler(
if
self
.
draft_worker
is
None
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
scheduler
=
self
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
...
...
@@ -660,7 +660,7 @@ class Scheduler(
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# *2 for the headroom.
buffer_size
=
self
.
max_running_requests
*
2
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
self
.
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
buffer_size
)
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
)
...
...
@@ -672,14 +672,20 @@ class Scheduler(
if
self
.
draft_worker
is
None
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
self
.
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
gpu_id
=
self
.
gpu_id
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
gloo_group
=
self
.
attn_tp_cpu_group
,
transfer_backend
=
self
.
transfer_backend
,
max_total_num_tokens
=
self
.
max_total_num_tokens
,
decode_tp_size
=
self
.
server_args
.
disaggregation_decode_tp
,
decode_dp_size
=
self
.
server_args
.
disaggregation_decode_dp
,
scheduler
=
self
,
pp_rank
=
self
.
pp_rank
,
pp_size
=
self
.
pp_size
,
transfer_backend
=
self
.
transfer_backend
,
)
# The prefill requests that are in the middle of kv sending
self
.
disagg_prefill_inflight_queue
:
List
[
Req
]
=
[]
...
...
@@ -1110,7 +1116,9 @@ class Scheduler(
def
_add_request_to_queue
(
self
,
req
:
Req
):
req
.
queue_time_start
=
time
.
perf_counter
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
disagg_prefill_bootstrap_queue
.
add
(
req
)
self
.
disagg_prefill_bootstrap_queue
.
add
(
req
,
self
.
model_config
.
num_key_value_heads
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
disagg_decode_prealloc_queue
.
add
(
req
)
else
:
...
...
@@ -1118,7 +1126,9 @@ class Scheduler(
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
]):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
disagg_prefill_bootstrap_queue
.
extend
(
reqs
)
self
.
disagg_prefill_bootstrap_queue
.
extend
(
reqs
,
self
.
model_config
.
num_key_value_heads
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
# If this is a decode server, we put the request to the decode pending prealloc queue
self
.
disagg_decode_prealloc_queue
.
extend
(
reqs
)
...
...
python/sglang/srt/server_args.py
View file @
7d316991
...
...
@@ -227,6 +227,9 @@ class ServerArgs:
disaggregation_mode
:
str
=
"null"
disaggregation_transfer_backend
:
str
=
"mooncake"
disaggregation_bootstrap_port
:
int
=
8998
disaggregation_decode_tp
:
Optional
[
int
]
=
None
disaggregation_decode_dp
:
Optional
[
int
]
=
None
disaggregation_prefill_pp
:
Optional
[
int
]
=
1
disaggregation_ib_device
:
Optional
[
str
]
=
None
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
pdlb_url
:
Optional
[
str
]
=
None
...
...
@@ -505,12 +508,27 @@ class ServerArgs:
self
.
triton_attention_num_kv_splits
=
16
# PD disaggregation
if
self
.
disaggregation_mode
==
"prefill"
:
self
.
disable_cuda_graph
=
True
logger
.
warning
(
"Cuda graph is disabled for prefill server"
)
elif
self
.
disaggregation_mode
==
"decode"
:
if
self
.
disaggregation_mode
==
"decode"
:
assert
(
self
.
disaggregation_decode_tp
is
None
),
"Cannot set --disaggregation-decode-tp for the decode engine."
assert
(
self
.
disaggregation_decode_dp
is
None
),
"Cannot set --disaggregation-decode-dp for the decode engine."
self
.
disable_radix_cache
=
True
logger
.
warning
(
"KV cache is forced as chunk cache for decode server"
)
elif
self
.
disaggregation_mode
==
"prefill"
:
if
self
.
disaggregation_decode_tp
is
None
:
self
.
disaggregation_decode_tp
=
self
.
tp_size
if
self
.
disaggregation_decode_dp
is
None
:
self
.
disaggregation_decode_dp
=
self
.
dp_size
self
.
disaggregation_prefill_pp
=
self
.
pp_size
self
.
validate_disagg_tp_size
(
self
.
tp_size
,
self
.
disaggregation_decode_tp
)
self
.
disable_cuda_graph
=
True
logger
.
warning
(
"Cuda graph is disabled for prefill server"
)
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
"1"
if
self
.
enable_torch_compile
else
"0"
...
...
@@ -520,6 +538,14 @@ class ServerArgs:
"1"
if
self
.
disable_outlines_disk_cache
else
"0"
)
def
validate_disagg_tp_size
(
self
,
prefill_tp
:
int
,
decode_tp
:
int
):
larger_tp
=
max
(
decode_tp
,
prefill_tp
)
smaller_tp
=
min
(
decode_tp
,
prefill_tp
)
assert
larger_tp
%
smaller_tp
==
0
,
(
"Different tp size is supported only when one tp is multiple of the other. "
f
"decode_tp=
{
decode_tp
}
, prefill_tp=
{
prefill_tp
}
"
)
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
# Model and port args
...
...
@@ -1512,6 +1538,24 @@ class ServerArgs:
default
=
ServerArgs
.
disaggregation_bootstrap_port
,
help
=
"Bootstrap server port on the prefill server. Default is 8998."
,
)
parser
.
add_argument
(
"--disaggregation-decode-tp"
,
type
=
int
,
default
=
ServerArgs
.
disaggregation_decode_tp
,
help
=
"Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server."
,
)
parser
.
add_argument
(
"--disaggregation-decode-dp"
,
type
=
int
,
default
=
ServerArgs
.
disaggregation_decode_dp
,
help
=
"Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server."
,
)
parser
.
add_argument
(
"--disaggregation-prefill-pp"
,
type
=
int
,
default
=
ServerArgs
.
disaggregation_prefill_pp
,
help
=
"Prefill pp size. If not set, it is default to 1. This is only set on the decode server."
,
)
parser
.
add_argument
(
"--disaggregation-ib-device"
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment