Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
78f13981
Unverified
Commit
78f13981
authored
Sep 08, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 08, 2025
Browse files
[1/N] DP-Refactor: move communicators into `tokenizer_communicator_mixin` (#10028)
parent
bfd7a18d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
503 additions
and
459 deletions
+503
-459
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+2
-1
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+491
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-457
python/sglang/srt/weight_sync/utils.py
python/sglang/srt/weight_sync/utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+4
-0
No files found.
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
78f13981
...
@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
MultiTokenizerRegisterReq
,
MultiTokenizerRegisterReq
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
)
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
,
_Communicator
from
sglang.srt.managers.tokenizer_communicator_mixin
import
_Communicator
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process_tree
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process_tree
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
0 → 100644
View file @
78f13981
from
__future__
import
annotations
import
asyncio
import
logging
import
os
import
time
from
collections
import
deque
from
typing
import
(
TYPE_CHECKING
,
Any
,
Deque
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
TypeVar
,
)
import
fastapi
from
sglang.srt.managers.io_struct
import
(
ClearHiCacheReqInput
,
ClearHiCacheReqOutput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
GetInternalStateReq
,
GetInternalStateReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdateResult
,
MultiTokenizerWrapper
,
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
SlowDownReqInput
,
SlowDownReqOutput
,
UnloadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.server_args
import
LoRARef
,
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.utils
import
TypeBasedDispatcher
if
TYPE_CHECKING
:
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
T
=
TypeVar
(
"T"
)
logger
=
logging
.
getLogger
(
__name__
)
class
_Communicator
(
Generic
[
T
]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer
=
False
def
__init__
(
self
,
sender
,
fan_out
:
int
):
self
.
_sender
=
sender
self
.
_fan_out
=
fan_out
self
.
_result_event
:
Optional
[
asyncio
.
Event
]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
self
.
_ready_queue
:
Deque
[
asyncio
.
Future
]
=
deque
()
async
def
__call__
(
self
,
obj
):
ready_event
=
asyncio
.
Event
()
if
self
.
_result_event
is
not
None
or
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
append
(
ready_event
)
await
ready_event
.
wait
()
assert
self
.
_result_event
is
None
assert
self
.
_result_values
is
None
if
obj
:
if
_Communicator
.
enable_multi_tokenizer
:
obj
=
MultiTokenizerWrapper
(
worker_id
=
os
.
getpid
(),
obj
=
obj
)
self
.
_sender
.
send_pyobj
(
obj
)
self
.
_result_event
=
asyncio
.
Event
()
self
.
_result_values
=
[]
await
self
.
_result_event
.
wait
()
result_values
=
self
.
_result_values
self
.
_result_event
=
self
.
_result_values
=
None
if
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
popleft
().
set
()
return
result_values
def
handle_recv
(
self
,
recv_obj
:
T
):
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
self
.
_result_event
.
set
()
class
TokenizerCommunicatorMixin
:
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
def
init_communicators
(
self
:
TokenizerManager
,
server_args
:
ServerArgs
):
# Communicators
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_tensor_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
release_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
resume_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
slow_down_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
flush_cache_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
clear_hicache_storage_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
profile_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
set_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
expert_distribution_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_lora_adapter_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
_result_dispatcher
+=
self
.
_get_communicator_dispatcher
()
def
_get_communicator_dispatcher
(
self
:
TokenizerManager
):
return
TypeBasedDispatcher
(
[
(
InitWeightsUpdateGroupReqOutput
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
),
(
UpdateWeightsFromDistributedReqOutput
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
),
(
UpdateWeightsFromTensorReqOutput
,
self
.
update_weights_from_tensor_communicator
.
handle_recv
,
),
(
GetWeightsByNameReqOutput
,
self
.
get_weights_by_name_communicator
.
handle_recv
,
),
(
ReleaseMemoryOccupationReqOutput
,
self
.
release_memory_occupation_communicator
.
handle_recv
,
),
(
ResumeMemoryOccupationReqOutput
,
self
.
resume_memory_occupation_communicator
.
handle_recv
,
),
(
SlowDownReqOutput
,
self
.
slow_down_communicator
.
handle_recv
,
),
(
ClearHiCacheReqOutput
,
self
.
clear_hicache_storage_communicator
.
handle_recv
,
),
(
FlushCacheReqOutput
,
self
.
flush_cache_communicator
.
handle_recv
,
),
(
ProfileReqOutput
,
self
.
profile_communicator
.
handle_recv
,
),
(
GetInternalStateReqOutput
,
self
.
get_internal_state_communicator
.
handle_recv
,
),
(
SetInternalStateReqOutput
,
self
.
set_internal_state_communicator
.
handle_recv
,
),
(
ExpertDistributionReqOutput
,
self
.
expert_distribution_communicator
.
handle_recv
,
),
(
LoRAUpdateResult
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
]
)
async
def
flush_cache
(
self
:
TokenizerManager
)
->
FlushCacheReqOutput
:
return
(
await
self
.
flush_cache_communicator
(
FlushCacheReqInput
()))[
0
]
async
def
clear_hicache_storage
(
self
:
TokenizerManager
)
->
ClearHiCacheReqOutput
:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return
(
await
self
.
clear_hicache_storage_communicator
(
ClearHiCacheReqInput
()))[
0
]
async
def
start_profile
(
self
:
TokenizerManager
,
output_dir
:
Optional
[
str
]
=
None
,
start_step
:
Optional
[
int
]
=
None
,
num_steps
:
Optional
[
int
]
=
None
,
activities
:
Optional
[
List
[
str
]]
=
None
,
with_stack
:
Optional
[
bool
]
=
None
,
record_shapes
:
Optional
[
bool
]
=
None
,
profile_by_stage
:
bool
=
False
,
):
self
.
auto_create_handle_loop
()
env_with_stack
:
bool
=
get_bool_env_var
(
"SGLANG_PROFILE_WITH_STACK"
,
"true"
)
with_stack
=
False
if
with_stack
is
False
or
env_with_stack
is
False
else
True
req
=
ProfileReq
(
type
=
ProfileReqType
.
START_PROFILE
,
output_dir
=
output_dir
,
start_step
=
start_step
,
num_steps
=
num_steps
,
activities
=
activities
,
with_stack
=
with_stack
,
record_shapes
=
record_shapes
,
profile_by_stage
=
profile_by_stage
,
profile_id
=
str
(
time
.
time
()),
)
return
await
self
.
_execute_profile
(
req
)
async
def
stop_profile
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
req
=
ProfileReq
(
type
=
ProfileReqType
.
STOP_PROFILE
)
return
await
self
.
_execute_profile
(
req
)
async
def
_execute_profile
(
self
:
TokenizerManager
,
req
:
ProfileReq
):
result
=
(
await
self
.
profile_communicator
(
req
))[
0
]
if
not
result
.
success
:
raise
RuntimeError
(
result
.
message
)
return
result
async
def
start_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
START_RECORD
)
async
def
stop_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
STOP_RECORD
)
async
def
dump_expert_distribution_record
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
async
def
init_weights_update_group
(
self
:
TokenizerManager
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
result
=
(
await
self
.
init_weights_update_group_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
self
:
TokenizerManager
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
enable_dp_attention
),
"dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if
obj
.
abort_all_requests
:
self
.
abort_request
(
abort_all
=
True
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_distributed_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_tensor
(
self
:
TokenizerManager
,
obj
:
UpdateWeightsFromTensorReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
enable_dp_attention
),
"dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if
obj
.
abort_all_requests
:
self
.
abort_request
(
abort_all
=
True
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_tensor_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
load_lora_adapter
(
self
:
TokenizerManager
,
obj
:
LoadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
LoadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
try
:
if
not
self
.
server_args
.
enable_lora
:
raise
ValueError
(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start load Lora adapter. Lora name=%s, path=%s"
,
obj
.
lora_name
,
obj
.
lora_path
,
)
async
with
self
.
lora_update_lock
:
if
(
self
.
server_args
.
max_loaded_loras
is
not
None
and
self
.
lora_registry
.
num_registered_loras
>=
self
.
server_args
.
max_loaded_loras
):
raise
ValueError
(
f
"Cannot load LoRA adapter
{
obj
.
lora_name
}
at path
{
obj
.
lora_path
}
. "
f
"Maximum number of loaded LoRA adapters is
{
self
.
server_args
.
max_loaded_loras
}
. "
"Please unload some LoRA adapters before loading new ones."
)
# Generate new uniquely identifiable LoRARef object.
new_adapter
=
LoRARef
(
lora_name
=
obj
.
lora_name
,
lora_path
=
obj
.
lora_path
,
pinned
=
obj
.
pinned
,
)
# Trigger the actual loading operation at the backend processes.
obj
.
lora_id
=
new_adapter
.
lora_id
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
# Register the LoRA adapter only after loading is successful.
if
result
.
success
:
await
self
.
lora_registry
.
register
(
new_adapter
)
return
result
except
ValueError
as
e
:
return
LoadLoRAAdapterReqOutput
(
success
=
False
,
error_message
=
str
(
e
),
)
async
def
unload_lora_adapter
(
self
:
TokenizerManager
,
obj
:
UnloadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
UnloadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
try
:
if
not
self
.
server_args
.
enable_lora
:
raise
ValueError
(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert
(
obj
.
lora_name
is
not
None
),
"lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start unload Lora adapter. Lora name=%s"
,
obj
.
lora_name
,
)
async
with
self
.
lora_update_lock
:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id
=
await
self
.
lora_registry
.
unregister
(
obj
.
lora_name
)
obj
.
lora_id
=
lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await
self
.
lora_registry
.
wait_for_unload
(
lora_id
)
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
return
result
except
ValueError
as
e
:
return
UnloadLoRAAdapterReqOutput
(
success
=
False
,
error_message
=
str
(
e
))
async
def
get_weights_by_name
(
self
:
TokenizerManager
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
results
=
await
self
.
get_weights_by_name_communicator
(
obj
)
all_parameters
=
[
r
.
parameter
for
r
in
results
]
if
self
.
server_args
.
dp_size
==
1
:
return
all_parameters
[
0
]
else
:
return
all_parameters
async
def
release_memory_occupation
(
self
:
TokenizerManager
,
obj
:
ReleaseMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
release_memory_occupation_communicator
(
obj
)
async
def
resume_memory_occupation
(
self
:
TokenizerManager
,
obj
:
ResumeMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
resume_memory_occupation_communicator
(
obj
)
async
def
slow_down
(
self
:
TokenizerManager
,
obj
:
SlowDownReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
slow_down_communicator
(
obj
)
async
def
get_internal_state
(
self
:
TokenizerManager
)
->
List
[
Dict
[
Any
,
Any
]]:
req
=
GetInternalStateReq
()
responses
:
List
[
GetInternalStateReqOutput
]
=
(
await
self
.
get_internal_state_communicator
(
req
)
)
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
async
def
set_internal_state
(
self
:
TokenizerManager
,
obj
:
SetInternalStateReq
)
->
List
[
bool
]:
responses
:
List
[
SetInternalStateReqOutput
]
=
(
await
self
.
set_internal_state_communicator
(
obj
)
)
return
[
res
.
updated
for
res
in
responses
]
async
def
get_load
(
self
:
TokenizerManager
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
python/sglang/srt/managers/tokenizer_manager.py
View file @
78f13981
...
@@ -31,19 +31,7 @@ from contextlib import nullcontext
...
@@ -31,19 +31,7 @@ from contextlib import nullcontext
from
datetime
import
datetime
from
datetime
import
datetime
from
enum
import
Enum
from
enum
import
Enum
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
(
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
Any
,
Awaitable
,
Deque
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
)
import
fastapi
import
fastapi
import
torch
import
torch
...
@@ -70,57 +58,26 @@ from sglang.srt.managers.io_struct import (
...
@@ -70,57 +58,26 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut
,
BatchTokenIDOut
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedEmbeddingReqInput
,
BatchTokenizedGenerateReqInput
,
BatchTokenizedGenerateReqInput
,
ClearHiCacheReqInput
,
ClearHiCacheReqOutput
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
FlushCacheReqInput
,
FlushCacheReqOutput
,
FreezeGCReq
,
FreezeGCReq
,
GenerateReqInput
,
GenerateReqInput
,
GetInternalStateReq
,
GetInternalStateReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
HealthCheckOutput
,
HealthCheckOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
LoRAUpdateResult
,
MultiTokenizerWrapper
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
SessionParams
,
SessionParams
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
SlowDownReqInput
,
SlowDownReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UnloadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
)
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.mm_utils
import
TensorTransportMode
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.managers.scheduler
import
is_health_check_generate_req
from
sglang.srt.managers.scheduler
import
is_health_check_generate_req
from
sglang.srt.managers.scheduler_input_blocker
import
input_blocker_guard_region
from
sglang.srt.managers.scheduler_input_blocker
import
input_blocker_guard_region
from
sglang.srt.managers.tokenizer_communicator_mixin
import
TokenizerCommunicatorMixin
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -177,7 +134,7 @@ class ReqState:
...
@@ -177,7 +134,7 @@ class ReqState:
output_token_ids_logprobs_idx
:
List
=
dataclasses
.
field
(
default_factory
=
list
)
output_token_ids_logprobs_idx
:
List
=
dataclasses
.
field
(
default_factory
=
list
)
class
TokenizerManager
:
class
TokenizerManager
(
TokenizerCommunicatorMixin
)
:
"""TokenizerManager is a process that tokenizes the text."""
"""TokenizerManager is a process that tokenizes the text."""
def
__init__
(
def
__init__
(
...
@@ -343,50 +300,6 @@ class TokenizerManager:
...
@@ -343,50 +300,6 @@ class TokenizerManager:
if
self
.
server_args
.
gc_warning_threshold_secs
>
0.0
:
if
self
.
server_args
.
gc_warning_threshold_secs
>
0.0
:
configure_gc_warning
(
self
.
server_args
.
gc_warning_threshold_secs
)
configure_gc_warning
(
self
.
server_args
.
gc_warning_threshold_secs
)
# Communicators
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_tensor_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
release_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
resume_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
slow_down_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
flush_cache_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
clear_hicache_storage_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
profile_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
set_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
expert_distribution_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_lora_adapter_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
[
[
(
(
...
@@ -404,70 +317,16 @@ class TokenizerManager:
...
@@ -404,70 +317,16 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
self
.
_handle_update_weights_from_disk_req_output
,
self
.
_handle_update_weights_from_disk_req_output
,
),
),
(
InitWeightsUpdateGroupReqOutput
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
),
(
UpdateWeightsFromDistributedReqOutput
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
),
(
UpdateWeightsFromTensorReqOutput
,
self
.
update_weights_from_tensor_communicator
.
handle_recv
,
),
(
GetWeightsByNameReqOutput
,
self
.
get_weights_by_name_communicator
.
handle_recv
,
),
(
ReleaseMemoryOccupationReqOutput
,
self
.
release_memory_occupation_communicator
.
handle_recv
,
),
(
ResumeMemoryOccupationReqOutput
,
self
.
resume_memory_occupation_communicator
.
handle_recv
,
),
(
SlowDownReqOutput
,
self
.
slow_down_communicator
.
handle_recv
,
),
(
ClearHiCacheReqOutput
,
self
.
clear_hicache_storage_communicator
.
handle_recv
,
),
(
FlushCacheReqOutput
,
self
.
flush_cache_communicator
.
handle_recv
,
),
(
ProfileReqOutput
,
self
.
profile_communicator
.
handle_recv
,
),
(
(
FreezeGCReq
,
FreezeGCReq
,
lambda
x
:
None
,
lambda
x
:
None
,
),
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
),
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
(
GetInternalStateReqOutput
,
self
.
get_internal_state_communicator
.
handle_recv
,
),
(
SetInternalStateReqOutput
,
self
.
set_internal_state_communicator
.
handle_recv
,
),
(
ExpertDistributionReqOutput
,
self
.
expert_distribution_communicator
.
handle_recv
,
),
(
LoRAUpdateResult
,
self
.
update_lora_adapter_communicator
.
handle_recv
,
),
(
HealthCheckOutput
,
lambda
x
:
None
),
(
HealthCheckOutput
,
lambda
x
:
None
),
]
]
)
)
self
.
init_communicators
(
server_args
)
async
def
generate_request
(
async
def
generate_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
@@ -983,16 +842,6 @@ class TokenizerManager:
...
@@ -983,16 +842,6 @@ class TokenizerManager:
except
StopAsyncIteration
:
except
StopAsyncIteration
:
pass
pass
async
def
flush_cache
(
self
)
->
FlushCacheReqOutput
:
return
(
await
self
.
flush_cache_communicator
(
FlushCacheReqInput
()))[
0
]
async
def
clear_hicache_storage
(
self
)
->
ClearHiCacheReqOutput
:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return
(
await
self
.
clear_hicache_storage_communicator
(
ClearHiCacheReqInput
()))[
0
]
def
abort_request
(
self
,
rid
:
str
=
""
,
abort_all
:
bool
=
False
):
def
abort_request
(
self
,
rid
:
str
=
""
,
abort_all
:
bool
=
False
):
if
not
abort_all
and
rid
not
in
self
.
rid_to_state
:
if
not
abort_all
and
rid
not
in
self
.
rid_to_state
:
return
return
...
@@ -1002,55 +851,6 @@ class TokenizerManager:
...
@@ -1002,55 +851,6 @@ class TokenizerManager:
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
self
.
metrics_collector
.
observe_one_aborted_request
()
self
.
metrics_collector
.
observe_one_aborted_request
()
async
def
start_profile
(
self
,
output_dir
:
Optional
[
str
]
=
None
,
start_step
:
Optional
[
int
]
=
None
,
num_steps
:
Optional
[
int
]
=
None
,
activities
:
Optional
[
List
[
str
]]
=
None
,
with_stack
:
Optional
[
bool
]
=
None
,
record_shapes
:
Optional
[
bool
]
=
None
,
profile_by_stage
:
bool
=
False
,
):
self
.
auto_create_handle_loop
()
env_with_stack
:
bool
=
get_bool_env_var
(
"SGLANG_PROFILE_WITH_STACK"
,
"true"
)
with_stack
=
False
if
with_stack
is
False
or
env_with_stack
is
False
else
True
req
=
ProfileReq
(
type
=
ProfileReqType
.
START_PROFILE
,
output_dir
=
output_dir
,
start_step
=
start_step
,
num_steps
=
num_steps
,
activities
=
activities
,
with_stack
=
with_stack
,
record_shapes
=
record_shapes
,
profile_by_stage
=
profile_by_stage
,
profile_id
=
str
(
time
.
time
()),
)
return
await
self
.
_execute_profile
(
req
)
async
def
stop_profile
(
self
):
self
.
auto_create_handle_loop
()
req
=
ProfileReq
(
type
=
ProfileReqType
.
STOP_PROFILE
)
return
await
self
.
_execute_profile
(
req
)
async
def
_execute_profile
(
self
,
req
:
ProfileReq
):
result
=
(
await
self
.
profile_communicator
(
req
))[
0
]
if
not
result
.
success
:
raise
RuntimeError
(
result
.
message
)
return
result
async
def
start_expert_distribution_record
(
self
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
START_RECORD
)
async
def
stop_expert_distribution_record
(
self
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
STOP_RECORD
)
async
def
dump_expert_distribution_record
(
self
):
self
.
auto_create_handle_loop
()
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
async
def
pause_generation
(
self
):
async
def
pause_generation
(
self
):
async
with
self
.
is_pause_cond
:
async
with
self
.
is_pause_cond
:
self
.
is_pause
=
True
self
.
is_pause
=
True
...
@@ -1111,191 +911,6 @@ class TokenizerManager:
...
@@ -1111,191 +911,6 @@ class TokenizerManager:
all_paused_requests
=
[
r
.
num_paused_requests
for
r
in
result
]
all_paused_requests
=
[
r
.
num_paused_requests
for
r
in
result
]
return
all_success
,
all_message
,
all_paused_requests
return
all_success
,
all_message
,
all_paused_requests
async
def
init_weights_update_group
(
self
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for init parameter update group"
result
=
(
await
self
.
init_weights_update_group_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
self
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
enable_dp_attention
),
"dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if
obj
.
abort_all_requests
:
self
.
abort_request
(
abort_all
=
True
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_distributed_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_tensor
(
self
,
obj
:
UpdateWeightsFromTensorReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
or
self
.
server_args
.
enable_dp_attention
),
"dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if
obj
.
abort_all_requests
:
self
.
abort_request
(
abort_all
=
True
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
result
=
(
await
self
.
update_weights_from_tensor_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
load_lora_adapter
(
self
,
obj
:
LoadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
LoadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
try
:
if
not
self
.
server_args
.
enable_lora
:
raise
ValueError
(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start load Lora adapter. Lora name=%s, path=%s"
,
obj
.
lora_name
,
obj
.
lora_path
,
)
async
with
self
.
lora_update_lock
:
if
(
self
.
server_args
.
max_loaded_loras
is
not
None
and
self
.
lora_registry
.
num_registered_loras
>=
self
.
server_args
.
max_loaded_loras
):
raise
ValueError
(
f
"Cannot load LoRA adapter
{
obj
.
lora_name
}
at path
{
obj
.
lora_path
}
. "
f
"Maximum number of loaded LoRA adapters is
{
self
.
server_args
.
max_loaded_loras
}
. "
"Please unload some LoRA adapters before loading new ones."
)
# Generate new uniquely identifiable LoRARef object.
new_adapter
=
LoRARef
(
lora_name
=
obj
.
lora_name
,
lora_path
=
obj
.
lora_path
,
pinned
=
obj
.
pinned
,
)
# Trigger the actual loading operation at the backend processes.
obj
.
lora_id
=
new_adapter
.
lora_id
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
# Register the LoRA adapter only after loading is successful.
if
result
.
success
:
await
self
.
lora_registry
.
register
(
new_adapter
)
return
result
except
ValueError
as
e
:
return
LoadLoRAAdapterReqOutput
(
success
=
False
,
error_message
=
str
(
e
),
)
async
def
unload_lora_adapter
(
self
,
obj
:
UnloadLoRAAdapterReqInput
,
_
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
UnloadLoRAAdapterReqOutput
:
self
.
auto_create_handle_loop
()
try
:
if
not
self
.
server_args
.
enable_lora
:
raise
ValueError
(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert
(
obj
.
lora_name
is
not
None
),
"lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for dynamic lora loading"
logger
.
info
(
"Start unload Lora adapter. Lora name=%s"
,
obj
.
lora_name
,
)
async
with
self
.
lora_update_lock
:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id
=
await
self
.
lora_registry
.
unregister
(
obj
.
lora_name
)
obj
.
lora_id
=
lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await
self
.
lora_registry
.
wait_for_unload
(
lora_id
)
result
=
(
await
self
.
update_lora_adapter_communicator
(
obj
))[
0
]
return
result
except
ValueError
as
e
:
return
UnloadLoRAAdapterReqOutput
(
success
=
False
,
error_message
=
str
(
e
))
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
self
.
auto_create_handle_loop
()
results
=
await
self
.
get_weights_by_name_communicator
(
obj
)
all_parameters
=
[
r
.
parameter
for
r
in
results
]
if
self
.
server_args
.
dp_size
==
1
:
return
all_parameters
[
0
]
else
:
return
all_parameters
async
def
release_memory_occupation
(
self
,
obj
:
ReleaseMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
release_memory_occupation_communicator
(
obj
)
async
def
resume_memory_occupation
(
self
,
obj
:
ResumeMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
resume_memory_occupation_communicator
(
obj
)
async
def
slow_down
(
self
,
obj
:
SlowDownReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
slow_down_communicator
(
obj
)
async
def
open_session
(
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
...
@@ -1320,28 +935,6 @@ class TokenizerManager:
...
@@ -1320,28 +935,6 @@ class TokenizerManager:
):
):
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
async
def
get_internal_state
(
self
)
->
List
[
Dict
[
Any
,
Any
]]:
req
=
GetInternalStateReq
()
responses
:
List
[
GetInternalStateReqOutput
]
=
(
await
self
.
get_internal_state_communicator
(
req
)
)
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
async
def
set_internal_state
(
self
,
obj
:
SetInternalStateReq
)
->
List
[
bool
]:
responses
:
List
[
SetInternalStateReqOutput
]
=
(
await
self
.
set_internal_state_communicator
(
obj
)
)
return
[
res
.
updated
for
res
in
responses
]
async
def
get_load
(
self
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
def
get_log_request_metadata
(
self
):
def
get_log_request_metadata
(
self
):
max_length
=
None
max_length
=
None
skip_names
=
None
skip_names
=
None
...
@@ -2108,51 +1701,6 @@ class SignalHandler:
...
@@ -2108,51 +1701,6 @@ class SignalHandler:
kill_process_tree
(
os
.
getpid
())
kill_process_tree
(
os
.
getpid
())
T
=
TypeVar
(
"T"
)
class
_Communicator
(
Generic
[
T
]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer
=
False
def
__init__
(
self
,
sender
,
fan_out
:
int
):
self
.
_sender
=
sender
self
.
_fan_out
=
fan_out
self
.
_result_event
:
Optional
[
asyncio
.
Event
]
=
None
self
.
_result_values
:
Optional
[
List
[
T
]]
=
None
self
.
_ready_queue
:
Deque
[
asyncio
.
Future
]
=
deque
()
async
def
__call__
(
self
,
obj
):
ready_event
=
asyncio
.
Event
()
if
self
.
_result_event
is
not
None
or
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
append
(
ready_event
)
await
ready_event
.
wait
()
assert
self
.
_result_event
is
None
assert
self
.
_result_values
is
None
if
obj
:
if
_Communicator
.
enable_multi_tokenizer
:
obj
=
MultiTokenizerWrapper
(
worker_id
=
os
.
getpid
(),
obj
=
obj
)
self
.
_sender
.
send_pyobj
(
obj
)
self
.
_result_event
=
asyncio
.
Event
()
self
.
_result_values
=
[]
await
self
.
_result_event
.
wait
()
result_values
=
self
.
_result_values
self
.
_result_event
=
self
.
_result_values
=
None
if
len
(
self
.
_ready_queue
)
>
0
:
self
.
_ready_queue
.
popleft
().
set
()
return
result_values
def
handle_recv
(
self
,
recv_obj
:
T
):
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
self
.
_result_event
.
set
()
# Note: request abort handling logic
# Note: request abort handling logic
# We should handle all of the following cases correctly.
# We should handle all of the following cases correctly.
#
#
...
...
python/sglang/srt/weight_sync/utils.py
View file @
78f13981
...
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
...
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
from
torch.distributed.tensor
import
DTensor
from
torch.distributed.tensor
import
DTensor
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.managers.
tokenizer_manager
import
UpdateWeightsFromTensorReqInput
from
sglang.srt.managers.
io_struct
import
UpdateWeightsFromTensorReqInput
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.utils
import
MultiprocessingSerializer
from
sglang.srt.utils
import
MultiprocessingSerializer
...
...
python/sglang/utils.py
View file @
78f13981
...
@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
...
@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
def
__init__
(
self
,
mapping
:
List
[
Tuple
[
Type
,
Callable
]]):
self
.
_mapping
=
mapping
self
.
_mapping
=
mapping
def
__iadd__
(
self
,
other
:
"TypeBasedDispatcher"
):
self
.
_mapping
.
extend
(
other
.
_mapping
)
return
self
def
__call__
(
self
,
obj
:
Any
):
def
__call__
(
self
,
obj
:
Any
):
for
ty
,
fn
in
self
.
_mapping
:
for
ty
,
fn
in
self
.
_mapping
:
if
isinstance
(
obj
,
ty
):
if
isinstance
(
obj
,
ty
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment