Unverified Commit a4c3b121 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Split the scheduler into multiple mixin classes to reduce the file size (#8483)

parent 5973675b
...@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
+ len(self.disagg_decode_prealloc_queue.queue) + len(self.disagg_decode_prealloc_queue.queue)
== 0 == 0
): ):
# When the server is idle, do self-check and re-init some states self.self_check_during_idle()
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.last_batch = batch self.last_batch = batch
...@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
+ len(self.disagg_decode_prealloc_queue.queue) + len(self.disagg_decode_prealloc_queue.queue)
== 0 == 0
): ):
# When the server is idle, do self-check and re-init some states self.self_check_during_idle()
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.last_batch = batch self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue self.last_batch_in_queue = last_batch_in_queue
......
...@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
self.process_disagg_prefill_inflight_queue() self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0: if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory() self.self_check_during_idle()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.last_batch = batch self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
...@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
self.process_disagg_prefill_inflight_queue() self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0: if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory() self.self_check_during_idle()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.last_batch = batch self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
......
...@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs):
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
def sigchld_handler(signum, frame): if True: # Keep this check for internal code compatibility
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
logger.warning(
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
)
signal.signal(signal.SIGCHLD, sigchld_handler)
# Register the signal handler. # Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens # The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree # This process then clean up the whole process tree
def sigquit_handler(signum, frame): # Note: This sigquit handler is used in the launch phase, and may be replaced by
# the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
def launch_phase_sigquit_handler(signum, frame):
logger.error( logger.error(
"Received sigquit from a child process. It usually means the child failed." "Received sigquit from a child process. It usually means the child failed."
) )
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler) signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)
# Set mp start method # Set mp start method
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
......
...@@ -238,6 +238,9 @@ async def health() -> Response: ...@@ -238,6 +238,9 @@ async def health() -> Response:
@app.get("/health_generate") @app.get("/health_generate")
async def health_generate(request: Request) -> Response: async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token.""" """Check the health of the inference server by generating one token."""
if _global_state.tokenizer_manager.gracefully_exit:
logger.info("Health check request received during shutdown. Returning 503.")
return Response(status_code=503)
sampling_params = {"max_new_tokens": 1, "temperature": 0.0} sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}" rid = f"HEALTH_CHECK_{time.time()}"
...@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response: ...@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
async for _ in _global_state.tokenizer_manager.generate_request(gri, request): async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break break
tic = time.perf_counter() # This request is a special request.
# If the server already has something running, this request will be ignored, so it creates zero overhead.
# If the server is not running, this request will be run, so we know whether the server is healthy.
task = asyncio.create_task(gen()) task = asyncio.create_task(gen())
while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT:
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
tic = time.time()
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1) await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic: if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel() task.cancel()
......
...@@ -152,8 +152,6 @@ class GenerateReqInput: ...@@ -152,8 +152,6 @@ class GenerateReqInput:
else: else:
self._normalize_batch_inputs() self._normalize_batch_inputs()
self._validate_session_params()
def _validate_inputs(self): def _validate_inputs(self):
"""Validate that the input configuration is valid.""" """Validate that the input configuration is valid."""
if ( if (
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# ============================================================================== # ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker.""" """A scheduler that manages a tensor parallel GPU worker."""
import datetime
import faulthandler import faulthandler
import logging import logging
import os import os
...@@ -21,11 +20,10 @@ import signal ...@@ -21,11 +20,10 @@ import signal
import sys import sys
import threading import threading
import time import time
from collections import defaultdict, deque from collections import deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -37,7 +35,6 @@ from torch.distributed import barrier ...@@ -37,7 +35,6 @@ from torch.distributed import barrier
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ, INVALID_GRAMMAR_OBJ,
create_grammar_backend, create_grammar_backend,
...@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import ( ...@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue, DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationDecodeMixin,
) )
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.prefill import ( from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue, PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin, SchedulerDisaggregationPrefillMixin,
...@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import ( ...@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput, HealthCheckOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput, RpcReqInput,
RpcReqOutput, RpcReqOutput,
SetInternalStateReq, SetInternalStateReq,
...@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput, UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput, UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
) )
from sglang.srt.managers.mm_utils import init_embedding_cache from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import ( ...@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy, SchedulePolicy,
) )
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_metrics_mixin import (
RECORD_STEP_TIME,
SchedulerMetricsMixin,
)
from sglang.srt.managers.scheduler_output_processor_mixin import ( from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin, SchedulerOutputProcessorMixin,
) )
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
from sglang.srt.managers.session_controller import Session from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
...@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache ...@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__) ...@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -191,41 +184,11 @@ class EmbeddingBatchResult: ...@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
bid: int bid: int
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class IdleSleeper:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def __init__(self, sockets):
self.poller = zmq.Poller()
for s in sockets:
self.poller.register(s, zmq.POLLIN)
def maybe_sleep(self):
self.poller.poll(1000)
class Scheduler( class Scheduler(
SchedulerOutputProcessorMixin, SchedulerOutputProcessorMixin,
SchedulerUpdateWeightsMixin,
SchedulerProfilerMixin,
SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin, SchedulerDisaggregationPrefillMixin,
): ):
...@@ -266,7 +229,7 @@ class Scheduler( ...@@ -266,7 +229,7 @@ class Scheduler(
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size self.page_size = server_args.page_size
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info( compute_dp_attention_world_info(
server_args.enable_dp_attention, server_args.enable_dp_attention,
...@@ -284,10 +247,13 @@ class Scheduler( ...@@ -284,10 +247,13 @@ class Scheduler(
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False context, zmq.PULL, port_args.scheduler_input_ipc_name, False
) )
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket( self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager # Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket( self.send_to_detokenizer = get_zmq_socket(
...@@ -299,9 +265,6 @@ class Scheduler( ...@@ -299,9 +265,6 @@ class Scheduler(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False context, zmq.PUSH, port_args.detokenizer_ipc_name, False
) )
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
if self.server_args.sleep_on_idle: if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper( self.idle_sleeper = IdleSleeper(
[ [
...@@ -398,7 +361,7 @@ class Scheduler( ...@@ -398,7 +361,7 @@ class Scheduler(
global_server_args_dict.update(worker_global_server_args_dict) global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Hybrid # Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid self.is_hybrid = self.tp_worker.is_hybrid
if self.is_hybrid: if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size self.sliding_window_size = self.tp_worker.sliding_window_size
...@@ -515,6 +478,15 @@ class Scheduler( ...@@ -515,6 +478,15 @@ class Scheduler(
self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config) self.init_kv_events(server_args.kv_events_config)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
# Init request dispatcher # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
...@@ -545,22 +517,6 @@ class Scheduler( ...@@ -545,22 +517,6 @@ class Scheduler(
] ]
) )
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
def current_scheduler_metrics_enabled(self):
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()
def init_tokenizer(self): def init_tokenizer(self):
server_args = self.server_args server_args = self.server_args
...@@ -668,50 +624,6 @@ class Scheduler( ...@@ -668,50 +624,6 @@ class Scheduler(
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100")) embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
init_embedding_cache(embedding_cache_size * 1024 * 1024) init_embedding_cache(embedding_cache_size * 1024 * 1024)
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_start_forward_ct: Optional[int] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
"tp_rank": tp_rank,
"pp_rank": pp_rank,
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def init_disaggregation(self): def init_disaggregation(self):
self.transfer_backend = TransferBackend( self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend self.server_args.disaggregation_transfer_backend
...@@ -820,10 +732,7 @@ class Scheduler( ...@@ -820,10 +732,7 @@ class Scheduler(
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else: else:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.self_check_during_idle()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.last_batch = batch self.last_batch = batch
...@@ -866,10 +775,7 @@ class Scheduler( ...@@ -866,10 +775,7 @@ class Scheduler(
) )
elif batch is None: elif batch is None:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.self_check_during_idle()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.last_batch = batch self.last_batch = batch
...@@ -1003,10 +909,8 @@ class Scheduler( ...@@ -1003,10 +909,8 @@ class Scheduler(
# When the server is idle, self-check and re-init some states # When the server is idle, self-check and re-init some states
if server_is_idle: if server_is_idle:
self.check_memory() # When the server is idle, do self-check and re-init some states
self.check_tree_cache() self.self_check_during_idle()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
def recv_requests(self) -> List[Req]: def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
...@@ -1355,170 +1259,11 @@ class Scheduler( ...@@ -1355,170 +1259,11 @@ class Scheduler(
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req) self._add_request_to_queue(req)
def _emit_kv_metrics(self): def self_check_during_idle(self):
kv_metrics = KvMetrics() self.check_memory()
kv_metrics.request_active_slots = self.stats.num_running_reqs self.check_tree_cache()
kv_metrics.request_total_slots = self.max_running_requests self.new_token_ratio = self.init_new_token_ratio
kv_metrics.kv_active_blocks = int( self.maybe_sleep_on_idle()
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
logger.info(f)
if self.enable_metrics:
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def check_memory(self): def check_memory(self):
if self.is_hybrid: if self.is_hybrid:
...@@ -2422,22 +2167,6 @@ class Scheduler( ...@@ -2422,22 +2167,6 @@ class Scheduler(
barrier() barrier()
return RpcReqOutput(success, "" if not exec else str(exec)) return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq): def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = [] to_del = []
...@@ -2515,16 +2244,6 @@ class Scheduler( ...@@ -2515,16 +2244,6 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]: def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError() raise NotImplementedError()
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter( def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput: ) -> LoadLoRAAdapterReqOutput:
...@@ -2541,81 +2260,6 @@ class Scheduler( ...@@ -2541,81 +2260,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req) result = self.tp_worker.unload_lora_adapter(recv_req)
return result return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput): def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time t = recv_req.forward_sleep_time
if t is not None and t <= 0: if t is not None and t <= 0:
...@@ -2623,254 +2267,6 @@ class Scheduler( ...@@ -2623,254 +2267,6 @@ class Scheduler(
self.forward_sleep_time = t self.forward_sleep_time = t
return SlowDownReqOutput() return SlowDownReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
if recv_req.profile_by_stage or recv_req.start_step:
return self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
return self.start_profile(True)
else:
return self.stop_profile()
def init_profile(
self,
output_dir: Optional[str],
start_step: Optional[int],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
self.profile_by_stage = profile_by_stage
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities
self.profile_id = profile_id
if start_step:
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
elif start_step:
self.profiler_target_forward_ct = (
self.profiler_start_forward_ct + num_steps
)
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
self.profile_in_progress = True
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
)
torch.distributed.barrier(self.tp_cpu_group)
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
elif batch.forward_mode.is_idle():
pass
else:
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
if (
self.profiler_start_forward_ct
and self.profiler_start_forward_ct == self.forward_ct
):
self.start_profile()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq): def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD: if recv_req == ExpertDistributionReq.START_RECORD:
get_global_expert_distribution_recorder().start_record() get_global_expert_distribution_recorder().start_record()
...@@ -2879,7 +2275,7 @@ class Scheduler( ...@@ -2879,7 +2275,7 @@ class Scheduler(
elif recv_req == ExpertDistributionReq.DUMP_RECORD: elif recv_req == ExpertDistributionReq.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record() get_global_expert_distribution_recorder().dump_record()
else: else:
raise ValueError("Unrecognized ExpertDistributionReq value") raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
return ExpertDistributionReqOutput() return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput): def open_session(self, recv_req: OpenSessionReqInput):
...@@ -2915,34 +2311,41 @@ class Scheduler( ...@@ -2915,34 +2311,41 @@ class Scheduler(
prefix += f" PP{self.pp_rank}" prefix += f" PP{self.pp_rank}"
return prefix return prefix
def _publish_kv_events(self): def current_scheduler_metrics_enabled(self):
if self.enable_kv_cache_events: return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
class IdleSleeper:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
def is_work_request(recv_req): The simplest solution is to use zmq.Poller on all sockets that may receive
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)) data that needs handling immediately.
"""
def __init__(self, sockets):
self.poller = zmq.Poller()
for s in sockets:
self.poller.register(s, zmq.POLLIN)
def _export_static_state(model): def maybe_sleep(self):
return dict( self.poller.poll(1000)
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
] def is_health_check_generate_req(recv_req):
) return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def _import_static_state(model, static_params): def is_work_request(recv_req):
self_named_buffers = dict(model.named_buffers()) return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def run_scheduler_process( def run_scheduler_process(
......
import logging
import time
from collections import defaultdict
from typing import List, Optional
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.schedule_policy import PrefillAdder
from sglang.srt.managers.scheduler import Req, ScheduleBatch
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class SchedulerMetricsMixin:
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
"tp_rank": tp_rank,
"pp_rank": pp_rank,
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
logger.info(f)
if self.enable_metrics:
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def _emit_kv_metrics(self):
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
kv_metrics.kv_active_blocks = int(
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def _publish_kv_events(self):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
import logging
import os
import time
from pathlib import Path
from typing import List, Optional
import torch
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
from sglang.srt.model_executor.forward_batch_info import ForwardMode
logger = logging.getLogger(__name__)
class SchedulerProfilerMixin:
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_start_forward_ct: Optional[int] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_profile(
self,
output_dir: Optional[str],
start_step: Optional[int],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
self.profile_by_stage = profile_by_stage
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities
self.profile_id = profile_id
if start_step:
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
elif start_step:
self.profiler_target_forward_ct = (
self.profiler_start_forward_ct + num_steps
)
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
self.profile_in_progress = True
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
)
torch.distributed.barrier(self.tp_cpu_group)
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
elif batch.forward_mode.is_idle():
pass
else:
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
if (
self.profiler_start_forward_ct
and self.profiler_start_forward_ct == self.forward_ct
):
self.start_profile()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
if recv_req.profile_by_stage or recv_req.start_step:
return self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
return self.start_profile(True)
else:
return self.stop_profile()
import logging
from typing import Tuple
import torch
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
logger = logging.getLogger(__name__)
class SchedulerUpdateWeightsMixin:
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput()
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
...@@ -170,16 +170,6 @@ class ReqState: ...@@ -170,16 +170,6 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
class TokenizerManager: class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
...@@ -199,16 +189,6 @@ class TokenizerManager: ...@@ -199,16 +189,6 @@ class TokenizerManager:
else None else None
) )
self.crash_dump_folder = server_args.crash_dump_folder self.crash_dump_folder = server_args.crash_dump_folder
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
...@@ -218,8 +198,7 @@ class TokenizerManager: ...@@ -218,8 +198,7 @@ class TokenizerManager:
self.is_image_gen = self.model_config.is_image_gen self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id self.image_token_id = self.model_config.image_token_id
self._updating = False self.max_req_input_len = None # Will be set later in engine.py
self._cond = asyncio.Condition()
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors() import_processors()
...@@ -258,39 +237,57 @@ class TokenizerManager: ...@@ -258,39 +237,57 @@ class TokenizerManager:
revision=server_args.revision, revision=server_args.revision,
) )
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. # Init inter-process communication
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It context = zmq.asyncio.Context(2)
# serves as the source of truth for available adapters and maps user-friendly LoRA names self.recv_from_detokenizer = get_zmq_socket(
# to internally used unique LoRA IDs. context, zmq.PULL, port_args.tokenizer_ipc_name, True
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) )
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Store states # Request states
self.no_create_loop = False self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
self.asyncio_tasks = set()
# Health check
self.health_check_failed = False self.health_check_failed = False
self.gracefully_exit = False self.gracefully_exit = False
self.last_receive_tstamp = 0 self.last_receive_tstamp = 0
# Dumping
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000 self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = [] self.dump_request_list: List[Tuple] = []
self.crash_dump_request_list: deque[Tuple] = deque()
self.log_request_metadata = self.get_log_request_metadata() self.log_request_metadata = self.get_log_request_metadata()
self.crash_dump_request_list: deque[Tuple] = deque()
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Session
self.session_futures = {} # session_id -> asyncio event self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
self.asyncio_tasks = set()
# Weight updates
# The event to notify the weight sync is finished. # The event to notify the weight sync is finished.
self.model_update_lock = RWLock() self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None None
) )
self._is_updating = False
self._is_updating_cond = asyncio.Condition()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Lock to serialize LoRA update operations. # Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing # Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap. # LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock() self.lora_update_lock = asyncio.Lock()
# For pd disaggregtion # For PD disaggregtion
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
) )
...@@ -458,17 +455,11 @@ class TokenizerManager: ...@@ -458,17 +455,11 @@ class TokenizerManager:
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
created_time = time.time() created_time = time.time()
async with self._cond:
await self._cond.wait_for(lambda: not self._updating)
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if isinstance(obj, EmbeddingReqInput) and self.is_generation: async with self._is_updating_cond:
raise ValueError( await self._is_updating_cond.wait_for(lambda: not self._is_updating)
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
...@@ -567,6 +558,12 @@ class TokenizerManager: ...@@ -567,6 +558,12 @@ class TokenizerManager:
f"model's context length ({self.context_len} tokens)." f"model's context length ({self.context_len} tokens)."
) )
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
# Check total tokens (input + max_new_tokens) # Check total tokens (input + max_new_tokens)
max_new_tokens = obj.sampling_params.get("max_new_tokens") max_new_tokens = obj.sampling_params.get("max_new_tokens")
if ( if (
...@@ -959,14 +956,14 @@ class TokenizerManager: ...@@ -959,14 +956,14 @@ class TokenizerManager:
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self): async def pause_generation(self):
async with self._cond: async with self._is_updating_cond:
self._updating = True self._is_updating = True
self.abort_request(abort_all=True) self.abort_request(abort_all=True)
async def continue_generation(self): async def continue_generation(self):
async with self._cond: async with self._is_updating_cond:
self._updating = False self._is_updating = False
self._cond.notify_all() self._is_updating_cond.notify_all()
async def update_weights_from_disk( async def update_weights_from_disk(
self, self,
...@@ -1208,14 +1205,6 @@ class TokenizerManager: ...@@ -1208,14 +1205,6 @@ class TokenizerManager:
# Many DP ranks # Many DP ranks
return [res.internal_state for res in responses] return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
async def set_internal_state( async def set_internal_state(
self, obj: SetInternalStateReq self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput: ) -> SetInternalStateReqOutput:
...@@ -1224,6 +1213,14 @@ class TokenizerManager: ...@@ -1224,6 +1213,14 @@ class TokenizerManager:
) )
return [res.internal_state for res in responses] return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
def get_log_request_metadata(self): def get_log_request_metadata(self):
max_length = None max_length = None
skip_names = None skip_names = None
...@@ -1343,11 +1340,24 @@ class TokenizerManager: ...@@ -1343,11 +1340,24 @@ class TokenizerManager:
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping." "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
) )
return return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
if not self.crash_dump_folder: if not self.crash_dump_folder:
return return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
# Check if NFS directory is available
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
# expected_nfs_dir, os.W_OK
# )
use_nfs_dir = False
if not use_nfs_dir:
logger.error(
f"Expected NFS directory is not available or writable. Uploading to GCS."
)
data_to_dump = [] data_to_dump = []
if self.crash_dump_request_list: if self.crash_dump_request_list:
data_to_dump.extend(self.crash_dump_request_list) data_to_dump.extend(self.crash_dump_request_list)
...@@ -1357,7 +1367,12 @@ class TokenizerManager: ...@@ -1357,7 +1367,12 @@ class TokenizerManager:
for rid, state in self.rid_to_state.items(): for rid, state in self.rid_to_state.items():
if not state.finished: if not state.finished:
unfinished_requests.append( unfinished_requests.append(
(state.obj, {}, state.created_time, time.time()) (
state.obj,
state.out_list[-1] if state.out_list else {},
state.created_time,
time.time(),
)
) )
if unfinished_requests: if unfinished_requests:
data_to_dump.extend(unfinished_requests) data_to_dump.extend(unfinished_requests)
...@@ -1365,10 +1380,11 @@ class TokenizerManager: ...@@ -1365,10 +1380,11 @@ class TokenizerManager:
if not data_to_dump: if not data_to_dump:
return return
object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
filename = os.path.join( filename = os.path.join(
self.crash_dump_folder, self.crash_dump_folder,
os.getenv("HOSTNAME", None), os.getenv("HOSTNAME", None),
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl", object_name,
) )
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)
...@@ -1383,6 +1399,24 @@ class TokenizerManager: ...@@ -1383,6 +1399,24 @@ class TokenizerManager:
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}" f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
) )
def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
from google.cloud import storage
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(object_name)
blob.upload_from_filename(source_file_path, if_generation_match=0)
logger.error(
f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
)
if not use_nfs_dir:
_upload_file_to_gcs(
"sglang_crash_dump",
filename,
os.getenv("HOSTNAME", None) + "/" + object_name,
)
async def sigterm_watchdog(self): async def sigterm_watchdog(self):
while not self.gracefully_exit: while not self.gracefully_exit:
await asyncio.sleep(5) await asyncio.sleep(5)
...@@ -1426,7 +1460,7 @@ class TokenizerManager: ...@@ -1426,7 +1460,7 @@ class TokenizerManager:
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj) self._result_dispatcher(recv_obj)
self.last_receive_tstamp = time.perf_counter() self.last_receive_tstamp = time.time()
def _handle_batch_output( def _handle_batch_output(
self, self,
...@@ -1697,24 +1731,13 @@ class TokenizerManager: ...@@ -1697,24 +1731,13 @@ class TokenizerManager:
self.dump_requests_folder, self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
) )
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") self._dump_data_to_file(
data_list=self.dump_request_list,
to_dump = self.dump_request_list filename=filename,
log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
)
self.dump_request_list = [] self.dump_request_list = []
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": to_dump,
}
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump_with_server_args, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict): def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
current_time = time.time() current_time = time.time()
self.crash_dump_request_list.append( self.crash_dump_request_list.append(
...@@ -1727,6 +1750,22 @@ class TokenizerManager: ...@@ -1727,6 +1750,22 @@ class TokenizerManager:
): ):
self.crash_dump_request_list.popleft() self.crash_dump_request_list.popleft()
def _dump_data_to_file(
self, data_list: List[Tuple], filename: str, log_message: str
):
logger.info(log_message)
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": data_list.copy(),
}
def background_task():
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump_with_server_args, f)
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj): def _handle_abort_req(self, recv_obj):
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
state.finished = True state.finished = True
...@@ -1862,6 +1901,16 @@ class TokenizerManager: ...@@ -1862,6 +1901,16 @@ class TokenizerManager:
return scores return scores
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
async def print_exception_wrapper(func): async def print_exception_wrapper(func):
""" """
Sometimes an asyncio function does not print exception. Sometimes an asyncio function does not print exception.
......
...@@ -2071,6 +2071,9 @@ class PortArgs: ...@@ -2071,6 +2071,9 @@ class PortArgs:
dist_init_host, dist_init_port = dist_init_addr dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1 port_base = int(dist_init_port) + 1
detokenizer_port = port_base + 1
rpc_port = port_base + 2
metrics_ipc_name = port_base + 3
if dp_rank is None: if dp_rank is None:
# TokenizerManager to DataParallelController # TokenizerManager to DataParallelController
scheduler_input_port = port_base + 4 scheduler_input_port = port_base + 4
...@@ -2080,10 +2083,10 @@ class PortArgs: ...@@ -2080,10 +2083,10 @@ class PortArgs:
return PortArgs( return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
nccl_port=nccl_port, nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}", rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}", metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
) )
......
...@@ -291,17 +291,6 @@ def find_printable_text(text: str): ...@@ -291,17 +291,6 @@ def find_printable_text(text: str):
return text[: text.rfind(" ") + 1] return text[: text.rfind(" ") + 1]
def graceful_registry(sub_module_name: str):
def graceful_shutdown(signum, frame):
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
)
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} receive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
class LazyImport: class LazyImport:
"""Lazy import to make `import sglang` run faster.""" """Lazy import to make `import sglang` run faster."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment