Unverified Commit 32852fe9 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move memory runtime checker to mixin class (#12014)

parent 53c2934d
......@@ -137,6 +137,9 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
from sglang.srt.managers.scheduler_runtime_checker_mixin import (
SchedulerRuntimeCheckerMixin,
)
from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
......@@ -207,6 +210,7 @@ class Scheduler(
SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
SchedulerRuntimeCheckerMixin,
SchedulerPPMixin,
):
"""A scheduler that manages a tensor parallel GPU worker."""
......@@ -1506,141 +1510,6 @@ class Scheduler(
for tokenized_req in recv_req:
self.handle_embedding_request(tokenized_req)
def self_check_during_idle(self):
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
def check_memory(self):
if self.is_hybrid:
(
full_num_used,
swa_num_used,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
memory_leak = full_num_used != 0 or swa_num_used != 0
token_msg = (
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
else:
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
memory_leak = (available_size + evictable_size) != (
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self.max_total_num_tokens
- protected_size
)
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
if memory_leak:
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
raise ValueError(msg)
if self.disaggregation_mode == DisaggregationMode.DECODE:
req_total_size = (
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
)
else:
req_total_size = self.req_to_token_pool.size
if len(self.req_to_token_pool.free_slots) != req_total_size:
msg = (
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
raise ValueError(msg)
if (
self.enable_metrics
and self.current_scheduler_metrics_enabled()
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
)
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def check_tree_cache(self):
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check()
def _get_token_info(self):
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
......
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler
class SchedulerRuntimeCheckerMixin:
def _check_hybrid_memory(self: Scheduler):
(
full_num_used,
swa_num_used,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
memory_leak = full_num_used != 0 or swa_num_used != 0
token_msg = (
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
return memory_leak, token_msg
def _check_mamba_memory(self: Scheduler):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
return memory_leak, token_msg
def _check_radix_cache_memory(self: Scheduler):
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
memory_leak = (available_size + evictable_size) != (
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self.max_total_num_tokens
- protected_size
)
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
return memory_leak, token_msg
def _check_req_pool(self: Scheduler):
if self.disaggregation_mode == DisaggregationMode.DECODE:
req_total_size = (
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
)
else:
req_total_size = self.req_to_token_pool.size
if len(self.req_to_token_pool.free_slots) != req_total_size:
msg = (
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
raise ValueError(msg)
def check_memory(self: Scheduler):
if self.is_hybrid:
memory_leak, token_msg = self._check_hybrid_memory()
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
memory_leak, token_msg = self._check_mamba_memory()
else:
memory_leak, token_msg = self._check_radix_cache_memory()
if memory_leak:
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
raise ValueError(msg)
self._check_req_pool()
if (
self.enable_metrics
and self.current_scheduler_metrics_enabled()
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
)
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def check_tree_cache(self: Scheduler):
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check()
def self_check_during_idle(self: Scheduler):
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment