Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
32852fe9
"projects/CLIP/tests/test_multi_head_attn.py" did not exist on "478602ba59c0bfe7ab9a094b9f1b7b33cfeecba4"
Unverified
Commit
32852fe9
authored
Oct 23, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 23, 2025
Browse files
Move memory runtime checker to mixin class (#12014)
parent
53c2934d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
168 additions
and
135 deletions
+168
-135
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-135
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
...on/sglang/srt/managers/scheduler_runtime_checker_mixin.py
+164
-0
No files found.
python/sglang/srt/managers/scheduler.py
View file @
32852fe9
...
...
@@ -137,6 +137,9 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from
sglang.srt.managers.scheduler_pp_mixin
import
SchedulerPPMixin
from
sglang.srt.managers.scheduler_profiler_mixin
import
SchedulerProfilerMixin
from
sglang.srt.managers.scheduler_recv_skipper
import
SchedulerRecvSkipper
from
sglang.srt.managers.scheduler_runtime_checker_mixin
import
(
SchedulerRuntimeCheckerMixin
,
)
from
sglang.srt.managers.scheduler_update_weights_mixin
import
(
SchedulerUpdateWeightsMixin
,
)
...
...
@@ -207,6 +210,7 @@ class Scheduler(
SchedulerMetricsMixin
,
SchedulerDisaggregationDecodeMixin
,
SchedulerDisaggregationPrefillMixin
,
SchedulerRuntimeCheckerMixin
,
SchedulerPPMixin
,
):
"""A scheduler that manages a tensor parallel GPU worker."""
...
...
@@ -1506,141 +1510,6 @@ class Scheduler(
for
tokenized_req
in
recv_req
:
self
.
handle_embedding_request
(
tokenized_req
)
def
self_check_during_idle
(
self
):
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
def
check_memory
(
self
):
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
_
,
_
,
full_available_size
,
full_evictable_size
,
swa_available_size
,
swa_evictable_size
,
)
=
self
.
_get_swa_token_info
()
memory_leak
=
full_num_used
!=
0
or
swa_num_used
!=
0
token_msg
=
(
f
"
{
self
.
full_tokens_per_layer
=
}
,
{
full_available_size
=
}
,
{
full_evictable_size
=
}
,
{
self
.
tree_cache
.
full_protected_size
()
=
}
\n
"
f
"
{
self
.
swa_tokens_per_layer
=
}
,
{
swa_available_size
=
}
,
{
swa_evictable_size
=
}
,
{
self
.
tree_cache
.
swa_protected_size
()
=
}
\n
"
)
elif
self
.
is_hybrid_gdn
and
isinstance
(
self
.
tree_cache
,
MambaRadixCache
):
(
full_num_used
,
mamba_num_used
,
_
,
_
,
full_available_size
,
full_evictable_size
,
mamba_available_size
,
mamba_evictable_size
,
)
=
self
.
_get_mamba_token_info
()
memory_leak
=
(
full_num_used
!=
self
.
tree_cache
.
full_protected_size
()
or
mamba_num_used
!=
self
.
tree_cache
.
mamba_protected_size
()
)
token_msg
=
(
f
"
{
full_available_size
=
}
,
{
full_evictable_size
=
}
,
{
self
.
token_to_kv_pool_allocator
.
size
=
}
,
{
self
.
tree_cache
.
full_protected_size
()
=
}
\n
"
f
"
{
mamba_available_size
=
}
,
{
mamba_evictable_size
=
}
,
{
self
.
req_to_token_pool
.
mamba_pool
.
size
=
}
,
{
self
.
tree_cache
.
mamba_protected_size
()
=
}
\n
"
)
else
:
_
,
_
,
available_size
,
evictable_size
=
self
.
_get_token_info
()
protected_size
=
self
.
tree_cache
.
protected_size
()
memory_leak
=
(
available_size
+
evictable_size
)
!=
(
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self
.
max_total_num_tokens
-
protected_size
)
token_msg
=
f
"
{
self
.
max_total_num_tokens
=
}
,
{
available_size
=
}
,
{
evictable_size
=
}
,
{
protected_size
=
}
\n
"
if
memory_leak
:
msg
=
"token_to_kv_pool_allocator memory leak detected! "
f
"
{
token_msg
}
"
raise
ValueError
(
msg
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
req_total_size
=
(
self
.
req_to_token_pool
.
size
+
self
.
req_to_token_pool
.
pre_alloc_size
)
else
:
req_total_size
=
self
.
req_to_token_pool
.
size
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
req_total_size
:
msg
=
(
"req_to_token_pool memory leak detected!"
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
)
raise
ValueError
(
msg
)
if
(
self
.
enable_metrics
and
self
.
current_scheduler_metrics_enabled
()
and
time
.
perf_counter
()
>
self
.
metrics_collector
.
last_log_time
+
30
):
# During idle time, also collect metrics every 30 seconds.
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
elif
self
.
is_hybrid_gdn
:
(
num_used
,
_
,
token_usage
,
_
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_mamba_token_info
()
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
stats
.
num_prefill_prealloc_queue_reqs
=
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
self
.
stats
.
num_prefill_inflight_queue_reqs
=
len
(
self
.
disagg_prefill_inflight_queue
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
stats
.
num_decode_prealloc_queue_reqs
=
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
self
.
stats
.
num_decode_transfer_queue_reqs
=
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_publish_kv_events
()
def
check_tree_cache
(
self
):
if
(
self
.
is_hybrid
and
isinstance
(
self
.
tree_cache
,
SWARadixCache
))
or
(
self
.
is_hybrid_gdn
and
isinstance
(
self
.
tree_cache
,
MambaRadixCache
)
):
self
.
tree_cache
.
sanity_check
()
def
_get_token_info
(
self
):
available_size
=
self
.
token_to_kv_pool_allocator
.
available_size
()
evictable_size
=
self
.
tree_cache
.
evictable_size
()
...
...
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
0 → 100644
View file @
32852fe9
from
__future__
import
annotations
import
time
from
typing
import
TYPE_CHECKING
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
if
TYPE_CHECKING
:
from
sglang.srt.managers.scheduler
import
Scheduler
class
SchedulerRuntimeCheckerMixin
:
def
_check_hybrid_memory
(
self
:
Scheduler
):
(
full_num_used
,
swa_num_used
,
_
,
_
,
full_available_size
,
full_evictable_size
,
swa_available_size
,
swa_evictable_size
,
)
=
self
.
_get_swa_token_info
()
memory_leak
=
full_num_used
!=
0
or
swa_num_used
!=
0
token_msg
=
(
f
"
{
self
.
full_tokens_per_layer
=
}
,
{
full_available_size
=
}
,
{
full_evictable_size
=
}
,
{
self
.
tree_cache
.
full_protected_size
()
=
}
\n
"
f
"
{
self
.
swa_tokens_per_layer
=
}
,
{
swa_available_size
=
}
,
{
swa_evictable_size
=
}
,
{
self
.
tree_cache
.
swa_protected_size
()
=
}
\n
"
)
return
memory_leak
,
token_msg
def
_check_mamba_memory
(
self
:
Scheduler
):
(
full_num_used
,
mamba_num_used
,
_
,
_
,
full_available_size
,
full_evictable_size
,
mamba_available_size
,
mamba_evictable_size
,
)
=
self
.
_get_mamba_token_info
()
memory_leak
=
(
full_num_used
!=
self
.
tree_cache
.
full_protected_size
()
or
mamba_num_used
!=
self
.
tree_cache
.
mamba_protected_size
()
)
token_msg
=
(
f
"
{
full_available_size
=
}
,
{
full_evictable_size
=
}
,
{
self
.
token_to_kv_pool_allocator
.
size
=
}
,
{
self
.
tree_cache
.
full_protected_size
()
=
}
\n
"
f
"
{
mamba_available_size
=
}
,
{
mamba_evictable_size
=
}
,
{
self
.
req_to_token_pool
.
mamba_pool
.
size
=
}
,
{
self
.
tree_cache
.
mamba_protected_size
()
=
}
\n
"
)
return
memory_leak
,
token_msg
def
_check_radix_cache_memory
(
self
:
Scheduler
):
_
,
_
,
available_size
,
evictable_size
=
self
.
_get_token_info
()
protected_size
=
self
.
tree_cache
.
protected_size
()
memory_leak
=
(
available_size
+
evictable_size
)
!=
(
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self
.
max_total_num_tokens
-
protected_size
)
token_msg
=
f
"
{
self
.
max_total_num_tokens
=
}
,
{
available_size
=
}
,
{
evictable_size
=
}
,
{
protected_size
=
}
\n
"
return
memory_leak
,
token_msg
def
_check_req_pool
(
self
:
Scheduler
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
req_total_size
=
(
self
.
req_to_token_pool
.
size
+
self
.
req_to_token_pool
.
pre_alloc_size
)
else
:
req_total_size
=
self
.
req_to_token_pool
.
size
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
req_total_size
:
msg
=
(
"req_to_token_pool memory leak detected!"
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
)
raise
ValueError
(
msg
)
def
check_memory
(
self
:
Scheduler
):
if
self
.
is_hybrid
:
memory_leak
,
token_msg
=
self
.
_check_hybrid_memory
()
elif
self
.
is_hybrid_gdn
and
isinstance
(
self
.
tree_cache
,
MambaRadixCache
):
memory_leak
,
token_msg
=
self
.
_check_mamba_memory
()
else
:
memory_leak
,
token_msg
=
self
.
_check_radix_cache_memory
()
if
memory_leak
:
msg
=
"token_to_kv_pool_allocator memory leak detected! "
f
"
{
token_msg
}
"
raise
ValueError
(
msg
)
self
.
_check_req_pool
()
if
(
self
.
enable_metrics
and
self
.
current_scheduler_metrics_enabled
()
and
time
.
perf_counter
()
>
self
.
metrics_collector
.
last_log_time
+
30
):
# During idle time, also collect metrics every 30 seconds.
if
self
.
is_hybrid
:
(
full_num_used
,
swa_num_used
,
full_token_usage
,
swa_token_usage
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
elif
self
.
is_hybrid_gdn
:
(
num_used
,
_
,
token_usage
,
_
,
_
,
_
,
_
,
_
,
)
=
self
.
_get_mamba_token_info
()
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
stats
.
num_prefill_prealloc_queue_reqs
=
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
self
.
stats
.
num_prefill_inflight_queue_reqs
=
len
(
self
.
disagg_prefill_inflight_queue
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
stats
.
num_decode_prealloc_queue_reqs
=
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
self
.
stats
.
num_decode_transfer_queue_reqs
=
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_publish_kv_events
()
def
check_tree_cache
(
self
:
Scheduler
):
if
(
self
.
is_hybrid
and
isinstance
(
self
.
tree_cache
,
SWARadixCache
))
or
(
self
.
is_hybrid_gdn
and
isinstance
(
self
.
tree_cache
,
MambaRadixCache
)
):
self
.
tree_cache
.
sanity_check
()
def
self_check_during_idle
(
self
:
Scheduler
):
self
.
check_memory
()
self
.
check_tree_cache
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
maybe_sleep_on_idle
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment