Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
78f13981
Unverified
Commit
78f13981
authored
Sep 08, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 08, 2025
Browse files
[1/N] DP-Refactor: move communicators into `tokenizer_communicator_mixin` (#10028)
parent
bfd7a18d
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
503 additions
and
459 deletions
+503
-459
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+2
-1
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+491
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-457
python/sglang/srt/weight_sync/utils.py
python/sglang/srt/weight_sync/utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+4
-0
No files found.
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
78f13981
...
@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
MultiTokenizerRegisterReq
,
MultiTokenizerRegisterReq
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
)
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
,
_Communicator
from
sglang.srt.managers.tokenizer_communicator_mixin
import
_Communicator
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process_tree
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process_tree
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
0 → 100644
View file @
78f13981
from
__future__
import
annotations
import
asyncio
import
logging
import
os
import
time
from
collections
import
deque
from
typing
import
(
TYPE_CHECKING
,
Any
,
Deque
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
TypeVar
,
)
import
fastapi
from
sglang.srt.managers.io_struct
import
(
ClearHiCacheReqInput
,
ClearHiCacheReqOutput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
GetInternalStateReq
,
GetInternalStateReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdateResult
,
MultiTokenizerWrapper
,
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
SlowDownReqInput
,
SlowDownReqOutput
,
UnloadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.server_args
import
LoRARef
,
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.utils
import
TypeBasedDispatcher
if
TYPE_CHECKING
:
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
T
=
TypeVar
(
"T"
)
logger
=
logging
.
getLogger
(
__name__
)
class
_Communicator
(
Generic
[
T
]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer
=
False
def
__init__
(
self
,
sender
,
fan_out
:
int
):
self
.
_sender
=
sender
self
.
_fan_out
=
fan_out
self
.
_result_event
:
Optional
[
asyncio
.
Event
]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
self
.
_ready_queue
:
Deque
[
asyncio
.
Future
]
=
deque
()
async
def
__call__
(
self
,
obj
):
ready_event
=
asyncio
.
Event
()
if
self
.
_result_event
is
not
None
or
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
append
(
ready_event
)
await
ready_event
.
wait
()
assert
self
.
_result_event
is
None
assert
self
.
_result_values
is
None
if
obj
:
if
_Communicator
.
enable_multi_tokenizer
:
obj
=
MultiTokenizerWrapper
(
worker_id
=
os
.
getpid
(),
obj
=
obj
)
self
.
_sender
.
send_pyobj
(
obj
)
self
.
_result_event
=
asyncio
.
Event
()
self
.
_result_values
=
[]
await
self
.
_result_event
.
wait
()
result_values
=
self
.
_result_values
self
.
_result_event
=
self
.
_result_values
=
None
if
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
popleft
().
set
()
return
result_values
def
handle_recv
(
self
,
recv_obj
:
T
):
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
self
.
_result_event
.
set
()
class
TokenizerCommunicatorMixin
:
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
def
init_communicators
(
self
:
TokenizerManager
,
server_args
:
ServerArgs
):
# Communicators
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_tensor_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
release_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
resume_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
slow_down_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
flush_cache_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
clear_hicache_storage_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
profile_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
set_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
expert_distribution_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_lora_adapter_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
_result_dispatcher
+=
self
.
_get_communicator_dispatcher
()
def
_get_communicator_dispatcher
(
self
:
TokenizerManager
):
return
TypeBasedDispatcher
(
[
(
InitWeightsUpdateGroupReqOutput
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
),
(
UpdateWeightsFromDistributedReqOutput
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
),
(
UpdateWeightsFromTensorReqOutput
,
self
.
update_weights_from_tensor_communicator
.
handle_recv
,
),
(
GetWeightsByNameReqOutput
,
self
.
get_weights_by_name_communicator
.
handle_recv
,
),
(
ReleaseMemoryOccupationReqOutput
,
self
.
release_memory_occupation_communicator
.
handle_recv
,
),
(
ResumeMemoryOccupationReqOutput
,
self
.
resume_memory_occupation_communicator
.
handle_recv
,
),
(
SlowDownReqOutput
,
self
.
slow_down_communicator
.
handle_recv
,
),
(
ClearHiCacheReqOutput
,
self
.
clear_hicache_storage_communicator
.
handle_recv
,
),
(
FlushCacheReqOutput
,
self
.
flush_cache_communicator
.
handle_recv
,
),
(
ProfileReqOutput
,
self
.
profile_communicator
.
handle_recv
,
),
(
GetInternalStateReqOutput
,
self
.
get_internal_state_communicator
.
handle_recv
,
),
(
SetInternalStateReqOutput
,
self
.
set_internal_state_communicator
.
handle_recv
,
),
(
ExpertDistributionReqOutput
,
self
.
expert_distribution_communicator
.
handle_recv
,
),
(
LoRAUpdateResult
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
]
)
async
def
flush_cache
(
self
:
TokenizerManager
)
->
FlushCacheReqOutput
:
return
(
await
self
.
flush_cache_communicator
(
FlushCacheReqInput
()))[
0
]
async
def
clear_hicache_storage
(
self
:
TokenizerManager
)
->
ClearHiCacheReqOutput
:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return
(
await
self
.
clear_hicache_storage_communicator
(
ClearHiCacheReqInput
()))[
0
]
async
def
start_profile
(
self
:
TokenizerManager
,
output_dir
:
Optional
[
str
]
=
None
,
start_step
:
Optional
[
int
]
=
None
,
num_steps
:
Optional
[
int
]
=
None
,
activities
:
Optional
[
List
[
str
]]
=
None
,
with_stack
:
Optional
[
bool
]
=
None
,
record_shapes
:
Optional
[
bool
]
=
None
,
profile_by_stage
:
bool
=
False
,
):
self
.
auto_create_handle_loop
()
env_with_stack
:
bool
=
get_bool_env_var
(
"SGLANG_PROFILE_WITH_STACK"
,
"true"
)
with_stack
=
False
if
with_stack
is
False
or
env_with_stack
is
False
else
True
req
=
ProfileReq
(
type
=
ProfileReqType
.
START_PROFILE
,
output_dir
=
output_dir
,
start_step
=
start_step
,
num_steps
=
num_steps
,
activities
=
activities
,
with_stack
=
with_stack
,
record_shapes
=
record_shapes
,
profile_by_stage
=
profile_by_stage
,
profile_id
=
str
(
time
.
time
()),
)
return
await
self
.
_execute_profile
(
req
)
async
def
stop_profile
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
req
=
ProfileReq
(
type
=
ProfileReqType
.
STOP_PROFILE
)
return
await
self
.
_execute_profile
(
req
)
async
def
_execute_profile
(
self
:
TokenizerManager
,
req
:
ProfileReq
):
result
=
(
await
self
.
profile_communicator
(
req
))[
0
]
if
not
result
.
success
:
raise
RuntimeError
(
result
.
message
)
return
result
async
def
start_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
START_RECORD
)
async
def
stop_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
STOP_RECORD
)
async
def
dump_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
async
def
init_weights_update_group
(
self
:
TokenizerManager
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
result
=
(
await
self
.
init_weights_update_group_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
self
:
TokenizerManager
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
enable_dp_attention
),
"dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if
obj
.
abort_all_requests
:
self
.
abort_request
(
abort_all
=
True
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_distributed_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_tensor
(
self
:
TokenizerManager
,
obj
:
UpdateWeightsFromTensorReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
enable_dp_attention
),
"dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if
obj
.
abort_all_requests
:
self
.
abort_request
(
abort_all
=
True
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_tensor_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
load_lora_adapter
(
self
:
TokenizerManager
,
obj
:
LoadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
LoadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
try
:
if
not
self
.
server_args
.
enable_lora
:
raise
ValueError
(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start load Lora adapter. Lora name=%s, path=%s"
,
obj
.
lora_name
,
obj
.
lora_path
,
)
async
with
self
.
lora_update_lock
:
if
(
self
.
server_args
.
max_loaded_loras
is
not
None
and
self
.
lora_registry
.
num_registered_loras
>=
self
.
server_args
.
max_loaded_loras
):
raise
ValueError
(
f
"Cannot load LoRA adapter
{
obj
.
lora_name
}
at path
{
obj
.
lora_path
}
. "
f
"Maximum number of loaded LoRA adapters is
{
self
.
server_args
.
max_loaded_loras
}
. "
"Please unload some LoRA adapters before loading new ones."
)
# Generate new uniquely identifiable LoRARef object.
new_adapter
=
LoRARef
(
lora_name
=
obj
.
lora_name
,
lora_path
=
obj
.
lora_path
,
pinned
=
obj
.
pinned
,
)
# Trigger the actual loading operation at the backend processes.
obj
.
lora_id
=
new_adapter
.
lora_id
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
# Register the LoRA adapter only after loading is successful.
if
result
.
success
:
await
self
.
lora_registry
.
register
(
new_adapter
)
return
result
except
ValueError
as
e
:
return
LoadLoRAAdapterReqOutput
(
success
=
False
,
error_message
=
str
(
e
),
)
async
def
unload_lora_adapter
(
self
:
TokenizerManager
,
obj
:
UnloadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
UnloadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
try
:
if
not
self
.
server_args
.
enable_lora
:
raise
ValueError
(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert
(
obj
.
lora_name
is
not
None
),
"lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start unload Lora adapter. Lora name=%s"
,
obj
.
lora_name
,
)
async
with
self
.
lora_update_lock
:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id
=
await
self
.
lora_registry
.
unregister
(
obj
.
lora_name
)
obj
.
lora_id
=
lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await
self
.
lora_registry
.
wait_for_unload
(
lora_id
)
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
return
result
except
ValueError
as
e
:
return
UnloadLoRAAdapterReqOutput
(
success
=
False
,
error_message
=
str
(
e
))
async
def
get_weights_by_name
(
self
:
TokenizerManager
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
results
=
await
self
.
get_weights_by_name_communicator
(
obj
)
all_parameters
=
[
r
.
parameter
for
r
in
results
]
if
self
.
server_args
.
dp_size
==
1
:
return
all_parameters
[
0
]
else
:
return
all_parameters
async
def
release_memory_occupation
(
self
:
TokenizerManager
,
obj
:
ReleaseMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
release_memory_occupation_communicator
(
obj
)
async
def
resume_memory_occupation
(
self
:
TokenizerManager
,
obj
:
ResumeMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
resume_memory_occupation_communicator
(
obj
)
async
def
slow_down
(
self
:
TokenizerManager
,
obj
:
SlowDownReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
slow_down_communicator
(
obj
)
async
def
get_internal_state
(
self
:
TokenizerManager
)
->
List
[
Dict
[
Any
,
Any
]]:
req
=
GetInternalStateReq
()
responses
:
List
[
GetInternalStateReqOutput
]
=
(
await
self
.
get_internal_state_communicator
(
req
)
)
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
async
def
set_internal_state
(
self
:
TokenizerManager
,
obj
:
SetInternalStateReq
)
->
List
[
bool
]:
responses
:
List
[
SetInternalStateReqOutput
]
=
(
await
self
.
set_internal_state_communicator
(
obj
)
)
return
[
res
.
updated
for
res
in
responses
]
async
def
get_load
(
self
:
TokenizerManager
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
python/sglang/srt/managers/tokenizer_manager.py
View file @
78f13981
This diff is collapsed.
Click to expand it.
python/sglang/srt/weight_sync/utils.py
View file @
78f13981
...
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
...
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
from
torch.distributed.tensor
import
DTensor
from
torch.distributed.tensor
import
DTensor
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.managers.
tokenizer_manager
import
UpdateWeightsFromTensorReqInput
from
sglang.srt.managers.
io_struct
import
UpdateWeightsFromTensorReqInput
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.utils
import
MultiprocessingSerializer
from
sglang.srt.utils
import
MultiprocessingSerializer
...
...
python/sglang/utils.py
View file @
78f13981
...
@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
...
@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
self
.
_mapping
=
mapping
self
.
_mapping
=
mapping
def
__iadd__
(
self
,
other
:
"TypeBasedDispatcher"
):
self
.
_mapping
.
extend
(
other
.
_mapping
)
return
self
def
__call__
(
self
,
obj
:
Any
):
def
__call__
(
self
,
obj
:
Any
):
for
ty
,
fn
in
self
.
_mapping
:
for
ty
,
fn
in
self
.
_mapping
:
if
isinstance
(
obj
,
ty
):
if
isinstance
(
obj
,
ty
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment