Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a4c3b121
Unverified
Commit
a4c3b121
authored
Jul 29, 2025
by
Lianmin Zheng
Committed by
GitHub
Jul 29, 2025
Browse files
Split the scheduler into multiple mixin classes to reduce the file size (#8483)
parent
5973675b
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
868 additions
and
784 deletions
+868
-784
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+2
-8
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+2
-6
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+11
-17
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+10
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+0
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+64
-661
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+229
-0
python/sglang/srt/managers/scheduler_profiler_mixin.py
python/sglang/srt/managers/scheduler_profiler_mixin.py
+279
-0
python/sglang/srt/managers/scheduler_update_weights_mixin.py
python/sglang/srt/managers/scheduler_update_weights_mixin.py
+142
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+123
-74
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-3
python/sglang/utils.py
python/sglang/utils.py
+0
-11
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
a4c3b121
...
@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
==
0
):
):
# When the server is idle, do self-check and re-init some states
self
.
self_check_during_idle
()
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
==
0
):
):
# When the server is idle, do self-check and re-init some states
self
.
self_check_during_idle
()
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
self
.
last_batch_in_queue
=
last_batch_in_queue
self
.
last_batch_in_queue
=
last_batch_in_queue
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
a4c3b121
...
@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
self
.
process_disagg_prefill_inflight_queue
()
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
self_check_during_idle
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
...
@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
self
.
process_disagg_prefill_inflight_queue
()
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
self_check_during_idle
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
self
.
last_batch
=
batch
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
...
...
python/sglang/srt/entrypoints/engine.py
View file @
a4c3b121
...
@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs):
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
)
def
sigchld_handler
(
signum
,
frame
):
if
True
:
# Keep this check for internal code compatibility
pid
,
exitcode
=
os
.
waitpid
(
0
,
os
.
WNOHANG
)
if
exitcode
!=
0
:
logger
.
warning
(
f
"Child process unexpectedly failed with
{
exitcode
=
}
.
{
pid
=
}
"
)
signal
.
signal
(
signal
.
SIGCHLD
,
sigchld_handler
)
# Register the signal handler.
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
# This process then clean up the whole process tree
def
sigquit_handler
(
signum
,
frame
):
# Note: This sigquit handler is used in the launch phase, and may be replaced by
# the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
def
launch_phase_sigquit_handler
(
signum
,
frame
):
logger
.
error
(
logger
.
error
(
"Received sigquit from a child process. It usually means the child failed."
"Received sigquit from a child process. It usually means the child failed."
)
)
kill_process_tree
(
os
.
getpid
())
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGQUIT
,
sigquit_handler
)
signal
.
signal
(
signal
.
SIGQUIT
,
launch_phase_
sigquit_handler
)
# Set mp start method
# Set mp start method
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
a4c3b121
...
@@ -238,6 +238,9 @@ async def health() -> Response:
...
@@ -238,6 +238,9 @@ async def health() -> Response:
@
app
.
get
(
"/health_generate"
)
@
app
.
get
(
"/health_generate"
)
async
def
health_generate
(
request
:
Request
)
->
Response
:
async
def
health_generate
(
request
:
Request
)
->
Response
:
"""Check the health of the inference server by generating one token."""
"""Check the health of the inference server by generating one token."""
if
_global_state
.
tokenizer_manager
.
gracefully_exit
:
logger
.
info
(
"Health check request received during shutdown. Returning 503."
)
return
Response
(
status_code
=
503
)
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.0
}
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.0
}
rid
=
f
"HEALTH_CHECK_
{
time
.
time
()
}
"
rid
=
f
"HEALTH_CHECK_
{
time
.
time
()
}
"
...
@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
...
@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
async
for
_
in
_global_state
.
tokenizer_manager
.
generate_request
(
gri
,
request
):
async
for
_
in
_global_state
.
tokenizer_manager
.
generate_request
(
gri
,
request
):
break
break
tic
=
time
.
perf_counter
()
# This request is a special request.
# If the server already has something running, this request will be ignored, so it creates zero overhead.
# If the server is not running, this request will be run, so we know whether the server is healthy.
task
=
asyncio
.
create_task
(
gen
())
task
=
asyncio
.
create_task
(
gen
())
while
time
.
perf_counter
()
<
tic
+
HEALTH_CHECK_TIMEOUT
:
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
tic
=
time
.
time
()
while
time
.
time
()
<
tic
+
HEALTH_CHECK_TIMEOUT
:
await
asyncio
.
sleep
(
1
)
await
asyncio
.
sleep
(
1
)
if
_global_state
.
tokenizer_manager
.
last_receive_tstamp
>
tic
:
if
_global_state
.
tokenizer_manager
.
last_receive_tstamp
>
tic
:
task
.
cancel
()
task
.
cancel
()
...
...
python/sglang/srt/managers/io_struct.py
View file @
a4c3b121
...
@@ -152,8 +152,6 @@ class GenerateReqInput:
...
@@ -152,8 +152,6 @@ class GenerateReqInput:
else
:
else
:
self
.
_normalize_batch_inputs
()
self
.
_normalize_batch_inputs
()
self
.
_validate_session_params
()
def
_validate_inputs
(
self
):
def
_validate_inputs
(
self
):
"""Validate that the input configuration is valid."""
"""Validate that the input configuration is valid."""
if
(
if
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a4c3b121
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/scheduler_metrics_mixin.py
0 → 100644
View file @
a4c3b121
import
logging
import
time
from
collections
import
defaultdict
from
typing
import
List
,
Optional
from
sglang.srt.disaggregation.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.managers.schedule_policy
import
PrefillAdder
from
sglang.srt.managers.scheduler
import
Req
,
ScheduleBatch
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
class
KvMetrics
:
def
__init__
(
self
):
self
.
request_active_slots
=
None
self
.
request_total_slots
=
None
self
.
kv_active_blocks
=
None
self
.
kv_total_blocks
=
None
self
.
num_requests_waiting
=
None
self
.
gpu_cache_usage_perc
=
None
self
.
gpu_prefix_cache_hit_rate
=
None
self
.
data_parallel_rank
=
None
class
SchedulerMetricsMixin
:
def
init_metrics
(
self
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
]):
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
self
.
step_time_dict
=
defaultdict
(
list
)
# Dict[batch size -> step time]
self
.
spec_num_total_accepted_tokens
=
0
self
.
spec_num_total_forward_ct
=
0
self
.
cum_spec_accept_length
=
0
self
.
cum_spec_accept_count
=
0
self
.
total_retracted_reqs
=
0
self
.
stats
=
SchedulerStats
()
if
self
.
enable_metrics
:
engine_type
=
"unified"
labels
=
{
"model_name"
:
self
.
server_args
.
served_model_name
,
"engine_type"
:
engine_type
,
"tp_rank"
:
tp_rank
,
"pp_rank"
:
pp_rank
,
}
if
dp_rank
is
not
None
:
labels
[
"dp_rank"
]
=
dp_rank
self
.
metrics_collector
=
SchedulerMetricsCollector
(
labels
=
labels
)
def
init_kv_events
(
self
,
kv_events_config
:
Optional
[
str
]):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
,
self
.
attn_dp_rank
)
def
log_prefill_stats
(
self
,
adder
:
PrefillAdder
,
can_run_list
:
List
[
Req
],
running_bs
:
int
,
):
gap_latency
=
time
.
perf_counter
()
-
self
.
last_prefill_stats_tic
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
last_input_throughput
=
self
.
last_prefill_tokens
/
gap_latency
self
.
last_prefill_tokens
=
adder
.
log_input_tokens
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"token usage:
{
token_usage
:.
2
f
}
, "
num_new_seq
=
len
(
can_run_list
)
f
=
(
f
"Prefill batch. "
f
"#new-seq:
{
num_new_seq
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"
{
token_msg
}
"
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
f
+=
f
"#unbootstrapped-req:
{
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
+=
f
"#transferring-req:
{
len
(
self
.
disagg_prefill_inflight_queue
)
}
, "
f
+=
f
"input throughput (token/s):
{
self
.
last_input_throughput
:.
2
f
}
, "
else
:
f
+=
f
"#running-req:
{
running_bs
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
logger
.
info
(
f
)
if
self
.
enable_metrics
:
total_tokens
=
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
cache_hit_rate
=
(
adder
.
log_hit_tokens
/
total_tokens
if
total_tokens
>
0
else
0.0
)
self
.
stats
.
num_running_reqs
=
running_bs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
cache_hit_rate
=
cache_hit_rate
total_queue_latency
=
0
for
req
in
can_run_list
:
total_queue_latency
+=
req
.
queue_time_end
-
req
.
queue_time_start
self
.
stats
.
avg_request_queue_latency
=
total_queue_latency
/
num_new_seq
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
):
batch
=
running_batch
or
self
.
running_batch
gap_latency
=
time
.
perf_counter
()
-
self
.
last_decode_stats_tic
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_gen_throughput
=
self
.
num_generated_tokens
/
gap_latency
self
.
num_generated_tokens
=
0
num_running_reqs
=
len
(
batch
.
reqs
)
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
f
"#full token:
{
full_num_used
}
, "
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"#swa token:
{
swa_num_used
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"#token:
{
num_used
}
, "
f
"token usage:
{
token_usage
:.
2
f
}
, "
if
RECORD_STEP_TIME
:
self
.
step_time_dict
[
num_running_reqs
].
append
(
gap_latency
/
self
.
server_args
.
decode_log_interval
)
msg
=
f
"Decode batch. #running-req:
{
num_running_reqs
}
,
{
token_msg
}
"
if
self
.
spec_algorithm
.
is_none
():
spec_accept_length
=
0
else
:
spec_accept_length
=
(
self
.
spec_num_total_accepted_tokens
/
self
.
spec_num_total_forward_ct
)
self
.
cum_spec_accept_length
+=
self
.
spec_num_total_accepted_tokens
self
.
cum_spec_accept_count
+=
self
.
spec_num_total_forward_ct
self
.
spec_num_total_accepted_tokens
=
self
.
spec_num_total_forward_ct
=
0
msg
+=
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
msg
+=
f
"pre-allocated usage:
{
self
.
disagg_decode_prealloc_queue
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
logger
.
info
(
msg
)
if
self
.
enable_metrics
:
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
stats
.
total_retracted_reqs
=
self
.
total_retracted_reqs
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
_emit_kv_metrics
(
self
):
kv_metrics
=
KvMetrics
()
kv_metrics
.
request_active_slots
=
self
.
stats
.
num_running_reqs
kv_metrics
.
request_total_slots
=
self
.
max_running_requests
kv_metrics
.
kv_active_blocks
=
int
(
self
.
stats
.
token_usage
*
self
.
max_total_num_tokens
)
kv_metrics
.
kv_total_blocks
=
self
.
max_total_num_tokens
kv_metrics
.
num_requests_waiting
=
self
.
stats
.
num_queue_reqs
kv_metrics
.
gpu_cache_usage_perc
=
self
.
stats
.
token_usage
kv_metrics
.
gpu_prefix_cache_hit_rate
=
self
.
stats
.
cache_hit_rate
kv_metrics
.
data_parallel_rank
=
self
.
dp_rank
if
self
.
dp_rank
is
not
None
else
0
if
not
self
.
send_metrics_from_scheduler
.
closed
:
self
.
send_metrics_from_scheduler
.
send_pyobj
(
kv_metrics
)
def
_publish_kv_events
(
self
):
if
self
.
enable_kv_cache_events
:
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
python/sglang/srt/managers/scheduler_profiler_mixin.py
0 → 100644
View file @
a4c3b121
import
logging
import
os
import
time
from
pathlib
import
Path
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.managers.io_struct
import
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerProfilerMixin
:
def
init_profier
(
self
):
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_start_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
def
init_profile
(
self
,
output_dir
:
Optional
[
str
],
start_step
:
Optional
[
int
],
num_steps
:
Optional
[
int
],
activities
:
Optional
[
List
[
str
]],
with_stack
:
Optional
[
bool
],
record_shapes
:
Optional
[
bool
],
profile_by_stage
:
bool
,
profile_id
:
str
,
)
->
ProfileReqOutput
:
if
self
.
profile_in_progress
:
return
ProfileReqOutput
(
success
=
False
,
message
=
"Profiling is already in progress. Call /stop_profile first."
,
)
self
.
profile_by_stage
=
profile_by_stage
if
output_dir
is
None
:
output_dir
=
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
"/tmp"
)
if
activities
is
None
:
activities
=
[
"CPU"
,
"GPU"
]
self
.
torch_profiler_output_dir
=
output_dir
self
.
torch_profiler_with_stack
=
with_stack
self
.
torch_profiler_record_shapes
=
record_shapes
self
.
profiler_activities
=
activities
self
.
profile_id
=
profile_id
if
start_step
:
self
.
profiler_start_forward_ct
=
max
(
start_step
,
self
.
forward_ct
+
1
)
if
num_steps
:
self
.
profile_steps
=
num_steps
if
self
.
profile_by_stage
:
self
.
profiler_target_prefill_ct
=
num_steps
self
.
profiler_target_decode_ct
=
num_steps
self
.
profiler_prefill_ct
=
0
self
.
profiler_decode_ct
=
0
elif
start_step
:
self
.
profiler_target_forward_ct
=
(
self
.
profiler_start_forward_ct
+
num_steps
)
else
:
self
.
profiler_target_forward_ct
=
self
.
forward_ct
+
num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else
:
self
.
profiler_target_forward_ct
=
None
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded"
)
def
start_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
stage_str
=
f
" for
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
f
"Profiling starts
{
stage_str
}
. Traces will be saved to:
{
self
.
torch_profiler_output_dir
}
(with profile id:
{
self
.
profile_id
}
)"
,
)
activities
=
self
.
profiler_activities
with_stack
=
self
.
torch_profiler_with_stack
record_shapes
=
self
.
torch_profiler_record_shapes
activity_map
=
{
"CPU"
:
torch
.
profiler
.
ProfilerActivity
.
CPU
,
"GPU"
:
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
}
torchprof_activities
=
[
activity_map
[
a
]
for
a
in
activities
if
a
in
activity_map
]
if
"RPD"
in
activities
:
from
rpdTracerControl
import
rpdTracerControl
rpdTracerControl
.
skipCreate
()
self
.
rpd_profile_path
=
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
"rpd-"
+
str
(
time
.
time
())
+
f
"-TP-
{
self
.
tp_rank
}
"
+
".trace.json.gz"
,
)
if
self
.
tp_rank
==
0
:
import
sqlite3
from
rocpd.schema
import
RocpdSchema
if
os
.
path
.
exists
(
"trace.rpd"
):
os
.
unlink
(
"trace.rpd"
)
schema
=
RocpdSchema
()
connection
=
sqlite3
.
connect
(
"trace.rpd"
)
schema
.
writeSchema
(
connection
)
connection
.
commit
()
del
connection
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
self
.
rpd_profiler
=
rpdTracerControl
()
self
.
rpd_profiler
.
setPythonTrace
(
True
)
self
.
rpd_profiler
.
start
()
self
.
rpd_profiler
.
rangePush
(
""
,
"rpd profile range"
,
""
)
self
.
profile_in_progress
=
True
elif
torchprof_activities
:
self
.
torch_profiler
=
torch
.
profiler
.
profile
(
activities
=
torchprof_activities
,
with_stack
=
with_stack
if
with_stack
is
not
None
else
True
,
record_shapes
=
record_shapes
if
record_shapes
is
not
None
else
False
,
)
self
.
torch_profiler
.
start
()
self
.
profile_in_progress
=
True
if
"MEM"
in
activities
:
torch
.
cuda
.
memory
.
_record_memory_history
(
max_entries
=
100000
)
self
.
profile_in_progress
=
True
if
"CUDA_PROFILER"
in
activities
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
self
.
profile_in_progress
=
True
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded"
)
def
stop_profile
(
self
,
stage
:
Optional
[
ForwardMode
]
=
None
)
->
ProfileReqOutput
|
None
:
if
not
self
.
profile_in_progress
:
return
ProfileReqOutput
(
success
=
False
,
message
=
"Profiling is not in progress. Call /start_profile first."
,
)
if
not
Path
(
self
.
torch_profiler_output_dir
).
exists
():
Path
(
self
.
torch_profiler_output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
stage_suffix
=
f
"-
{
stage
.
__str__
()
}
"
if
stage
else
""
logger
.
info
(
"Stop profiling"
+
stage_suffix
+
"..."
)
if
self
.
torch_profiler
is
not
None
:
self
.
torch_profiler
.
stop
()
self
.
torch_profiler
.
export_chrome_trace
(
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
self
.
profile_id
+
f
"-TP-
{
self
.
tp_rank
}
"
+
stage_suffix
+
".trace.json.gz"
,
)
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
if
self
.
rpd_profiler
is
not
None
:
self
.
rpd_profiler
.
rangePop
()
self
.
rpd_profiler
.
stop
()
self
.
rpd_profiler
.
flush
()
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
if
self
.
tp_rank
==
0
:
from
sglang.srt.utils
import
rpd_to_chrome_trace
rpd_to_chrome_trace
(
"trace.rpd"
,
self
.
rpd_profile_path
)
self
.
rpd_profiler
=
None
self
.
rpd_profiler_path
=
None
if
self
.
profiler_activities
is
not
None
and
"MEM"
in
self
.
profiler_activities
:
memory_profile_path
=
os
.
path
.
join
(
self
.
torch_profiler_output_dir
,
str
(
time
.
time
())
+
f
"-TP-
{
self
.
tp_rank
}
-memory"
+
stage_suffix
+
".pickle"
,
)
torch
.
cuda
.
memory
.
_dump_snapshot
(
memory_profile_path
)
torch
.
cuda
.
memory
.
_record_memory_history
(
enabled
=
None
)
if
"CUDA_PROFILER"
in
self
.
profiler_activities
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
logger
.
info
(
"Profiling done. Traces are saved to: %s"
,
self
.
torch_profiler_output_dir
,
)
self
.
torch_profiler
=
None
self
.
profile_in_progress
=
False
self
.
profiler_start_forward_ct
=
None
return
ProfileReqOutput
(
success
=
True
,
message
=
"Succeeded."
)
def
_profile_batch_predicate
(
self
,
batch
):
if
self
.
profile_by_stage
:
if
batch
.
forward_mode
.
is_prefill
():
if
self
.
profiler_prefill_ct
==
0
:
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_prefill_ct
+=
1
if
self
.
profiler_prefill_ct
>
self
.
profiler_target_prefill_ct
:
if
self
.
profile_in_progress
:
self
.
stop_profile
(
stage
=
ForwardMode
.
EXTEND
)
elif
batch
.
forward_mode
.
is_decode
():
if
self
.
profiler_decode_ct
==
0
:
if
self
.
profile_in_progress
:
# force trace flush
self
.
stop_profile
(
ForwardMode
.
EXTEND
)
self
.
start_profile
(
batch
.
forward_mode
)
self
.
profiler_decode_ct
+=
1
if
self
.
profiler_decode_ct
>
self
.
profiler_target_decode_ct
:
if
self
.
profile_in_progress
:
self
.
stop_profile
(
stage
=
ForwardMode
.
DECODE
)
elif
batch
.
forward_mode
.
is_idle
():
pass
else
:
raise
RuntimeError
(
f
"unsupported profile stage:
{
batch
.
forward_mode
}
"
)
else
:
# Check profiler
if
(
self
.
profiler_target_forward_ct
and
self
.
profiler_target_forward_ct
<=
self
.
forward_ct
):
self
.
stop_profile
()
if
(
self
.
profiler_start_forward_ct
and
self
.
profiler_start_forward_ct
==
self
.
forward_ct
):
self
.
start_profile
()
def
profile
(
self
,
recv_req
:
ProfileReq
):
if
recv_req
.
type
==
ProfileReqType
.
START_PROFILE
:
if
recv_req
.
profile_by_stage
or
recv_req
.
start_step
:
return
self
.
init_profile
(
recv_req
.
output_dir
,
recv_req
.
start_step
,
recv_req
.
num_steps
,
recv_req
.
activities
,
recv_req
.
with_stack
,
recv_req
.
record_shapes
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
else
:
self
.
init_profile
(
recv_req
.
output_dir
,
recv_req
.
start_step
,
recv_req
.
num_steps
,
recv_req
.
activities
,
recv_req
.
with_stack
,
recv_req
.
record_shapes
,
recv_req
.
profile_by_stage
,
recv_req
.
profile_id
,
)
return
self
.
start_profile
(
True
)
else
:
return
self
.
stop_profile
()
python/sglang/srt/managers/scheduler_update_weights_mixin.py
0 → 100644
View file @
a4c3b121
import
logging
from
typing
import
Tuple
import
torch
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.managers.io_struct
import
(
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqOutput
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqOutput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerUpdateWeightsMixin
:
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
"""In-place update of the weights from disk."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_disk
(
recv_req
)
if
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightFromDiskReqOutput
(
success
,
message
,
0
)
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
"""Initialize the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
,
)
->
Tuple
[
bool
,
str
]:
"""Update the online model parameter."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
if
success
:
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
"""Update the online model parameter from tensors."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_tensor
(
recv_req
)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if
success
:
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
torch
.
distributed
.
barrier
(
group
=
self
.
tp_cpu_group
)
return
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
GetWeightsByNameReqOutput
(
parameter
)
def
release_memory_occupation
(
self
,
recv_req
:
ReleaseMemoryOccupationReqInput
):
tags
=
recv_req
.
tags
if
tags
is
None
or
len
(
tags
)
==
0
:
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_KV_CACHE
)
self
.
flush_cache
()
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
stashed_model_static_state
=
_export_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_WEIGHTS
)
return
ReleaseMemoryOccupationReqOutput
()
def
resume_memory_occupation
(
self
,
recv_req
:
ResumeMemoryOccupationReqInput
):
tags
=
recv_req
.
tags
if
tags
is
None
or
len
(
tags
)
==
0
:
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_WEIGHTS
)
torch
.
distributed
.
barrier
(
self
.
tp_cpu_group
)
_import_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
,
)
del
self
.
stashed_model_static_state
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_KV_CACHE
)
return
ResumeMemoryOccupationReqOutput
()
def
save_remote_model
(
self
,
params
):
url
=
params
[
"url"
]
worker
=
self
.
tp_worker
.
worker
worker
.
model_runner
.
save_remote_model
(
url
)
def
save_sharded_model
(
self
,
params
):
worker
=
self
.
tp_worker
.
worker
worker
.
model_runner
.
save_sharded_model
(
path
=
params
[
"path"
],
pattern
=
params
[
"pattern"
],
max_size
=
params
[
"max_size"
],
)
def
_export_static_state
(
model
):
return
dict
(
buffers
=
[
(
name
,
buffer
.
detach
().
clone
())
for
name
,
buffer
in
model
.
named_buffers
()
]
)
def
_import_static_state
(
model
,
static_params
):
self_named_buffers
=
dict
(
model
.
named_buffers
())
for
name
,
tensor
in
static_params
[
"buffers"
]:
self_named_buffers
[
name
][...]
=
tensor
python/sglang/srt/managers/tokenizer_manager.py
View file @
a4c3b121
...
@@ -170,16 +170,6 @@ class ReqState:
...
@@ -170,16 +170,6 @@ class ReqState:
output_token_ids_logprobs_idx
:
List
=
dataclasses
.
field
(
default_factory
=
list
)
output_token_ids_logprobs_idx
:
List
=
dataclasses
.
field
(
default_factory
=
list
)
def
_determine_tensor_transport_mode
(
server_args
:
ServerArgs
)
->
TensorTransportMode
:
is_cross_node
=
server_args
.
dist_init_addr
if
is_cross_node
:
# Fallback to default CPU transport for multi-node
return
"default"
else
:
return
"cuda_ipc"
class
TokenizerManager
:
class
TokenizerManager
:
"""TokenizerManager is a process that tokenizes the text."""
"""TokenizerManager is a process that tokenizes the text."""
...
@@ -199,16 +189,6 @@ class TokenizerManager:
...
@@ -199,16 +189,6 @@ class TokenizerManager:
else
None
else
None
)
)
self
.
crash_dump_folder
=
server_args
.
crash_dump_folder
self
.
crash_dump_folder
=
server_args
.
crash_dump_folder
self
.
crash_dump_performed
=
False
# Flag to ensure dump is only called once
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
self
.
recv_from_detokenizer
=
get_zmq_socket
(
context
,
zmq
.
PULL
,
port_args
.
tokenizer_ipc_name
,
True
)
self
.
send_to_scheduler
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
scheduler_input_ipc_name
,
True
)
# Read model args
# Read model args
self
.
model_path
=
server_args
.
model_path
self
.
model_path
=
server_args
.
model_path
...
@@ -218,8 +198,7 @@ class TokenizerManager:
...
@@ -218,8 +198,7 @@ class TokenizerManager:
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
self
.
is_image_gen
=
self
.
model_config
.
is_image_gen
self
.
context_len
=
self
.
model_config
.
context_len
self
.
context_len
=
self
.
model_config
.
context_len
self
.
image_token_id
=
self
.
model_config
.
image_token_id
self
.
image_token_id
=
self
.
model_config
.
image_token_id
self
.
_updating
=
False
self
.
max_req_input_len
=
None
# Will be set later in engine.py
self
.
_cond
=
asyncio
.
Condition
()
if
self
.
model_config
.
is_multimodal
:
if
self
.
model_config
.
is_multimodal
:
import_processors
()
import_processors
()
...
@@ -258,39 +237,57 @@ class TokenizerManager:
...
@@ -258,39 +237,57 @@ class TokenizerManager:
revision
=
server_args
.
revision
,
revision
=
server_args
.
revision
,
)
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# Init inter-process communication
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
context
=
zmq
.
asyncio
.
Context
(
2
)
# serves as the source of truth for available adapters and maps user-friendly LoRA names
self
.
recv_from_detokenizer
=
get_zmq_socket
(
# to internally used unique LoRA IDs.
context
,
zmq
.
PULL
,
port_args
.
tokenizer_ipc_name
,
True
self
.
lora_registry
=
LoRARegistry
(
self
.
server_args
.
lora_paths
or
{})
)
self
.
send_to_scheduler
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
scheduler_input_ipc_name
,
True
)
#
Store
states
#
Request
states
self
.
no_create_loop
=
False
self
.
no_create_loop
=
False
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
asyncio_tasks
=
set
()
# Health check
self
.
health_check_failed
=
False
self
.
health_check_failed
=
False
self
.
gracefully_exit
=
False
self
.
gracefully_exit
=
False
self
.
last_receive_tstamp
=
0
self
.
last_receive_tstamp
=
0
# Dumping
self
.
dump_requests_folder
=
""
# By default do not dump
self
.
dump_requests_folder
=
""
# By default do not dump
self
.
dump_requests_threshold
=
1000
self
.
dump_requests_threshold
=
1000
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
crash_dump_request_list
:
deque
[
Tuple
]
=
deque
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
crash_dump_request_list
:
deque
[
Tuple
]
=
deque
()
self
.
crash_dump_performed
=
False
# Flag to ensure dump is only called once
# Session
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
max_req_input_len
=
None
self
.
asyncio_tasks
=
set
()
# Weight updates
# The event to notify the weight sync is finished.
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
self
.
model_update_lock
=
RWLock
()
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
None
None
)
)
self
.
_is_updating
=
False
self
.
_is_updating_cond
=
asyncio
.
Condition
()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self
.
lora_registry
=
LoRARegistry
(
self
.
server_args
.
lora_paths
or
{})
# Lock to serialize LoRA update operations.
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
# LoRA updates and inference to overlap.
self
.
lora_update_lock
=
asyncio
.
Lock
()
self
.
lora_update_lock
=
asyncio
.
Lock
()
# For
pd
disaggregtion
# For
PD
disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
self
.
server_args
.
disaggregation_mode
)
)
...
@@ -458,17 +455,11 @@ class TokenizerManager:
...
@@ -458,17 +455,11 @@ class TokenizerManager:
request
:
Optional
[
fastapi
.
Request
]
=
None
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
):
created_time
=
time
.
time
()
created_time
=
time
.
time
()
async
with
self
.
_cond
:
await
self
.
_cond
.
wait_for
(
lambda
:
not
self
.
_updating
)
self
.
auto_create_handle_loop
()
self
.
auto_create_handle_loop
()
obj
.
normalize_batch_and_arguments
()
obj
.
normalize_batch_and_arguments
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
async
with
self
.
_is_updating_cond
:
raise
ValueError
(
await
self
.
_is_updating_cond
.
wait_for
(
lambda
:
not
self
.
_is_updating
)
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
if
self
.
log_requests
:
if
self
.
log_requests
:
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
...
@@ -567,6 +558,12 @@ class TokenizerManager:
...
@@ -567,6 +558,12 @@ class TokenizerManager:
f
"model's context length (
{
self
.
context_len
}
tokens)."
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
)
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
# Check total tokens (input + max_new_tokens)
# Check total tokens (input + max_new_tokens)
max_new_tokens
=
obj
.
sampling_params
.
get
(
"max_new_tokens"
)
max_new_tokens
=
obj
.
sampling_params
.
get
(
"max_new_tokens"
)
if
(
if
(
...
@@ -959,14 +956,14 @@ class TokenizerManager:
...
@@ -959,14 +956,14 @@ class TokenizerManager:
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
await
self
.
expert_distribution_communicator
(
ExpertDistributionReq
.
DUMP_RECORD
)
async
def
pause_generation
(
self
):
async
def
pause_generation
(
self
):
async
with
self
.
_cond
:
async
with
self
.
_
is_updating_
cond
:
self
.
_updating
=
True
self
.
_
is_
updating
=
True
self
.
abort_request
(
abort_all
=
True
)
self
.
abort_request
(
abort_all
=
True
)
async
def
continue_generation
(
self
):
async
def
continue_generation
(
self
):
async
with
self
.
_cond
:
async
with
self
.
_
is_updating_
cond
:
self
.
_updating
=
False
self
.
_
is_
updating
=
False
self
.
_cond
.
notify_all
()
self
.
_
is_updating_
cond
.
notify_all
()
async
def
update_weights_from_disk
(
async
def
update_weights_from_disk
(
self
,
self
,
...
@@ -1208,14 +1205,6 @@ class TokenizerManager:
...
@@ -1208,14 +1205,6 @@ class TokenizerManager:
# Many DP ranks
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
return
[
res
.
internal_state
for
res
in
responses
]
async
def
get_load
(
self
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
async
def
set_internal_state
(
async
def
set_internal_state
(
self
,
obj
:
SetInternalStateReq
self
,
obj
:
SetInternalStateReq
)
->
SetInternalStateReqOutput
:
)
->
SetInternalStateReqOutput
:
...
@@ -1224,6 +1213,14 @@ class TokenizerManager:
...
@@ -1224,6 +1213,14 @@ class TokenizerManager:
)
)
return
[
res
.
internal_state
for
res
in
responses
]
return
[
res
.
internal_state
for
res
in
responses
]
async
def
get_load
(
self
)
->
dict
:
# TODO(lsyin): fake load report server
if
not
self
.
current_load_lock
.
locked
():
async
with
self
.
current_load_lock
:
internal_state
=
await
self
.
get_internal_state
()
self
.
current_load
=
internal_state
[
0
][
"load"
]
return
{
"load"
:
self
.
current_load
}
def
get_log_request_metadata
(
self
):
def
get_log_request_metadata
(
self
):
max_length
=
None
max_length
=
None
skip_names
=
None
skip_names
=
None
...
@@ -1343,11 +1340,24 @@ class TokenizerManager:
...
@@ -1343,11 +1340,24 @@ class TokenizerManager:
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
)
return
return
logger
.
error
(
f
"Dumping requests before crash.
{
self
.
crash_dump_folder
=
}
"
)
self
.
crash_dump_performed
=
True
if
not
self
.
crash_dump_folder
:
if
not
self
.
crash_dump_folder
:
return
return
logger
.
error
(
f
"Dumping requests before crash.
{
self
.
crash_dump_folder
=
}
"
)
self
.
crash_dump_performed
=
True
# Check if NFS directory is available
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
# expected_nfs_dir, os.W_OK
# )
use_nfs_dir
=
False
if
not
use_nfs_dir
:
logger
.
error
(
f
"Expected NFS directory is not available or writable. Uploading to GCS."
)
data_to_dump
=
[]
data_to_dump
=
[]
if
self
.
crash_dump_request_list
:
if
self
.
crash_dump_request_list
:
data_to_dump
.
extend
(
self
.
crash_dump_request_list
)
data_to_dump
.
extend
(
self
.
crash_dump_request_list
)
...
@@ -1357,7 +1367,12 @@ class TokenizerManager:
...
@@ -1357,7 +1367,12 @@ class TokenizerManager:
for
rid
,
state
in
self
.
rid_to_state
.
items
():
for
rid
,
state
in
self
.
rid_to_state
.
items
():
if
not
state
.
finished
:
if
not
state
.
finished
:
unfinished_requests
.
append
(
unfinished_requests
.
append
(
(
state
.
obj
,
{},
state
.
created_time
,
time
.
time
())
(
state
.
obj
,
state
.
out_list
[
-
1
]
if
state
.
out_list
else
{},
state
.
created_time
,
time
.
time
(),
)
)
)
if
unfinished_requests
:
if
unfinished_requests
:
data_to_dump
.
extend
(
unfinished_requests
)
data_to_dump
.
extend
(
unfinished_requests
)
...
@@ -1365,10 +1380,11 @@ class TokenizerManager:
...
@@ -1365,10 +1380,11 @@ class TokenizerManager:
if
not
data_to_dump
:
if
not
data_to_dump
:
return
return
object_name
=
f
'crash_dump_
{
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
}
.pkl'
filename
=
os
.
path
.
join
(
filename
=
os
.
path
.
join
(
self
.
crash_dump_folder
,
self
.
crash_dump_folder
,
os
.
getenv
(
"HOSTNAME"
,
None
),
os
.
getenv
(
"HOSTNAME"
,
None
),
f
"crash_dump_
{
datetime
.
now
().
strftime
(
'%Y-%m-%d_%H-%M-%S'
)
}
.pkl"
,
object_name
,
)
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
...
@@ -1383,6 +1399,24 @@ class TokenizerManager:
...
@@ -1383,6 +1399,24 @@ class TokenizerManager:
f
"Dumped
{
len
(
self
.
crash_dump_request_list
)
}
finished and
{
len
(
unfinished_requests
)
}
unfinished requests before crash to
{
filename
}
"
f
"Dumped
{
len
(
self
.
crash_dump_request_list
)
}
finished and
{
len
(
unfinished_requests
)
}
unfinished requests before crash to
{
filename
}
"
)
)
def
_upload_file_to_gcs
(
bucket_name
,
source_file_path
,
object_name
):
from
google.cloud
import
storage
client
=
storage
.
Client
()
bucket
=
client
.
bucket
(
bucket_name
)
blob
=
bucket
.
blob
(
object_name
)
blob
.
upload_from_filename
(
source_file_path
,
if_generation_match
=
0
)
logger
.
error
(
f
"Successfully uploaded
{
source_file_path
}
to gs://
{
bucket_name
}
/
{
object_name
}
"
)
if
not
use_nfs_dir
:
_upload_file_to_gcs
(
"sglang_crash_dump"
,
filename
,
os
.
getenv
(
"HOSTNAME"
,
None
)
+
"/"
+
object_name
,
)
async
def
sigterm_watchdog
(
self
):
async
def
sigterm_watchdog
(
self
):
while
not
self
.
gracefully_exit
:
while
not
self
.
gracefully_exit
:
await
asyncio
.
sleep
(
5
)
await
asyncio
.
sleep
(
5
)
...
@@ -1426,7 +1460,7 @@ class TokenizerManager:
...
@@ -1426,7 +1460,7 @@ class TokenizerManager:
while
True
:
while
True
:
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
self
.
_result_dispatcher
(
recv_obj
)
self
.
_result_dispatcher
(
recv_obj
)
self
.
last_receive_tstamp
=
time
.
perf_counter
()
self
.
last_receive_tstamp
=
time
.
time
()
def
_handle_batch_output
(
def
_handle_batch_output
(
self
,
self
,
...
@@ -1697,24 +1731,13 @@ class TokenizerManager:
...
@@ -1697,24 +1731,13 @@ class TokenizerManager:
self
.
dump_requests_folder
,
self
.
dump_requests_folder
,
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
+
".pkl"
,
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
+
".pkl"
,
)
)
logger
.
info
(
f
"Dump
{
len
(
self
.
dump_request_list
)
}
requests to
{
filename
}
"
)
self
.
_dump_data_to_file
(
data_list
=
self
.
dump_request_list
,
to_dump
=
self
.
dump_request_list
filename
=
filename
,
log_message
=
f
"Dump
{
len
(
self
.
dump_request_list
)
}
requests to
{
filename
}
"
,
)
self
.
dump_request_list
=
[]
self
.
dump_request_list
=
[]
to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
to_dump
,
}
def
background_task
():
os
.
makedirs
(
self
.
dump_requests_folder
,
exist_ok
=
True
)
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump_with_server_args
,
f
)
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
record_request_for_crash_dump
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
def
record_request_for_crash_dump
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
current_time
=
time
.
time
()
current_time
=
time
.
time
()
self
.
crash_dump_request_list
.
append
(
self
.
crash_dump_request_list
.
append
(
...
@@ -1727,6 +1750,22 @@ class TokenizerManager:
...
@@ -1727,6 +1750,22 @@ class TokenizerManager:
):
):
self
.
crash_dump_request_list
.
popleft
()
self
.
crash_dump_request_list
.
popleft
()
def
_dump_data_to_file
(
self
,
data_list
:
List
[
Tuple
],
filename
:
str
,
log_message
:
str
):
logger
.
info
(
log_message
)
to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
data_list
.
copy
(),
}
def
background_task
():
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump_with_server_args
,
f
)
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_abort_req
(
self
,
recv_obj
):
def
_handle_abort_req
(
self
,
recv_obj
):
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
.
finished
=
True
state
.
finished
=
True
...
@@ -1862,6 +1901,16 @@ class TokenizerManager:
...
@@ -1862,6 +1901,16 @@ class TokenizerManager:
return
scores
return
scores
def
_determine_tensor_transport_mode
(
server_args
:
ServerArgs
)
->
TensorTransportMode
:
is_cross_node
=
server_args
.
dist_init_addr
if
is_cross_node
:
# Fallback to default CPU transport for multi-node
return
"default"
else
:
return
"cuda_ipc"
async
def
print_exception_wrapper
(
func
):
async
def
print_exception_wrapper
(
func
):
"""
"""
Sometimes an asyncio function does not print exception.
Sometimes an asyncio function does not print exception.
...
...
python/sglang/srt/server_args.py
View file @
a4c3b121
...
@@ -2071,6 +2071,9 @@ class PortArgs:
...
@@ -2071,6 +2071,9 @@ class PortArgs:
dist_init_host
,
dist_init_port
=
dist_init_addr
dist_init_host
,
dist_init_port
=
dist_init_addr
port_base
=
int
(
dist_init_port
)
+
1
port_base
=
int
(
dist_init_port
)
+
1
detokenizer_port
=
port_base
+
1
rpc_port
=
port_base
+
2
metrics_ipc_name
=
port_base
+
3
if
dp_rank
is
None
:
if
dp_rank
is
None
:
# TokenizerManager to DataParallelController
# TokenizerManager to DataParallelController
scheduler_input_port
=
port_base
+
4
scheduler_input_port
=
port_base
+
4
...
@@ -2080,10 +2083,10 @@ class PortArgs:
...
@@ -2080,10 +2083,10 @@ class PortArgs:
return
PortArgs
(
return
PortArgs
(
tokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
}
"
,
tokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
}
"
,
scheduler_input_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
scheduler_input_port
}
"
,
scheduler_input_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
scheduler_input_port
}
"
,
detokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
+
1
}
"
,
detokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
detokenizer_port
}
"
,
nccl_port
=
nccl_port
,
nccl_port
=
nccl_port
,
rpc_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port
_base
+
2
}
"
,
rpc_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
rpc_
port
}
"
,
metrics_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
+
3
}
"
,
metrics_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
metrics_ipc_name
}
"
,
)
)
...
...
python/sglang/utils.py
View file @
a4c3b121
...
@@ -291,17 +291,6 @@ def find_printable_text(text: str):
...
@@ -291,17 +291,6 @@ def find_printable_text(text: str):
return
text
[:
text
.
rfind
(
" "
)
+
1
]
return
text
[:
text
.
rfind
(
" "
)
+
1
]
def
graceful_registry
(
sub_module_name
:
str
):
def
graceful_shutdown
(
signum
,
frame
):
logger
.
info
(
f
"
{
sub_module_name
}
Received signal to shutdown. Performing graceful shutdown..."
)
if
signum
==
signal
.
SIGTERM
:
logger
.
info
(
f
"
{
sub_module_name
}
receive sigterm"
)
signal
.
signal
(
signal
.
SIGTERM
,
graceful_shutdown
)
class
LazyImport
:
class
LazyImport
:
"""Lazy import to make `import sglang` run faster."""
"""Lazy import to make `import sglang` run faster."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment