Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
305c9e8c
Unverified
Commit
305c9e8c
authored
Sep 15, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 15, 2025
Browse files
[4/N]DP refactor: support watching mode `get_load` and shortest queue strategy (#10201)
parent
ca63f075
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
202 additions
and
44 deletions
+202
-44
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+5
-2
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+74
-19
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+18
-0
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+4
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+23
-11
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+1
-1
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+41
-9
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+19
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+7
-0
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
305c9e8c
...
@@ -27,7 +27,7 @@ import tempfile
...
@@ -27,7 +27,7 @@ import tempfile
import
threading
import
threading
import
time
import
time
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Any
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
setproctitle
import
setproctitle
...
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
from
sglang.srt.managers.multi_tokenizer_mixin
import
(
MultiTokenizerManager
,
MultiTokenizerManager
,
MultiTokenizerRouter
,
get_main_process_id
,
get_main_process_id
,
monkey_patch_uvicorn_multiprocessing
,
monkey_patch_uvicorn_multiprocessing
,
read_from_shared_memory
,
read_from_shared_memory
,
...
@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
...
@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
# Store global states
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
_GlobalState
:
class
_GlobalState
:
tokenizer_manager
:
TokenizerManager
tokenizer_manager
:
Union
[
TokenizerManager
,
MultiTokenizerRouter
,
MultiTokenizerManager
]
template_manager
:
TemplateManager
template_manager
:
TemplateManager
scheduler_info
:
Dict
scheduler_info
:
Dict
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
305c9e8c
...
@@ -21,6 +21,7 @@ import struct
...
@@ -21,6 +21,7 @@ import struct
import
sys
import
sys
import
threading
import
threading
import
time
import
time
from
collections
import
deque
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
...
@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import (
BlockReqInput
,
BlockReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
WatchLoadUpdateReq
,
)
)
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.scheduler
import
run_scheduler_process
...
@@ -46,7 +48,7 @@ from sglang.srt.utils import (
...
@@ -46,7 +48,7 @@ from sglang.srt.utils import (
get_zmq_socket
,
get_zmq_socket
,
kill_itself_when_parent_died
,
kill_itself_when_parent_died
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum):
...
@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum):
raise
ValueError
(
f
"Invalid load balance method:
{
method
}
"
)
from
exc
raise
ValueError
(
f
"Invalid load balance method:
{
method
}
"
)
from
exc
class
DPBudget
:
def
__init__
(
self
):
# TODO: support minimum tokens method
self
.
budget_queue
=
deque
()
def
update_budget
(
self
,
load_update
:
WatchLoadUpdateReq
):
"""Update the budget queue.
Use num_reqs instead of num_waiting_reqs to balance decode running batch.
"""
loads
=
load_update
.
loads
self
.
budget_queue
.
clear
()
num_reqs
=
[
load
.
num_reqs
for
load
in
loads
]
if
not
num_reqs
:
return
max_num_reqs
=
max
(
num_reqs
)
if
all
(
x
==
max_num_reqs
for
x
in
num_reqs
):
return
while
any
(
x
!=
num_reqs
[
0
]
for
x
in
num_reqs
):
min_load
=
min
(
num_reqs
)
min_indices
=
[
i
for
i
,
x
in
enumerate
(
num_reqs
)
if
x
==
min_load
]
second_min_load
=
min
(
x
for
x
in
num_reqs
if
x
>
min_load
)
self
.
budget_queue
.
extend
(
[
loads
[
i
].
dp_rank
for
i
in
min_indices
]
*
(
second_min_load
-
min_load
)
)
for
idx
in
min_indices
:
num_reqs
[
idx
]
=
second_min_load
def
dispatch
(
self
):
if
self
.
budget_queue
:
return
self
.
budget_queue
.
popleft
()
return
None
class
DataParallelController
:
class
DataParallelController
:
"""A controller that dispatches requests to multiple data parallel workers."""
"""A controller that dispatches requests to multiple data parallel workers."""
...
@@ -104,6 +142,9 @@ class DataParallelController:
...
@@ -104,6 +142,9 @@ class DataParallelController:
}
}
self
.
dispatching
=
dispatch_lookup
[
self
.
load_balance_method
]
self
.
dispatching
=
dispatch_lookup
[
self
.
load_balance_method
]
# Load balance budget
self
.
dp_budget
=
DPBudget
()
# Launch data parallel workers
# Launch data parallel workers
self
.
scheduler_procs
=
[]
self
.
scheduler_procs
=
[]
self
.
workers
:
List
[
zmq
.
Socket
]
=
[
None
]
*
server_args
.
dp_size
self
.
workers
:
List
[
zmq
.
Socket
]
=
[
None
]
*
server_args
.
dp_size
...
@@ -127,6 +168,31 @@ class DataParallelController:
...
@@ -127,6 +168,31 @@ class DataParallelController:
self
.
max_req_input_len
=
None
self
.
max_req_input_len
=
None
self
.
init_dispatcher
()
def
send_to_all_workers
(
self
,
obj
):
for
worker
in
self
.
workers
:
worker
.
send_pyobj
(
obj
)
def
send_control_message
(
self
,
obj
):
# Send control messages to first worker of tp group
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
worker
.
send_pyobj
(
obj
)
def
handle_load_update_req
(
self
,
obj
):
self
.
dp_budget
.
update_budget
(
obj
)
def
init_dispatcher
(
self
):
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
(
TokenizedGenerateReqInput
,
self
.
dispatching
),
(
TokenizedEmbeddingReqInput
,
self
.
dispatching
),
(
BlockReqInput
,
self
.
send_to_all_workers
),
(
WatchLoadUpdateReq
,
self
.
handle_load_update_req
),
]
)
self
.
_request_dispatcher
.
add_fallback_fn
(
self
.
send_control_message
)
def
launch_dp_schedulers
(
self
,
server_args
,
port_args
):
def
launch_dp_schedulers
(
self
,
server_args
,
port_args
):
base_gpu_id
=
0
base_gpu_id
=
0
...
@@ -291,10 +357,14 @@ class DataParallelController:
...
@@ -291,10 +357,14 @@ class DataParallelController:
else
:
else
:
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
def
shortest_queue_scheduler
(
self
,
input_requests
):
def
shortest_queue_scheduler
(
self
,
req
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
return
return
raise
NotImplementedError
()
target_worker
=
self
.
dp_budget
.
dispatch
()
if
target_worker
is
None
:
self
.
round_robin_scheduler
(
req
)
else
:
self
.
workers
[
target_worker
].
send_pyobj
(
req
)
def
minimum_tokens_scheduler
(
self
,
req
):
def
minimum_tokens_scheduler
(
self
,
req
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
...
@@ -333,22 +403,7 @@ class DataParallelController:
...
@@ -333,22 +403,7 @@ class DataParallelController:
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
recv_req
=
self
.
recv_from_tokenizer
.
recv_pyobj
(
zmq
.
NOBLOCK
)
except
zmq
.
ZMQError
:
except
zmq
.
ZMQError
:
break
break
self
.
_request_dispatcher
(
recv_req
)
if
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
),
):
self
.
dispatching
(
recv_req
)
elif
isinstance
(
recv_req
,
BlockReqInput
):
for
worker
in
self
.
workers
:
worker
.
send_pyobj
(
recv_req
)
else
:
# Send other control messages to first worker of tp group
for
worker
in
self
.
workers
[::
self
.
control_message_step
]:
worker
.
send_pyobj
(
recv_req
)
def
run_data_parallel_controller_process
(
def
run_data_parallel_controller_process
(
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
305c9e8c
...
@@ -297,7 +297,7 @@ def run_detokenizer_process(
...
@@ -297,7 +297,7 @@ def run_detokenizer_process(
else
:
else
:
manager
.
event_loop
()
manager
.
event_loop
()
except
Exception
:
except
Exception
:
manager
.
socket_mapping
.
clear_
all_
socket
s
()
manager
.
maybe_
clear_socket
_mapping
()
traceback
=
get_exception_traceback
()
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"DetokenizerManager hit an exception:
{
traceback
}
"
)
logger
.
error
(
f
"DetokenizerManager hit an exception:
{
traceback
}
"
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
python/sglang/srt/managers/io_struct.py
View file @
305c9e8c
...
@@ -1374,3 +1374,21 @@ class BlockReqType(Enum):
...
@@ -1374,3 +1374,21 @@ class BlockReqType(Enum):
@
dataclass
@
dataclass
class
BlockReqInput
:
class
BlockReqInput
:
type
:
BlockReqType
type
:
BlockReqType
@
dataclass
class
GetLoadReqInput
:
pass
@
dataclass
class
GetLoadReqOutput
:
dp_rank
:
int
num_reqs
:
int
num_waiting_reqs
:
int
num_tokens
:
int
@
dataclass
class
WatchLoadUpdateReq
:
loads
:
List
[
GetLoadReqOutput
]
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
305c9e8c
...
@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin:
...
@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin:
worker_ids
=
[]
worker_ids
=
[]
return
worker_ids
return
worker_ids
def
maybe_clear_socket_mapping
(
self
):
if
hasattr
(
self
,
"socket_mapping"
):
self
.
socket_mapping
.
clear_all_sockets
()
def
multi_http_worker_event_loop
(
self
):
def
multi_http_worker_event_loop
(
self
):
"""The event loop that handles requests, for multi multi-http-worker mode"""
"""The event loop that handles requests, for multi multi-http-worker mode"""
self
.
socket_mapping
=
SocketMapping
()
self
.
socket_mapping
=
SocketMapping
()
...
...
python/sglang/srt/managers/scheduler.py
View file @
305c9e8c
...
@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq
,
FreezeGCReq
,
GetInternalStateReq
,
GetInternalStateReq
,
GetInternalStateReqOutput
,
GetInternalStateReqOutput
,
GetLoadReqInput
,
GetLoadReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
HealthCheckOutput
,
HealthCheckOutput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
...
@@ -577,6 +579,7 @@ class Scheduler(
...
@@ -577,6 +579,7 @@ class Scheduler(
(
LoadLoRAAdapterReqInput
,
self
.
load_lora_adapter
),
(
LoadLoRAAdapterReqInput
,
self
.
load_lora_adapter
),
(
UnloadLoRAAdapterReqInput
,
self
.
unload_lora_adapter
),
(
UnloadLoRAAdapterReqInput
,
self
.
unload_lora_adapter
),
(
MultiTokenizerRegisterReq
,
self
.
register_multi_tokenizer
),
(
MultiTokenizerRegisterReq
,
self
.
register_multi_tokenizer
),
(
GetLoadReqInput
,
self
.
get_load
),
]
]
)
)
...
@@ -2279,39 +2282,50 @@ class Scheduler(
...
@@ -2279,39 +2282,50 @@ class Scheduler(
if_success
=
False
if_success
=
False
return
if_success
return
if_success
def
get_load
(
self
)
:
def
get_load
(
self
,
recv_req
:
GetLoadReqInput
=
None
)
->
GetLoadReqOutput
:
# TODO(lsyin): use dynamically maintained num_waiting_tokens
# TODO(lsyin): use dynamically maintained num_waiting_tokens
if
self
.
is_hybrid
:
if
self
.
is_hybrid
:
load
_full
=
(
num_tokens
_full
=
(
self
.
full_tokens_per_layer
self
.
full_tokens_per_layer
-
self
.
token_to_kv_pool_allocator
.
full_available_size
()
-
self
.
token_to_kv_pool_allocator
.
full_available_size
()
-
self
.
tree_cache
.
full_evictable_size
()
-
self
.
tree_cache
.
full_evictable_size
()
)
)
load
_swa
=
(
num_tokens
_swa
=
(
self
.
swa_tokens_per_layer
self
.
swa_tokens_per_layer
-
self
.
token_to_kv_pool_allocator
.
swa_available_size
()
-
self
.
token_to_kv_pool_allocator
.
swa_available_size
()
-
self
.
tree_cache
.
swa_evictable_size
()
-
self
.
tree_cache
.
swa_evictable_size
()
)
)
load
=
max
(
load_full
,
load
_swa
)
num_tokens
=
max
(
num_tokens_full
,
num_tokens
_swa
)
else
:
else
:
load
=
(
num_tokens
=
(
self
.
max_total_num_tokens
self
.
max_total_num_tokens
-
self
.
token_to_kv_pool_allocator
.
available_size
()
-
self
.
token_to_kv_pool_allocator
.
available_size
()
-
self
.
tree_cache
.
evictable_size
()
-
self
.
tree_cache
.
evictable_size
()
)
)
load
+=
sum
(
len
(
req
.
origin_input_ids
)
for
req
in
self
.
waiting_queue
)
# Tokens in waiting queue, bootstrap queue, prealloc queue
num_tokens
+=
sum
(
len
(
req
.
origin_input_ids
)
for
req
in
self
.
waiting_queue
)
num_waiting_reqs
=
len
(
self
.
waiting_queue
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
load
+=
sum
(
num_tokens
+=
sum
(
len
(
req
.
origin_input_ids
)
len
(
req
.
origin_input_ids
)
for
req
in
self
.
disagg_prefill_bootstrap_queue
.
queue
for
req
in
self
.
disagg_prefill_bootstrap_queue
.
queue
)
)
num_waiting_reqs
+=
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
load
+=
sum
(
num_tokens
+=
sum
(
len
(
req
.
req
.
origin_input_ids
)
len
(
req
.
req
.
origin_input_ids
)
for
req
in
self
.
disagg_decode_prealloc_queue
.
queue
for
req
in
self
.
disagg_decode_prealloc_queue
.
queue
)
)
num_waiting_reqs
+=
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
return
load
return
GetLoadReqOutput
(
dp_rank
=
self
.
dp_rank
,
num_reqs
=
len
(
self
.
running_batch
.
reqs
)
+
num_waiting_reqs
,
num_waiting_reqs
=
num_waiting_reqs
,
num_tokens
=
num_tokens
,
)
def
get_internal_state
(
self
,
recv_req
:
GetInternalStateReq
):
def
get_internal_state
(
self
,
recv_req
:
GetInternalStateReq
):
ret
=
dict
(
global_server_args_dict
)
ret
=
dict
(
global_server_args_dict
)
...
@@ -2337,8 +2351,6 @@ class Scheduler(
...
@@ -2337,8 +2351,6 @@ class Scheduler(
if
RECORD_STEP_TIME
:
if
RECORD_STEP_TIME
:
ret
[
"step_time_dict"
]
=
self
.
step_time_dict
ret
[
"step_time_dict"
]
=
self
.
step_time_dict
ret
[
"load"
]
=
self
.
get_load
()
return
GetInternalStateReqOutput
(
internal_state
=
ret
)
return
GetInternalStateReqOutput
(
internal_state
=
ret
)
def
set_internal_state
(
self
,
recv_req
:
SetInternalStateReq
):
def
set_internal_state
(
self
,
recv_req
:
SetInternalStateReq
):
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
305c9e8c
...
@@ -279,7 +279,7 @@ class SchedulerMetricsMixin:
...
@@ -279,7 +279,7 @@ class SchedulerMetricsMixin:
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
self
.
server_args
.
load_balance_method
==
"minimum_tokens"
and
self
.
forward_ct
%
40
==
0
and
self
.
forward_ct
%
40
==
0
):
):
holding_tokens
=
self
.
get_load
()
holding_tokens
=
self
.
get_load
()
.
num_tokens
new_recv_dp_balance_id_list
,
holding_token_list
=
(
new_recv_dp_balance_id_list
,
holding_token_list
=
(
self
.
gather_dp_balance_info
(
holding_tokens
)
self
.
gather_dp_balance_info
(
holding_tokens
)
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
View file @
305c9e8c
from
__future__
import
annotations
from
__future__
import
annotations
import
asyncio
import
asyncio
import
copy
import
logging
import
logging
import
os
import
os
import
time
import
time
...
@@ -18,6 +19,7 @@ from typing import (
...
@@ -18,6 +19,7 @@ from typing import (
)
)
import
fastapi
import
fastapi
import
zmq
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
ClearHiCacheReqInput
,
ClearHiCacheReqInput
,
...
@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReqOutput
,
FlushCacheReqOutput
,
GetInternalStateReq
,
GetInternalStateReq
,
GetInternalStateReqOutput
,
GetInternalStateReqOutput
,
GetLoadReqInput
,
GetLoadReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
...
@@ -75,14 +79,17 @@ class _Communicator(Generic[T]):
...
@@ -75,14 +79,17 @@ class _Communicator(Generic[T]):
enable_multi_tokenizer
=
False
enable_multi_tokenizer
=
False
def
__init__
(
self
,
sender
,
fan_out
:
int
):
def
__init__
(
self
,
sender
:
zmq
.
Socket
,
fan_out
:
int
,
mode
=
"queueing"
):
self
.
_sender
=
sender
self
.
_sender
=
sender
self
.
_fan_out
=
fan_out
self
.
_fan_out
=
fan_out
self
.
_mode
=
mode
self
.
_result_event
:
Optional
[
asyncio
.
Event
]
=
None
self
.
_result_event
:
Optional
[
asyncio
.
Event
]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
self
.
_ready_queue
:
Deque
[
asyncio
.
Future
]
=
deque
()
self
.
_ready_queue
:
Deque
[
asyncio
.
Future
]
=
deque
()
async
def
__call__
(
self
,
obj
):
assert
mode
in
[
"queueing"
,
"watching"
]
async
def
queueing_call
(
self
,
obj
:
T
):
ready_event
=
asyncio
.
Event
()
ready_event
=
asyncio
.
Event
()
if
self
.
_result_event
is
not
None
or
len
(
self
.
_ready_queue
)
>
0
:
if
self
.
_result_event
is
not
None
or
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
append
(
ready_event
)
self
.
_ready_queue
.
append
(
ready_event
)
...
@@ -106,6 +113,28 @@ class _Communicator(Generic[T]):
...
@@ -106,6 +113,28 @@ class _Communicator(Generic[T]):
return
result_values
return
result_values
async
def
watching_call
(
self
,
obj
):
if
self
.
_result_event
is
None
:
assert
self
.
_result_values
is
None
self
.
_result_values
=
[]
self
.
_result_event
=
asyncio
.
Event
()
if
obj
:
if
_Communicator
.
enable_multi_tokenizer
:
obj
=
MultiTokenizerWrapper
(
worker_id
=
os
.
getpid
(),
obj
=
obj
)
self
.
_sender
.
send_pyobj
(
obj
)
await
self
.
_result_event
.
wait
()
result_values
=
copy
.
deepcopy
(
self
.
_result_values
)
self
.
_result_event
=
self
.
_result_values
=
None
return
result_values
async
def
__call__
(
self
,
obj
):
if
self
.
_mode
==
"queueing"
:
return
await
self
.
queueing_call
(
obj
)
else
:
return
await
self
.
watching_call
(
obj
)
def
handle_recv
(
self
,
recv_obj
:
T
):
def
handle_recv
(
self
,
recv_obj
:
T
):
self
.
_result_values
.
append
(
recv_obj
)
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
...
@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin:
...
@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin:
self
.
update_lora_adapter_communicator
=
_Communicator
(
self
.
update_lora_adapter_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
self
.
get_load_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
,
mode
=
"watching"
)
self
.
_result_dispatcher
+=
self
.
_get_communicator_dispatcher
()
self
.
_result_dispatcher
+=
self
.
_get_communicator_dispatcher
()
...
@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin:
...
@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin:
LoRAUpdateResult
,
LoRAUpdateResult
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
),
(
GetLoadReqOutput
,
self
.
get_load_communicator
.
handle_recv
,
),
]
]
)
)
...
@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin:
...
@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin:
)
)
return
[
res
.
updated
for
res
in
responses
]
return
[
res
.
updated
for
res
in
responses
]
async
def
get_load
(
self
:
TokenizerManager
)
->
dict
:
async
def
get_load
(
self
:
TokenizerManager
)
->
List
[
GetLoadReqOutput
]:
# TODO(lsyin): fake load report server
req
=
GetLoadReqInput
()
if
not
self
.
current_load_lock
.
locked
():
return
await
self
.
get_load_communicator
(
req
)
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
python/sglang/srt/managers/tokenizer_manager.py
View file @
305c9e8c
...
@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput
,
EmbeddingReqInput
,
FreezeGCReq
,
FreezeGCReq
,
GenerateReqInput
,
GenerateReqInput
,
GetLoadReqInput
,
HealthCheckOutput
,
HealthCheckOutput
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
OpenSessionReqInput
,
...
@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
WatchLoadUpdateReq
,
)
)
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
...
@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self
.
asyncio_tasks
.
add
(
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
)
)
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
print_exception_wrapper
(
self
.
watch_load_thread
))
)
def
dump_requests_before_crash
(
self
):
def
dump_requests_before_crash
(
self
):
if
self
.
crash_dump_performed
:
if
self
.
crash_dump_performed
:
...
@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return
scores
return
scores
async
def
watch_load_thread
(
self
):
# Only for dp_controller when dp_size > 1
if
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
load_balance_method
==
"round_robin"
):
return
while
True
:
await
asyncio
.
sleep
(
self
.
server_args
.
load_watch_interval
)
loads
=
await
self
.
get_load_communicator
(
GetLoadReqInput
())
load_udpate_req
=
WatchLoadUpdateReq
(
loads
=
loads
)
self
.
send_to_scheduler
.
send_pyobj
(
load_udpate_req
)
class
ServerStatus
(
Enum
):
class
ServerStatus
(
Enum
):
Up
=
"Up"
Up
=
"Up"
...
...
python/sglang/srt/server_args.py
View file @
305c9e8c
...
@@ -233,6 +233,7 @@ class ServerArgs:
...
@@ -233,6 +233,7 @@ class ServerArgs:
# Data parallelism
# Data parallelism
dp_size
:
int
=
1
dp_size
:
int
=
1
load_balance_method
:
str
=
"round_robin"
load_balance_method
:
str
=
"round_robin"
load_watch_interval
:
float
=
0.1
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance
:
bool
=
False
prefill_round_robin_balance
:
bool
=
False
...
@@ -663,6 +664,7 @@ class ServerArgs:
...
@@ -663,6 +664,7 @@ class ServerArgs:
if
self
.
dp_size
==
1
:
if
self
.
dp_size
==
1
:
self
.
enable_dp_attention
=
False
self
.
enable_dp_attention
=
False
self
.
enable_dp_lm_head
=
False
# Data parallelism attention
# Data parallelism attention
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
...
@@ -1488,6 +1490,12 @@ class ServerArgs:
...
@@ -1488,6 +1490,12 @@ class ServerArgs:
"minimum_tokens"
,
"minimum_tokens"
,
],
],
)
)
parser
.
add_argument
(
"--load-watch-interval"
,
type
=
float
,
default
=
ServerArgs
.
load_watch_interval
,
help
=
"The interval of load watching in seconds."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--prefill-round-robin-balance"
,
"--prefill-round-robin-balance"
,
default
=
ServerArgs
.
prefill_round_robin_balance
,
default
=
ServerArgs
.
prefill_round_robin_balance
,
...
...
python/sglang/srt/utils.py
View file @
305c9e8c
...
@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
...
@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
def
get_zmq_socket
(
def
get_zmq_socket
(
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
):
)
->
zmq
.
Socket
:
mem
=
psutil
.
virtual_memory
()
mem
=
psutil
.
virtual_memory
()
total_mem
=
mem
.
total
/
1024
**
3
total_mem
=
mem
.
total
/
1024
**
3
available_mem
=
mem
.
available
/
1024
**
3
available_mem
=
mem
.
available
/
1024
**
3
...
...
python/sglang/utils.py
View file @
305c9e8c
...
@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
...
@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
class
TypeBasedDispatcher
:
class
TypeBasedDispatcher
:
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
self
.
_mapping
=
mapping
self
.
_mapping
=
mapping
self
.
_fallback_fn
=
None
def
add_fallback_fn
(
self
,
fallback_fn
:
Callable
):
self
.
_fallback_fn
=
fallback_fn
def
__iadd__
(
self
,
other
:
"TypeBasedDispatcher"
):
def
__iadd__
(
self
,
other
:
"TypeBasedDispatcher"
):
self
.
_mapping
.
extend
(
other
.
_mapping
)
self
.
_mapping
.
extend
(
other
.
_mapping
)
...
@@ -481,6 +485,9 @@ class TypeBasedDispatcher:
...
@@ -481,6 +485,9 @@ class TypeBasedDispatcher:
for
ty
,
fn
in
self
.
_mapping
:
for
ty
,
fn
in
self
.
_mapping
:
if
isinstance
(
obj
,
ty
):
if
isinstance
(
obj
,
ty
):
return
fn
(
obj
)
return
fn
(
obj
)
if
self
.
_fallback_fn
is
not
None
:
return
self
.
_fallback_fn
(
obj
)
raise
ValueError
(
f
"Invalid object:
{
obj
}
"
)
raise
ValueError
(
f
"Invalid object:
{
obj
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment