Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a4c3b121
Unverified
Commit
a4c3b121
authored
Jul 29, 2025
by
Lianmin Zheng
Committed by
GitHub
Jul 29, 2025
Browse files
Split the scheduler into multiple mixin classes to reduce the file size (#8483)
parent
5973675b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
868 additions
and
784 deletions
+868
-784
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+2
-8
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+2
-6
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+11
-17
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+10
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+0
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+64
-661
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+229
-0
python/sglang/srt/managers/scheduler_profiler_mixin.py
python/sglang/srt/managers/scheduler_profiler_mixin.py
+279
-0
python/sglang/srt/managers/scheduler_update_weights_mixin.py
python/sglang/srt/managers/scheduler_update_weights_mixin.py
+142
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+123
-74
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-3
python/sglang/utils.py
python/sglang/utils.py
+0
-11
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
a4c3b121
...
...
@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
...
...
@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
self
.
last_batch_in_queue
=
last_batch_in_queue
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
a4c3b121
...
...
@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
...
...
@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
...
...
python/sglang/srt/entrypoints/engine.py
View file @
a4c3b121
...
...
@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs):
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
def
sigchld_handler
(
signum
,
frame
):
pid
,
exitcode
=
os
.
waitpid
(
0
,
os
.
WNOHANG
)
if
exitcode
!=
0
:
logger
.
warning
(
f
"Child process unexpectedly failed with
{
exitcode
=
}
.
{
pid
=
}
"
if
True
:
# Keep this check for internal code compatibility
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
# Note: This sigquit handler is used in the launch phase, and may be replaced by
# the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
def
launch_phase_sigquit_handler
(
signum
,
frame
):
logger
.
error
(
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGCHLD
,
sigchld_handler
)
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def
sigquit_handler
(
signum
,
frame
):
logger
.
error
(
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGQUIT
,
sigquit_handler
)
signal
.
signal
(
signal
.
SIGQUIT
,
launch_phase_sigquit_handler
)
# Set mp start method
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
a4c3b121
...
...
@@ -238,6 +238,9 @@ async def health() -> Response:
@
app
.
get
(
"/health_generate"
)
async
def
health_generate
(
request
:
Request
)
->
Response
:
"""Check the health of the inference server by generating one token."""
if
_global_state
.
tokenizer_manager
.
gracefully_exit
:
logger
.
info
(
"Health check request received during shutdown. Returning 503."
)
return
Response
(
status_code
=
503
)
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.0
}
rid
=
f
"HEALTH_CHECK_
{
time
.
time
()
}
"
...
...
@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
async
for
_
in
_global_state
.
tokenizer_manager
.
generate_request
(
gri
,
request
):
break
tic
=
time
.
perf_counter
()
# This request is a special request.
# If the server already has something running, this request will be ignored, so it creates zero overhead.
# If the server is not running, this request will be run, so we know whether the server is healthy.
task
=
asyncio
.
create_task
(
gen
())
while
time
.
perf_counter
()
<
tic
+
HEALTH_CHECK_TIMEOUT
:
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
tic
=
time
.
time
()
while
time
.
time
()
<
tic
+
HEALTH_CHECK_TIMEOUT
:
await
asyncio
.
sleep
(
1
)
if
_global_state
.
tokenizer_manager
.
last_receive_tstamp
>
tic
:
task
.
cancel
()
...
...
python/sglang/srt/managers/io_struct.py
View file @
a4c3b121
...
...
@@ -152,8 +152,6 @@ class GenerateReqInput:
else
:
self
.
_normalize_batch_inputs
()
self
.
_validate_session_params
()
def
_validate_inputs
(
self
):
"""Validate that the input configuration is valid."""
if
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a4c3b121
...
...
@@ -13,7 +13,6 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""
import
datetime
import
faulthandler
import
logging
import
os
...
...
@@ -21,11 +20,10 @@ import signal
import
sys
import
threading
import
time
from
collections
import
defaultdict
,
deque
from
collections
import
deque
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -37,7 +35,6 @@ from torch.distributed import barrier
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.constrained.base_grammar_backend
import
(
INVALID_GRAMMAR_OBJ
,
create_grammar_backend
,
...
...
@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue
,
SchedulerDisaggregationDecodeMixin
,
)
from
sglang.srt.disaggregation.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
sglang.srt.disaggregation.prefill
import
(
PrefillBootstrapQueue
,
SchedulerDisaggregationPrefillMixin
,
...
...
@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReq
,
GetInternalStateReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
HealthCheckOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
LoadLoRAAdapterReqInput
,
LoadLoRAAdapterReqOutput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
RpcReqInput
,
RpcReqOutput
,
SetInternalStateReq
,
...
...
@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput
,
UnloadLoRAAdapterReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.managers.mm_utils
import
init_embedding_cache
from
sglang.srt.managers.schedule_batch
import
(
...
...
@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy
,
)
from
sglang.srt.managers.scheduler_input_blocker
import
SchedulerInputBlocker
from
sglang.srt.managers.scheduler_metrics_mixin
import
(
RECORD_STEP_TIME
,
SchedulerMetricsMixin
,
)
from
sglang.srt.managers.scheduler_output_processor_mixin
import
(
SchedulerOutputProcessorMixin
,
)
from
sglang.srt.managers.scheduler_profiler_mixin
import
SchedulerProfilerMixin
from
sglang.srt.managers.scheduler_update_weights_mixin
import
(
SchedulerUpdateWeightsMixin
,
)
from
sglang.srt.managers.session_controller
import
Session
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
...
...
@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
PPProxyTensors
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
_is_cpu
=
is_cpu
()
...
...
@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
bid
:
int
class
KvMetrics
:
def
__init__
(
self
):
self
.
request_active_slots
=
None
self
.
request_total_slots
=
None
self
.
kv_active_blocks
=
None
self
.
kv_total_blocks
=
None
self
.
num_requests_waiting
=
None
self
.
gpu_cache_usage_perc
=
None
self
.
gpu_prefix_cache_hit_rate
=
None
self
.
data_parallel_rank
=
None
class
IdleSleeper
:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def
__init__
(
self
,
sockets
):
self
.
poller
=
zmq
.
Poller
()
for
s
in
sockets
:
self
.
poller
.
register
(
s
,
zmq
.
POLLIN
)
def
maybe_sleep
(
self
):
self
.
poller
.
poll
(
1000
)
class
Scheduler
(
SchedulerOutputProcessorMixin
,
SchedulerUpdateWeightsMixin
,
SchedulerProfilerMixin
,
SchedulerMetricsMixin
,
SchedulerDisaggregationDecodeMixin
,
SchedulerDisaggregationPrefillMixin
,
):
...
...
@@ -266,7 +229,7 @@ class Scheduler(
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
self
.
enable_hicache_storage
=
server_args
.
hicache_storage_backend
is
not
None
self
.
page_size
=
server_args
.
page_size
self
.
dp_size
=
server_args
.
dp_size
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
attn_dp_rank
=
(
compute_dp_attention_world_info
(
server_args
.
enable_dp_attention
,
...
...
@@ -284,10 +247,13 @@ class Scheduler(
self
.
recv_from_tokenizer
=
get_zmq_socket
(
context
,
zmq
.
PULL
,
port_args
.
scheduler_input_ipc_name
,
False
)
self
.
recv_from_rpc
=
get_zmq_socket
(
context
,
zmq
.
DEALER
,
port_args
.
rpc_ipc_name
,
False
)
self
.
send_to_tokenizer
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
tokenizer_ipc_name
,
False
)
if
server_args
.
skip_tokenizer_init
:
# Directly send to the TokenizerManager
self
.
send_to_detokenizer
=
get_zmq_socket
(
...
...
@@ -299,9 +265,6 @@ class Scheduler(
context
,
zmq
.
PUSH
,
port_args
.
detokenizer_ipc_name
,
False
)
self
.
recv_from_rpc
=
get_zmq_socket
(
context
,
zmq
.
DEALER
,
port_args
.
rpc_ipc_name
,
False
)
if
self
.
server_args
.
sleep_on_idle
:
self
.
idle_sleeper
=
IdleSleeper
(
[
...
...
@@ -398,7 +361,7 @@ class Scheduler(
global_server_args_dict
.
update
(
worker_global_server_args_dict
)
set_random_seed
(
self
.
random_seed
)
# Hybrid
# Hybrid
memory pool
self
.
is_hybrid
=
self
.
tp_worker
.
is_hybrid
if
self
.
is_hybrid
:
self
.
sliding_window_size
=
self
.
tp_worker
.
sliding_window_size
...
...
@@ -515,6 +478,15 @@ class Scheduler(
self
.
init_metrics
(
tp_rank
,
pp_rank
,
dp_rank
)
self
.
init_kv_events
(
server_args
.
kv_events_config
)
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
init_disaggregation
()
if
get_bool_env_var
(
"SGLANG_GC_LOG"
):
configure_gc_logger
()
# Init request dispatcher
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
...
...
@@ -545,22 +517,6 @@ class Scheduler(
]
)
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
init_disaggregation
()
if
get_bool_env_var
(
"SGLANG_GC_LOG"
):
configure_gc_logger
()
def
current_scheduler_metrics_enabled
(
self
):
return
self
.
attn_tp_rank
==
0
or
self
.
enable_metrics_for_all_schedulers
def
maybe_sleep_on_idle
(
self
):
if
self
.
idle_sleeper
is
not
None
:
self
.
idle_sleeper
.
maybe_sleep
()
def
init_tokenizer
(
self
):
server_args
=
self
.
server_args
...
...
@@ -668,50 +624,6 @@ class Scheduler(
embedding_cache_size
=
int
(
os
.
environ
.
get
(
"SGLANG_VLM_CACHE_SIZE_MB"
,
"100"
))
init_embedding_cache
(
embedding_cache_size
*
1024
*
1024
)
def
init_profier
(
self
):
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_start_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
def
init_metrics
(
self
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
]):
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
step_time_dict
=
defaultdict
(
list
)
# Dict[batch size -> step time]
self
.
spec_num_total_accepted_tokens
=
0
self
.
spec_num_total_forward_ct
=
0
self
.
cum_spec_accept_length
=
0
self
.
cum_spec_accept_count
=
0
self
.
total_retracted_reqs
=
0
self
.
stats
=
SchedulerStats
()
if
self
.
enable_metrics
:
engine_type
=
"unified"
labels
=
{
"model_name"
:
self
.
server_args
.
served_model_name
,
"engine_type"
:
engine_type
,
"tp_rank"
:
tp_rank
,
"pp_rank"
:
pp_rank
,
}
if
dp_rank
is
not
None
:
labels
[
"dp_rank"
]
=
dp_rank
self
.
metrics_collector
=
SchedulerMetricsCollector
(
labels
=
labels
)
def
init_kv_events
(
self
,
kv_events_config
:
Optional
[
str
]):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
,
self
.
attn_dp_rank
)
def
init_disaggregation
(
self
):
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
...
...
@@ -820,10 +732,7 @@ class Scheduler(
self
.
process_batch_result
(
batch
,
result
)
else
:
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
...
...
@@ -866,10 +775,7 @@ class Scheduler(
)
elif
batch
is
None
:
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
self_check_during_idle
()
self
.
last_batch
=
batch
...
...
@@ -1003,10 +909,8 @@ class Scheduler(
# When the server is idle, self-check and re-init some states
if
server_is_idle
:
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
# When the server is idle, do self-check and re-init some states
self
.
self_check_during_idle
()
def
recv_requests
(
self
)
->
List
[
Req
]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
...
...
@@ -1355,170 +1259,11 @@ class Scheduler(
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
self
.
_add_request_to_queue
(
req
)
def
_emit_kv_metrics
(
self
):
kv_metrics
=
KvMetrics
()
kv_metrics
.
request_active_slots
=
self
.
stats
.
num_running_reqs
kv_metrics
.
request_total_slots
=
self
.
max_running_requests
kv_metrics
.
kv_active_blocks
=
int
(
self
.
stats
.
token_usage
*
self
.
max_total_num_tokens
)
kv_metrics
.
kv_total_blocks
=
self
.
max_total_num_tokens
kv_metrics
.
num_requests_waiting
=
self
.
stats
.
num_queue_reqs
kv_metrics
.
gpu_cache_usage_perc
=
self
.
stats
.
token_usage
kv_metrics
.
gpu_prefix_cache_hit_rate
=
self
.
stats
.
cache_hit_rate
kv_metrics
.
data_parallel_rank
=
self
.
dp_rank
if
self
.
dp_rank
is
not
None
else
0
if
not
self
.
send_metrics_from_scheduler
.
closed
:
self
.
send_metrics_from_scheduler
.
send_pyobj
(
kv_metrics
)
def
log_prefill_stats
(
self
,
adder
:
PrefillAdder
,
can_run_list
:
List
[
Req
],
running_bs
:
int
,
):
gap_latency
=
time
.
perf_counter
()
-
self
.
last_prefill_stats_tic
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
last_input_throughput
=
self
.
last_prefill_tokens
/
gap_latency
self
.
last_prefill_tokens
=
adder
.
log_input_tokens
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"token usage:
{
token_usage
:.
2
f
}
, "
num_new_seq
=
len
(
can_run_list
)
f
=
(
f
"Prefill batch. "
f
"#new-seq:
{
num_new_seq
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"
{
token_msg
}
"
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
f
+=
f
"#unbootstrapped-req:
{
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
+=
f
"#transferring-req:
{
len
(
self
.
disagg_prefill_inflight_queue
)
}
, "
f
+=
f
"input throughput (token/s):
{
self
.
last_input_throughput
:.
2
f
}
, "
else
:
f
+=
f
"#running-req:
{
running_bs
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
logger
.
info
(
f
)
if
self
.
enable_metrics
:
total_tokens
=
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
cache_hit_rate
=
(
adder
.
log_hit_tokens
/
total_tokens
if
total_tokens
>
0
else
0.0
)
self
.
stats
.
num_running_reqs
=
running_bs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
cache_hit_rate
=
cache_hit_rate
total_queue_latency
=
0
for
req
in
can_run_list
:
total_queue_latency
+=
req
.
queue_time_end
-
req
.
queue_time_start
self
.
stats
.
avg_request_queue_latency
=
total_queue_latency
/
num_new_seq
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
):
batch
=
running_batch
or
self
.
running_batch
gap_latency
=
time
.
perf_counter
()
-
self
.
last_decode_stats_tic
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_gen_throughput
=
self
.
num_generated_tokens
/
gap_latency
self
.
num_generated_tokens
=
0
num_running_reqs
=
len
(
batch
.
reqs
)
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
f
"#full token:
{
full_num_used
}
, "
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"#swa token:
{
swa_num_used
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"#token:
{
num_used
}
, "
f
"token usage:
{
token_usage
:.
2
f
}
, "
if
RECORD_STEP_TIME
:
self
.
step_time_dict
[
num_running_reqs
].
append
(
gap_latency
/
self
.
server_args
.
decode_log_interval
)
msg
=
f
"Decode batch. #running-req:
{
num_running_reqs
}
,
{
token_msg
}
"
if
self
.
spec_algorithm
.
is_none
():
spec_accept_length
=
0
else
:
spec_accept_length
=
(
self
.
spec_num_total_accepted_tokens
/
self
.
spec_num_total_forward_ct
)
self
.
cum_spec_accept_length
+=
self
.
spec_num_total_accepted_tokens
self
.
cum_spec_accept_count
+=
self
.
spec_num_total_forward_ct
self
.
spec_num_total_accepted_tokens
=
self
.
spec_num_total_forward_ct
=
0
msg
+=
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
msg
+=
f
"pre-allocated usage:
{
self
.
disagg_decode_prealloc_queue
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
logger
.
info
(
msg
)
if
self
.
enable_metrics
:
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
stats
.
total_retracted_reqs
=
self
.
total_retracted_reqs
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
self_check_during_idle
(
self
):
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
def
check_memory
(
self
):
if
self
.
is_hybrid
:
...
...
@@ -2422,22 +2167,6 @@ class Scheduler(
barrier
()
return
RpcReqOutput
(
success
,
""
if
not
exec
else
str
(
exec
))
def
save_remote_model
(
self
,
params
):
url
=
params
[
"url"
]
worker
=
self
.
tp_worker
.
worker
worker
.
model_runner
.
save_remote_model
(
url
)
def
save_sharded_model
(
self
,
params
):
worker
=
self
.
tp_worker
.
worker
worker
.
model_runner
.
save_sharded_model
(
path
=
params
[
"path"
],
pattern
=
params
[
"pattern"
],
max_size
=
params
[
"max_size"
],
)
def
abort_request
(
self
,
recv_req
:
AbortReq
):
# Delete requests in the waiting queue
to_del
=
[]
...
...
@@ -2515,16 +2244,6 @@ class Scheduler(
def
_pause_engine
(
self
)
->
Tuple
[
List
[
Req
],
int
]:
raise
NotImplementedError
()
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
"""In-place update of the weights from disk."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_disk
(
recv_req
)
if
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightFromDiskReqOutput
(
success
,
message
,
0
)
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
)
->
LoadLoRAAdapterReqOutput
:
...
...
@@ -2541,81 +2260,6 @@ class Scheduler(
result
=
self
.
tp_worker
.
unload_lora_adapter
(
recv_req
)
return
result
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
,
)
->
Tuple
[
bool
,
str
]:
"""Update the online model parameter."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
if
success
:
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
"""Update the online model parameter from tensors."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_tensor
(
recv_req
)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if
success
:
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
barrier
(
group
=
self
.
tp_cpu_group
)
return
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
GetWeightsByNameReqOutput
(
parameter
)
def
release_memory_occupation
(
self
,
recv_req
:
ReleaseMemoryOccupationReqInput
):
tags
=
recv_req
.
tags
if
tags
is
None
or
len
(
tags
)
==
0
:
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_KV_CACHE
)
self
.
flush_cache
()
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
stashed_model_static_state
=
_export_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_WEIGHTS
)
return
ReleaseMemoryOccupationReqOutput
()
def
resume_memory_occupation
(
self
,
recv_req
:
ResumeMemoryOccupationReqInput
):
tags
=
recv_req
.
tags
if
tags
is
None
or
len
(
tags
)
==
0
:
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_WEIGHTS
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
_import_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
,
)
del
self
.
stashed_model_static_state
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_KV_CACHE
)
return
ResumeMemoryOccupationReqOutput
()
def
slow_down
(
self
,
recv_req
:
SlowDownReqInput
):
t
=
recv_req
.
forward_sleep_time
if
t
is
not
None
and
t
<=
0
:
...
...
@@ -2623,254 +2267,6 @@ class Scheduler(
self
.
forward_sleep_time
=
t
return
SlowDownReqOutput
()
def
profile
(
self
,
recv_req
:
ProfileReq
):
if
recv_req
.
type
==
ProfileReqType
.
START_PROFILE
:
if
recv_req
.
profile_by_stage
or
recv_req
.
start_step
:
return
self
.
init_profile
(
recv_req
.
output_dir
,
recv_req
.
start_step
,
recv_req
.
num_steps
,
recv_req
.
activities
,
recv_req
.
with_stack
,
recv_req
.
record_shapes
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
else
:
self
.
init_profile
(
recv_req
.
output_dir
,
recv_req
.
start_step
,
recv_req
.
num_steps
,
recv_req
.
activities
,
recv_req
.
with_stack
,
recv_req
.
record_shapes
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
return
self
.
start_profile
(
True
)
else
:
return
self
.
stop_profile
()
def
init_profile
(
self
,
output_dir
:
Optional
[
str
],
start_step
:
Optional
[
int
],
num_steps
:
Optional
[
int
],
activities
:
Optional
[
List
[
str
]],
with_stack
:
Optional
[
bool
],
record_shapes
:
Optional
[
bool
],
profile_by_stage
:
bool
,
profile_id
:
str
,
)
->
ProfileReqOutput
:
if
self
.
profile_in_progress
:
return
ProfileReqOutput
(
success
=
False
,
message
=
"Profiling is already in progress. Call /stop_profile first."
,
)
self
.
profile_by_stage
=
profile_by_stage
if
output_dir
is
None
:
output_dir
=
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
"/tmp"
)
if
activities
is
None
:
activities
=
[
"CPU"
,
"GPU"
]
self
.
torch_profiler_output_dir
=
output_dir
self
.
torch_profiler_with_stack
=
with_stack
self
.
torch_profiler_record_shapes
=
record_shapes
self
.
profiler_activities
=
activities
self
.
profile_id
=
profile_id
if
start_step
:
self
.
profiler_start_forward_ct
=
max
(
start_step
,
self
.
forward_ct
+
1
)
if
num_steps
:
self
.
profile_steps
=
num_steps
if
self
.
profile_by_stage
:
self
.
profiler_target_prefill_ct
=
num_steps
self
.
profiler_target_decode_ct
=
num_steps
self
.
profiler_prefill_ct
=
0
self
.
profiler_decode_ct
=
0
elif
start_step
:
self
.
profiler_target_forward_ct
=
(
self
.
profiler_start_forward_ct
+
num_steps
)
else
:
self
.
profiler_target_forward_ct
=
self
.
forward_ct
+
num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else
:
self
.
profiler_target_forward_ct
=
None
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded"
)
def
start_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
stage_str
=
f
" for
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
)
activities
=
self
.
profiler_activities
with_stack
=
self
.
torch_profiler_with_stack
record_shapes
=
self
.
torch_profiler_record_shapes
activity_map
=
{
"CPU"
:
torch
.
profiler
.
ProfilerActivity
.
CPU
,
"GPU"
:
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
}
torchprof_activities
=
[
activity_map
[
a
]
for
a
in
activities
if
a
in
activity_map
]
if
"RPD"
in
activities
:
from
rpdTracerControl
import
rpdTracerControl
rpdTracerControl
.
skipCreate
()
self
.
rpd_profile_path
=
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
"rpd-"
+
str
(
time
.
time
())
+
f
"-TP-
{
self
.
tp_rank
}
"
+
".trace.json.gz"
,
)
if
self
.
tp_rank
==
0
:
import
sqlite3
from
rocpd.schema
import
RocpdSchema
if
os
.
path
.
exists
(
"trace.rpd"
):
os
.
unlink
(
"trace.rpd"
)
schema
=
RocpdSchema
()
connection
=
sqlite3
.
connect
(
"trace.rpd"
)
schema
.
writeSchema
(
connection
)
connection
.
commit
()
del
connection
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
self
.
rpd_profiler
=
rpdTracerControl
()
self
.
rpd_profiler
.
setPythonTrace
(
True
)
self
.
rpd_profiler
.
start
()
self
.
rpd_profiler
.
rangePush
(
""
,
"rpd profile range"
,
""
)
self
.
profile_in_progress
=
True
elif
torchprof_activities
:
self
.
torch_profiler
=
torch
.
profiler
.
profile
(
activities
=
torchprof_activities
,
with_stack
=
with_stack
if
with_stack
is
not
None
else
True
,
record_shapes
=
record_shapes
if
record_shapes
is
not
None
else
False
,
)
self
.
torch_profiler
.
start
()
self
.
profile_in_progress
=
True
if
"MEM"
in
activities
:
torch
.
cuda
.
memory
.
_record_memory_history
(
max_entries
=
100000
)
self
.
profile_in_progress
=
True
if
"CUDA_PROFILER"
in
activities
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
self
.
profile_in_progress
=
True
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded"
)
def
stop_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
if
not
self
.
profile_in_progress
:
return
ProfileReqOutput
(
success
=
False
,
message
=
"Profiling is not in progress. Call /start_profile first."
,
)
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
stage_suffix
=
f
"-
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
if
self
.
torch_profiler
is
not
None
:
self
.
torch_profiler
.
stop
()
self
.
torch_profiler
.
export_chrome_trace
(
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
self
.
profile_id
+
f
"-TP-
{
self
.
tp_rank
}
"
+
stage_suffix
+
".trace.json.gz"
,
)
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
if
self
.
rpd_profiler
is
not
None
:
self
.
rpd_profiler
.
rangePop
()
self
.
rpd_profiler
.
stop
()
self
.
rpd_profiler
.
flush
()
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
if
self
.
tp_rank
==
0
:
from
sglang.srt.utils
import
rpd_to_chrome_trace
rpd_to_chrome_trace
(
"trace.rpd"
,
self
.
rpd_profile_path
)
self
.
rpd_profiler
=
None
self
.
rpd_profiler_path
=
None
if
self
.
profiler_activities
is
not
None
and
"MEM"
in
self
.
profiler_activities
:
memory_profile_path
=
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
str
(
time
.
time
())
+
f
"-TP-
{
self
.
tp_rank
}
-memory"
+
stage_suffix
+
".pickle"
,
)
torch
.
cuda
.
memory
.
_dump_snapshot
(
memory_profile_path
)
torch
.
cuda
.
memory
.
_record_memory_history
(
enabled
=
None
)
if
"CUDA_PROFILER"
in
self
.
profiler_activities
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
logger
.
info
(
"Profiling done. Traces are saved to: %s"
,
self
.
torch_profiler_output_dir
,
)
self
.
torch_profiler
=
None
self
.
profile_in_progress
=
False
self
.
profiler_start_forward_ct
=
None
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded."
)
def
_profile_batch_predicate
(
self
,
batch
):
if
self
.
profile_by_stage
:
if
batch
.
forward_mode
.
is_prefill
():
if
self
.
profiler_prefill_ct
==
0
:
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_prefill_ct
+=
1
if
self
.
profiler_prefill_ct
>
self
.
profiler_target_prefill_ct
:
if
self
.
profile_in_progress
:
self
.
stop_profile
(
stage
=
ForwardMode
.
EXTEND
)
elif
batch
.
forward_mode
.
is_decode
():
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profile_in_progress
:
# force trace flush
self
.
stop_profile
(
ForwardMode
.
EXTEND
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_decode_ct
+=
1
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
if
self
.
profile_in_progress
:
self
.
stop_profile
(
stage
=
ForwardMode
.
DECODE
)
elif
batch
.
forward_mode
.
is_idle
():
pass
else
:
raise
RuntimeError
(
f
"unsupported profile stage:
{
batch
.
forward_mode
}
"
)
else
:
# Check profiler
if
(
self
.
profiler_target_forward_ct
and
self
.
profiler_target_forward_ct
<=
self
.
forward_ct
):
self
.
stop_profile
()
if
(
self
.
profiler_start_forward_ct
and
self
.
profiler_start_forward_ct
==
self
.
forward_ct
):
self
.
start_profile
()
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
if
recv_req
==
ExpertDistributionReq
.
START_RECORD
:
get_global_expert_distribution_recorder
().
start_record
()
...
...
@@ -2879,7 +2275,7 @@ class Scheduler(
elif
recv_req
==
ExpertDistributionReq
.
DUMP_RECORD
:
get_global_expert_distribution_recorder
().
dump_record
()
else
:
raise
ValueError
(
"Unrecognized ExpertDistributionReq value"
)
raise
ValueError
(
f
"Unrecognized ExpertDistributionReq value
:
{
recv_req
=
}
"
)
return
ExpertDistributionReqOutput
()
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
):
...
...
@@ -2915,34 +2311,41 @@ class Scheduler(
prefix
+=
f
" PP
{
self
.
pp_rank
}
"
return
prefix
def
_publish_kv_events
(
self
):
if
self
.
enable_kv_cache_events
:
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
def
current_scheduler_metrics_enabled
(
self
):
return
self
.
attn_tp_rank
==
0
or
self
.
enable_metrics_for_all_schedulers
def
maybe_sleep_on_idle
(
self
):
if
self
.
idle_sleeper
is
not
None
:
self
.
idle_sleeper
.
maybe_sleep
()
def
is_health_check_generate_req
(
recv_req
):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
class
IdleSleeper
:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
def
is_work_request
(
recv_req
):
return
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
))
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def
__init__
(
self
,
sockets
):
self
.
poller
=
zmq
.
Poller
()
for
s
in
sockets
:
self
.
poller
.
register
(
s
,
zmq
.
POLLIN
)
def
maybe_sleep
(
self
):
self
.
poller
.
poll
(
1000
)
def
_export_static_state
(
model
):
return
dict
(
buffers
=
[
(
name
,
buffer
.
detach
().
clone
())
for
name
,
buffer
in
model
.
named_buffers
()
]
)
def
is_health_check_generate_req
(
recv_req
):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
def
_import_static_state
(
model
,
static_params
):
self_named_buffers
=
dict
(
model
.
named_buffers
())
for
name
,
tensor
in
static_params
[
"buffers"
]:
self_named_buffers
[
name
][...]
=
tensor
def
is_work_request
(
recv_req
):
return
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
))
def
run_scheduler_process
(
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
0 → 100644
View file @
a4c3b121
import
logging
import
time
from
collections
import
defaultdict
from
typing
import
List
,
Optional
from
sglang.srt.disaggregation.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.managers.schedule_policy
import
PrefillAdder
from
sglang.srt.managers.scheduler
import
Req
,
ScheduleBatch
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
class
KvMetrics
:
def
__init__
(
self
):
self
.
request_active_slots
=
None
self
.
request_total_slots
=
None
self
.
kv_active_blocks
=
None
self
.
kv_total_blocks
=
None
self
.
num_requests_waiting
=
None
self
.
gpu_cache_usage_perc
=
None
self
.
gpu_prefix_cache_hit_rate
=
None
self
.
data_parallel_rank
=
None
class
SchedulerMetricsMixin
:
def
init_metrics
(
self
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
]):
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
step_time_dict
=
defaultdict
(
list
)
# Dict[batch size -> step time]
self
.
spec_num_total_accepted_tokens
=
0
self
.
spec_num_total_forward_ct
=
0
self
.
cum_spec_accept_length
=
0
self
.
cum_spec_accept_count
=
0
self
.
total_retracted_reqs
=
0
self
.
stats
=
SchedulerStats
()
if
self
.
enable_metrics
:
engine_type
=
"unified"
labels
=
{
"model_name"
:
self
.
server_args
.
served_model_name
,
"engine_type"
:
engine_type
,
"tp_rank"
:
tp_rank
,
"pp_rank"
:
pp_rank
,
}
if
dp_rank
is
not
None
:
labels
[
"dp_rank"
]
=
dp_rank
self
.
metrics_collector
=
SchedulerMetricsCollector
(
labels
=
labels
)
def
init_kv_events
(
self
,
kv_events_config
:
Optional
[
str
]):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
,
self
.
attn_dp_rank
)
def
log_prefill_stats
(
self
,
adder
:
PrefillAdder
,
can_run_list
:
List
[
Req
],
running_bs
:
int
,
):
gap_latency
=
time
.
perf_counter
()
-
self
.
last_prefill_stats_tic
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
last_input_throughput
=
self
.
last_prefill_tokens
/
gap_latency
self
.
last_prefill_tokens
=
adder
.
log_input_tokens
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"token usage:
{
token_usage
:.
2
f
}
, "
num_new_seq
=
len
(
can_run_list
)
f
=
(
f
"Prefill batch. "
f
"#new-seq:
{
num_new_seq
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"
{
token_msg
}
"
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
f
+=
f
"#unbootstrapped-req:
{
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
+=
f
"#transferring-req:
{
len
(
self
.
disagg_prefill_inflight_queue
)
}
, "
f
+=
f
"input throughput (token/s):
{
self
.
last_input_throughput
:.
2
f
}
, "
else
:
f
+=
f
"#running-req:
{
running_bs
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
logger
.
info
(
f
)
if
self
.
enable_metrics
:
total_tokens
=
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
cache_hit_rate
=
(
adder
.
log_hit_tokens
/
total_tokens
if
total_tokens
>
0
else
0.0
)
self
.
stats
.
num_running_reqs
=
running_bs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
cache_hit_rate
=
cache_hit_rate
total_queue_latency
=
0
for
req
in
can_run_list
:
total_queue_latency
+=
req
.
queue_time_end
-
req
.
queue_time_start
self
.
stats
.
avg_request_queue_latency
=
total_queue_latency
/
num_new_seq
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
):
batch
=
running_batch
or
self
.
running_batch
gap_latency
=
time
.
perf_counter
()
-
self
.
last_decode_stats_tic
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_gen_throughput
=
self
.
num_generated_tokens
/
gap_latency
self
.
num_generated_tokens
=
0
num_running_reqs
=
len
(
batch
.
reqs
)
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
f
"#full token:
{
full_num_used
}
, "
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"#swa token:
{
swa_num_used
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"#token:
{
num_used
}
, "
f
"token usage:
{
token_usage
:.
2
f
}
, "
if
RECORD_STEP_TIME
:
self
.
step_time_dict
[
num_running_reqs
].
append
(
gap_latency
/
self
.
server_args
.
decode_log_interval
)
msg
=
f
"Decode batch. #running-req:
{
num_running_reqs
}
,
{
token_msg
}
"
if
self
.
spec_algorithm
.
is_none
():
spec_accept_length
=
0
else
:
spec_accept_length
=
(
self
.
spec_num_total_accepted_tokens
/
self
.
spec_num_total_forward_ct
)
self
.
cum_spec_accept_length
+=
self
.
spec_num_total_accepted_tokens
self
.
cum_spec_accept_count
+=
self
.
spec_num_total_forward_ct
self
.
spec_num_total_accepted_tokens
=
self
.
spec_num_total_forward_ct
=
0
msg
+=
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
msg
+=
f
"pre-allocated usage:
{
self
.
disagg_decode_prealloc_queue
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
logger
.
info
(
msg
)
if
self
.
enable_metrics
:
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
stats
.
total_retracted_reqs
=
self
.
total_retracted_reqs
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
_emit_kv_metrics
(
self
):
kv_metrics
=
KvMetrics
()
kv_metrics
.
request_active_slots
=
self
.
stats
.
num_running_reqs
kv_metrics
.
request_total_slots
=
self
.
max_running_requests
kv_metrics
.
kv_active_blocks
=
int
(
self
.
stats
.
token_usage
*
self
.
max_total_num_tokens
)
kv_metrics
.
kv_total_blocks
=
self
.
max_total_num_tokens
kv_metrics
.
num_requests_waiting
=
self
.
stats
.
num_queue_reqs
kv_metrics
.
gpu_cache_usage_perc
=
self
.
stats
.
token_usage
kv_metrics
.
gpu_prefix_cache_hit_rate
=
self
.
stats
.
cache_hit_rate
kv_metrics
.
data_parallel_rank
=
self
.
dp_rank
if
self
.
dp_rank
is
not
None
else
0
if
not
self
.
send_metrics_from_scheduler
.
closed
:
self
.
send_metrics_from_scheduler
.
send_pyobj
(
kv_metrics
)
def
_publish_kv_events
(
self
):
if
self
.
enable_kv_cache_events
:
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
python/sglang/srt/managers/scheduler_profiler_mixin.py
0 → 100644
View file @
a4c3b121
import
logging
import
os
import
time
from
pathlib
import
Path
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.managers.io_struct
import
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerProfilerMixin
:
def
init_profier
(
self
):
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_start_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
def
init_profile
(
self
,
output_dir
:
Optional
[
str
],
start_step
:
Optional
[
int
],
num_steps
:
Optional
[
int
],
activities
:
Optional
[
List
[
str
]],
with_stack
:
Optional
[
bool
],
record_shapes
:
Optional
[
bool
],
profile_by_stage
:
bool
,
profile_id
:
str
,
)
->
ProfileReqOutput
:
if
self
.
profile_in_progress
:
return
ProfileReqOutput
(
success
=
False
,
message
=
"Profiling is already in progress. Call /stop_profile first."
,
)
self
.
profile_by_stage
=
profile_by_stage
if
output_dir
is
None
:
output_dir
=
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
"/tmp"
)
if
activities
is
None
:
activities
=
[
"CPU"
,
"GPU"
]
self
.
torch_profiler_output_dir
=
output_dir
self
.
torch_profiler_with_stack
=
with_stack
self
.
torch_profiler_record_shapes
=
record_shapes
self
.
profiler_activities
=
activities
self
.
profile_id
=
profile_id
if
start_step
:
self
.
profiler_start_forward_ct
=
max
(
start_step
,
self
.
forward_ct
+
1
)
if
num_steps
:
self
.
profile_steps
=
num_steps
if
self
.
profile_by_stage
:
self
.
profiler_target_prefill_ct
=
num_steps
self
.
profiler_target_decode_ct
=
num_steps
self
.
profiler_prefill_ct
=
0
self
.
profiler_decode_ct
=
0
elif
start_step
:
self
.
profiler_target_forward_ct
=
(
self
.
profiler_start_forward_ct
+
num_steps
)
else
:
self
.
profiler_target_forward_ct
=
self
.
forward_ct
+
num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else
:
self
.
profiler_target_forward_ct
=
None
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded"
)
def
start_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
stage_str
=
f
" for
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
)
activities
=
self
.
profiler_activities
with_stack
=
self
.
torch_profiler_with_stack
record_shapes
=
self
.
torch_profiler_record_shapes
activity_map
=
{
"CPU"
:
torch
.
profiler
.
ProfilerActivity
.
CPU
,
"GPU"
:
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
}
torchprof_activities
=
[
activity_map
[
a
]
for
a
in
activities
if
a
in
activity_map
]
if
"RPD"
in
activities
:
from
rpdTracerControl
import
rpdTracerControl
rpdTracerControl
.
skipCreate
()
self
.
rpd_profile_path
=
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
"rpd-"
+
str
(
time
.
time
())
+
f
"-TP-
{
self
.
tp_rank
}
"
+
".trace.json.gz"
,
)
if
self
.
tp_rank
==
0
:
import
sqlite3
from
rocpd.schema
import
RocpdSchema
if
os
.
path
.
exists
(
"trace.rpd"
):
os
.
unlink
(
"trace.rpd"
)
schema
=
RocpdSchema
()
connection
=
sqlite3
.
connect
(
"trace.rpd"
)
schema
.
writeSchema
(
connection
)
connection
.
commit
()
del
connection
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
self
.
rpd_profiler
=
rpdTracerControl
()
self
.
rpd_profiler
.
setPythonTrace
(
True
)
self
.
rpd_profiler
.
start
()
self
.
rpd_profiler
.
rangePush
(
""
,
"rpd profile range"
,
""
)
self
.
profile_in_progress
=
True
elif
torchprof_activities
:
self
.
torch_profiler
=
torch
.
profiler
.
profile
(
activities
=
torchprof_activities
,
with_stack
=
with_stack
if
with_stack
is
not
None
else
True
,
record_shapes
=
record_shapes
if
record_shapes
is
not
None
else
False
,
)
self
.
torch_profiler
.
start
()
self
.
profile_in_progress
=
True
if
"MEM"
in
activities
:
torch
.
cuda
.
memory
.
_record_memory_history
(
max_entries
=
100000
)
self
.
profile_in_progress
=
True
if
"CUDA_PROFILER"
in
activities
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
self
.
profile_in_progress
=
True
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded"
)
def
stop_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
if
not
self
.
profile_in_progress
:
return
ProfileReqOutput
(
success
=
False
,
message
=
"Profiling is not in progress. Call /start_profile first."
,
)
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
stage_suffix
=
f
"-
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
if
self
.
torch_profiler
is
not
None
:
self
.
torch_profiler
.
stop
()
self
.
torch_profiler
.
export_chrome_trace
(
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
self
.
profile_id
+
f
"-TP-
{
self
.
tp_rank
}
"
+
stage_suffix
+
".trace.json.gz"
,
)
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
if
self
.
rpd_profiler
is
not
None
:
self
.
rpd_profiler
.
rangePop
()
self
.
rpd_profiler
.
stop
()
self
.
rpd_profiler
.
flush
()
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
if
self
.
tp_rank
==
0
:
from
sglang.srt.utils
import
rpd_to_chrome_trace
rpd_to_chrome_trace
(
"trace.rpd"
,
self
.
rpd_profile_path
)
self
.
rpd_profiler
=
None
self
.
rpd_profiler_path
=
None
if
self
.
profiler_activities
is
not
None
and
"MEM"
in
self
.
profiler_activities
:
memory_profile_path
=
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
str
(
time
.
time
())
+
f
"-TP-
{
self
.
tp_rank
}
-memory"
+
stage_suffix
+
".pickle"
,
)
torch
.
cuda
.
memory
.
_dump_snapshot
(
memory_profile_path
)
torch
.
cuda
.
memory
.
_record_memory_history
(
enabled
=
None
)
if
"CUDA_PROFILER"
in
self
.
profiler_activities
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
logger
.
info
(
"Profiling done. Traces are saved to: %s"
,
self
.
torch_profiler_output_dir
,
)
self
.
torch_profiler
=
None
self
.
profile_in_progress
=
False
self
.
profiler_start_forward_ct
=
None
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded."
)
def
_profile_batch_predicate
(
self
,
batch
):
if
self
.
profile_by_stage
:
if
batch
.
forward_mode
.
is_prefill
():
if
self
.
profiler_prefill_ct
==
0
:
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_prefill_ct
+=
1
if
self
.
profiler_prefill_ct
>
self
.
profiler_target_prefill_ct
:
if
self
.
profile_in_progress
:
self
.
stop_profile
(
stage
=
ForwardMode
.
EXTEND
)
elif
batch
.
forward_mode
.
is_decode
():
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profile_in_progress
:
# force trace flush
self
.
stop_profile
(
ForwardMode
.
EXTEND
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_decode_ct
+=
1
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
if
self
.
profile_in_progress
:
self
.
stop_profile
(
stage
=
ForwardMode
.
DECODE
)
elif
batch
.
forward_mode
.
is_idle
():
pass
else
:
raise
RuntimeError
(
f
"unsupported profile stage:
{
batch
.
forward_mode
}
"
)
else
:
# Check profiler
if
(
self
.
profiler_target_forward_ct
and
self
.
profiler_target_forward_ct
<=
self
.
forward_ct
):
self
.
stop_profile
()
if
(
self
.
profiler_start_forward_ct
and
self
.
profiler_start_forward_ct
==
self
.
forward_ct
):
self
.
start_profile
()
def
profile
(
self
,
recv_req
:
ProfileReq
):
if
recv_req
.
type
==
ProfileReqType
.
START_PROFILE
:
if
recv_req
.
profile_by_stage
or
recv_req
.
start_step
:
return
self
.
init_profile
(
recv_req
.
output_dir
,
recv_req
.
start_step
,
recv_req
.
num_steps
,
recv_req
.
activities
,
recv_req
.
with_stack
,
recv_req
.
record_shapes
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
else
:
self
.
init_profile
(
recv_req
.
output_dir
,
recv_req
.
start_step
,
recv_req
.
num_steps
,
recv_req
.
activities
,
recv_req
.
with_stack
,
recv_req
.
record_shapes
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
return
self
.
start_profile
(
True
)
else
:
return
self
.
stop_profile
()
python/sglang/srt/managers/scheduler_update_weights_mixin.py
0 → 100644
View file @
a4c3b121
import
logging
from
typing
import
Tuple
import
torch
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.managers.io_struct
import
(
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerUpdateWeightsMixin
:
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
"""In-place update of the weights from disk."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_disk
(
recv_req
)
if
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightFromDiskReqOutput
(
success
,
message
,
0
)
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
,
)
->
Tuple
[
bool
,
str
]:
"""Update the online model parameter."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
if
success
:
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
"""Update the online model parameter from tensors."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_tensor
(
recv_req
)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if
success
:
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
torch
.
distributed
.
barrier
(
group
=
self
.
tp_cpu_group
)
return
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
GetWeightsByNameReqOutput
(
parameter
)
def
release_memory_occupation
(
self
,
recv_req
:
ReleaseMemoryOccupationReqInput
):
tags
=
recv_req
.
tags
if
tags
is
None
or
len
(
tags
)
==
0
:
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_KV_CACHE
)
self
.
flush_cache
()
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
stashed_model_static_state
=
_export_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_WEIGHTS
)
return
ReleaseMemoryOccupationReqOutput
()
def
resume_memory_occupation
(
self
,
recv_req
:
ResumeMemoryOccupationReqInput
):
tags
=
recv_req
.
tags
if
tags
is
None
or
len
(
tags
)
==
0
:
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_WEIGHTS
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
_import_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
,
)
del
self
.
stashed_model_static_state
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_KV_CACHE
)
return
ResumeMemoryOccupationReqOutput
()
def
save_remote_model
(
self
,
params
):
url
=
params
[
"url"
]
worker
=
self
.
tp_worker
.
worker
worker
.
model_runner
.
save_remote_model
(
url
)
def
save_sharded_model
(
self
,
params
):
worker
=
self
.
tp_worker
.
worker
worker
.
model_runner
.
save_sharded_model
(
path
=
params
[
"path"
],
pattern
=
params
[
"pattern"
],
max_size
=
params
[
"max_size"
],
)
def
_export_static_state
(
model
):
return
dict
(
buffers
=
[
(
name
,
buffer
.
detach
().
clone
())
for
name
,
buffer
in
model
.
named_buffers
()
]
)
def
_import_static_state
(
model
,
static_params
):
self_named_buffers
=
dict
(
model
.
named_buffers
())
for
name
,
tensor
in
static_params
[
"buffers"
]:
self_named_buffers
[
name
][...]
=
tensor
python/sglang/srt/managers/tokenizer_manager.py
View file @
a4c3b121
...
...
@@ -170,16 +170,6 @@ class ReqState:
output_token_ids_logprobs_idx
:
List
=
dataclasses
.
field
(
default_factory
=
list
)
def
_determine_tensor_transport_mode
(
server_args
:
ServerArgs
)
->
TensorTransportMode
:
is_cross_node
=
server_args
.
dist_init_addr
if
is_cross_node
:
# Fallback to default CPU transport for multi-node
return
"default"
else
:
return
"cuda_ipc"
class
TokenizerManager
:
"""TokenizerManager is a process that tokenizes the text."""
...
...
@@ -199,16 +189,6 @@ class TokenizerManager:
else
None
)
self
.
crash_dump_folder
=
server_args
.
crash_dump_folder
self
.
crash_dump_performed
=
False
# Flag to ensure dump is only called once
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
self
.
recv_from_detokenizer
=
get_zmq_socket
(
context
,
zmq
.
PULL
,
port_args
.
tokenizer_ipc_name
,
True
)
self
.
send_to_scheduler
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
scheduler_input_ipc_name
,
True
)
# Read model args
self
.
model_path
=
server_args
.
model_path
...
...
@@ -218,8 +198,7 @@ class TokenizerManager:
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
self
.
context_len
=
self
.
model_config
.
context_len
self
.
image_token_id
=
self
.
model_config
.
image_token_id
self
.
_updating
=
False
self
.
_cond
=
asyncio
.
Condition
()
self
.
max_req_input_len
=
None
# Will be set later in engine.py
if
self
.
model_config
.
is_multimodal
:
import_processors
()
...
...
@@ -258,39 +237,57 @@ class TokenizerManager:
revision
=
server_args
.
revision
,
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self
.
lora_registry
=
LoRARegistry
(
self
.
server_args
.
lora_paths
or
{})
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
self
.
recv_from_detokenizer
=
get_zmq_socket
(
context
,
zmq
.
PULL
,
port_args
.
tokenizer_ipc_name
,
True
)
self
.
send_to_scheduler
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
scheduler_input_ipc_name
,
True
)
#
Store
states
#
Request
states
self
.
no_create_loop
=
False
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
asyncio_tasks
=
set
()
# Health check
self
.
health_check_failed
=
False
self
.
gracefully_exit
=
False
self
.
last_receive_tstamp
=
0
# Dumping
self
.
dump_requests_folder
=
""
# By default do not dump
self
.
dump_requests_threshold
=
1000
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
crash_dump_request_list
:
deque
[
Tuple
]
=
deque
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
crash_dump_request_list
:
deque
[
Tuple
]
=
deque
()
self
.
crash_dump_performed
=
False
# Flag to ensure dump is only called once
# Session
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
max_req_input_len
=
None
self
.
asyncio_tasks
=
set
()
# Weight updates
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
None
)
self
.
_is_updating
=
False
self
.
_is_updating_cond
=
asyncio
.
Condition
()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self
.
lora_registry
=
LoRARegistry
(
self
.
server_args
.
lora_paths
or
{})
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self
.
lora_update_lock
=
asyncio
.
Lock
()
# For
pd
disaggregtion
# For
PD
disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
...
...
@@ -458,17 +455,11 @@ class TokenizerManager:
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
created_time
=
time
.
time
()
async
with
self
.
_cond
:
await
self
.
_cond
.
wait_for
(
lambda
:
not
self
.
_updating
)
self
.
auto_create_handle_loop
()
obj
.
normalize_batch_and_arguments
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
async
with
self
.
_is_updating_cond
:
await
self
.
_is_updating_cond
.
wait_for
(
lambda
:
not
self
.
_is_updating
)
if
self
.
log_requests
:
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
...
...
@@ -567,6 +558,12 @@ class TokenizerManager:
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
# Check total tokens (input + max_new_tokens)
max_new_tokens
=
obj
.
sampling_params
.
get
(
"max_new_tokens"
)
if
(
...
...
@@ -959,14 +956,14 @@ class TokenizerManager:
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
async
def
pause_generation
(
self
):
async
with
self
.
_cond
:
self
.
_updating
=
True
async
with
self
.
_
is_updating_
cond
:
self
.
_
is_
updating
=
True
self
.
abort_request
(
abort_all
=
True
)
async
def
continue_generation
(
self
):
async
with
self
.
_cond
:
self
.
_updating
=
False
self
.
_cond
.
notify_all
()
async
with
self
.
_
is_updating_
cond
:
self
.
_
is_
updating
=
False
self
.
_
is_updating_
cond
.
notify_all
()
async
def
update_weights_from_disk
(
self
,
...
...
@@ -1208,14 +1205,6 @@ class TokenizerManager:
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
async
def
get_load
(
self
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
async
def
set_internal_state
(
self
,
obj
:
SetInternalStateReq
)
->
SetInternalStateReqOutput
:
...
...
@@ -1224,6 +1213,14 @@ class TokenizerManager:
)
return
[
res
.
internal_state
for
res
in
responses
]
async
def
get_load
(
self
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
def
get_log_request_metadata
(
self
):
max_length
=
None
skip_names
=
None
...
...
@@ -1343,11 +1340,24 @@ class TokenizerManager:
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger
.
error
(
f
"Dumping requests before crash.
{
self
.
crash_dump_folder
=
}
"
)
self
.
crash_dump_performed
=
True
if
not
self
.
crash_dump_folder
:
return
logger
.
error
(
f
"Dumping requests before crash.
{
self
.
crash_dump_folder
=
}
"
)
self
.
crash_dump_performed
=
True
# Check if NFS directory is available
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
# expected_nfs_dir, os.W_OK
# )
use_nfs_dir
=
False
if
not
use_nfs_dir
:
logger
.
error
(
f
"Expected NFS directory is not available or writable. Uploading to GCS."
)
data_to_dump
=
[]
if
self
.
crash_dump_request_list
:
data_to_dump
.
extend
(
self
.
crash_dump_request_list
)
...
...
@@ -1357,7 +1367,12 @@ class TokenizerManager:
for
rid
,
state
in
self
.
rid_to_state
.
items
():
if
not
state
.
finished
:
unfinished_requests
.
append
(
(
state
.
obj
,
{},
state
.
created_time
,
time
.
time
())
(
state
.
obj
,
state
.
out_list
[
-
1
]
if
state
.
out_list
else
{},
state
.
created_time
,
time
.
time
(),
)
)
if
unfinished_requests
:
data_to_dump
.
extend
(
unfinished_requests
)
...
...
@@ -1365,10 +1380,11 @@ class TokenizerManager:
if
not
data_to_dump
:
return
object_name
=
f
'crash_dump_
{
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
}
.pkl'
filename
=
os
.
path
.
join
(
self
.
crash_dump_folder
,
os
.
getenv
(
"HOSTNAME"
,
None
),
f
"crash_dump_
{
datetime
.
now
().
strftime
(
'%Y-%m-%d_%H-%M-%S'
)
}
.pkl"
,
object_name
,
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
...
...
@@ -1383,6 +1399,24 @@ class TokenizerManager:
f
"Dumped
{
len
(
self
.
crash_dump_request_list
)
}
finished and
{
len
(
unfinished_requests
)
}
unfinished requests before crash to
{
filename
}
"
)
def
_upload_file_to_gcs
(
bucket_name
,
source_file_path
,
object_name
):
from
google.cloud
import
storage
client
=
storage
.
Client
()
bucket
=
client
.
bucket
(
bucket_name
)
blob
=
bucket
.
blob
(
object_name
)
blob
.
upload_from_filename
(
source_file_path
,
if_generation_match
=
0
)
logger
.
error
(
f
"Successfully uploaded
{
source_file_path
}
to gs://
{
bucket_name
}
/
{
object_name
}
"
)
if
not
use_nfs_dir
:
_upload_file_to_gcs
(
"sglang_crash_dump"
,
filename
,
os
.
getenv
(
"HOSTNAME"
,
None
)
+
"/"
+
object_name
,
)
async
def
sigterm_watchdog
(
self
):
while
not
self
.
gracefully_exit
:
await
asyncio
.
sleep
(
5
)
...
...
@@ -1426,7 +1460,7 @@ class TokenizerManager:
while
True
:
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
self
.
_result_dispatcher
(
recv_obj
)
self
.
last_receive_tstamp
=
time
.
perf_counter
()
self
.
last_receive_tstamp
=
time
.
time
()
def
_handle_batch_output
(
self
,
...
...
@@ -1697,24 +1731,13 @@ class TokenizerManager:
self
.
dump_requests_folder
,
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
+
".pkl"
,
)
logger
.
info
(
f
"Dump
{
len
(
self
.
dump_request_list
)
}
requests to
{
filename
}
"
)
to_dump
=
self
.
dump_request_list
self
.
_dump_data_to_file
(
data_list
=
self
.
dump_request_list
,
filename
=
filename
,
log_message
=
f
"Dump
{
len
(
self
.
dump_request_list
)
}
requests to
{
filename
}
"
,
)
self
.
dump_request_list
=
[]
to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
to_dump
,
}
def
background_task
():
os
.
makedirs
(
self
.
dump_requests_folder
,
exist_ok
=
True
)
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump_with_server_args
,
f
)
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
record_request_for_crash_dump
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
current_time
=
time
.
time
()
self
.
crash_dump_request_list
.
append
(
...
...
@@ -1727,6 +1750,22 @@ class TokenizerManager:
):
self
.
crash_dump_request_list
.
popleft
()
def
_dump_data_to_file
(
self
,
data_list
:
List
[
Tuple
],
filename
:
str
,
log_message
:
str
):
logger
.
info
(
log_message
)
to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
data_list
.
copy
(),
}
def
background_task
():
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump_with_server_args
,
f
)
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_abort_req
(
self
,
recv_obj
):
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
.
finished
=
True
...
...
@@ -1862,6 +1901,16 @@ class TokenizerManager:
return
scores
def
_determine_tensor_transport_mode
(
server_args
:
ServerArgs
)
->
TensorTransportMode
:
is_cross_node
=
server_args
.
dist_init_addr
if
is_cross_node
:
# Fallback to default CPU transport for multi-node
return
"default"
else
:
return
"cuda_ipc"
async
def
print_exception_wrapper
(
func
):
"""
Sometimes an asyncio function does not print exception.
...
...
python/sglang/srt/server_args.py
View file @
a4c3b121
...
...
@@ -2071,6 +2071,9 @@ class PortArgs:
dist_init_host
,
dist_init_port
=
dist_init_addr
port_base
=
int
(
dist_init_port
)
+
1
detokenizer_port
=
port_base
+
1
rpc_port
=
port_base
+
2
metrics_ipc_name
=
port_base
+
3
if
dp_rank
is
None
:
# TokenizerManager to DataParallelController
scheduler_input_port
=
port_base
+
4
...
...
@@ -2080,10 +2083,10 @@ class PortArgs:
return
PortArgs
(
tokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
}
"
,
scheduler_input_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
scheduler_input_port
}
"
,
detokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
+
1
}
"
,
detokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
detokenizer_port
}
"
,
nccl_port
=
nccl_port
,
rpc_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port
_base
+
2
}
"
,
metrics_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
+
3
}
"
,
rpc_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
rpc_
port
}
"
,
metrics_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
metrics_ipc_name
}
"
,
)
...
...
python/sglang/utils.py
View file @
a4c3b121
...
...
@@ -291,17 +291,6 @@ def find_printable_text(text: str):
return
text
[:
text
.
rfind
(
" "
)
+
1
]
def
graceful_registry
(
sub_module_name
:
str
):
def
graceful_shutdown
(
signum
,
frame
):
logger
.
info
(
f
"
{
sub_module_name
}
Received signal to shutdown. Performing graceful shutdown..."
)
if
signum
==
signal
.
SIGTERM
:
logger
.
info
(
f
"
{
sub_module_name
}
receive sigterm"
)
signal
.
signal
(
signal
.
SIGTERM
,
graceful_shutdown
)
class
LazyImport
:
"""Lazy import to make `import sglang` run faster."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment