Unverified Commit a4c3b121 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Split the scheduler into multiple mixin classes to reduce the file size (#8483)

parent 5973675b
......@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
......@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue
......
......@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
......@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
......
......@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs):
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
logger.warning(
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
if True: # Keep this check for internal code compatibility
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
# Note: This sigquit handler is used in the launch phase, and may be replaced by
# the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
def launch_phase_sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree(os.getpid())
signal.signal(signal.SIGCHLD, sigchld_handler)
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)
# Set mp start method
mp.set_start_method("spawn", force=True)
......
......@@ -238,6 +238,9 @@ async def health() -> Response:
@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
if _global_state.tokenizer_manager.gracefully_exit:
logger.info("Health check request received during shutdown. Returning 503.")
return Response(status_code=503)
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}"
......@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break
tic = time.perf_counter()
# This request is a special request.
# If the server already has something running, this request will be ignored, so it creates zero overhead.
# If the server is not running, this request will be run, so we know whether the server is healthy.
task = asyncio.create_task(gen())
while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT:
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
tic = time.time()
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
......
......@@ -152,8 +152,6 @@ class GenerateReqInput:
else:
self._normalize_batch_inputs()
self._validate_session_params()
def _validate_inputs(self):
"""Validate that the input configuration is valid."""
if (
......
......@@ -13,7 +13,6 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""
import datetime
import faulthandler
import logging
import os
......@@ -21,11 +20,10 @@ import signal
import sys
import threading
import time
from collections import defaultdict, deque
from collections import deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
......@@ -37,7 +35,6 @@ from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
create_grammar_backend,
......@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
......@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput,
RpcReqOutput,
SetInternalStateReq,
......@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.schedule_batch import (
......@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy,
)
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_metrics_mixin import (
RECORD_STEP_TIME,
SchedulerMetricsMixin,
)
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
......@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu()
......@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
bid: int
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class IdleSleeper:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def __init__(self, sockets):
self.poller = zmq.Poller()
for s in sockets:
self.poller.register(s, zmq.POLLIN)
def maybe_sleep(self):
self.poller.poll(1000)
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerUpdateWeightsMixin,
SchedulerProfilerMixin,
SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
):
......@@ -266,7 +229,7 @@ class Scheduler(
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
......@@ -284,10 +247,13 @@ class Scheduler(
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
......@@ -299,9 +265,6 @@ class Scheduler(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
if self.server_args.sleep_on_idle:
self.idle_sleeper = IdleSleeper(
[
......@@ -398,7 +361,7 @@ class Scheduler(
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Hybrid
# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
......@@ -515,6 +478,15 @@ class Scheduler(
self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
......@@ -545,22 +517,6 @@ class Scheduler(
]
)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
def current_scheduler_metrics_enabled(self):
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()
def init_tokenizer(self):
server_args = self.server_args
......@@ -668,50 +624,6 @@ class Scheduler(
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
init_embedding_cache(embedding_cache_size * 1024 * 1024)
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_start_forward_ct: Optional[int] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
"tp_rank": tp_rank,
"pp_rank": pp_rank,
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def init_disaggregation(self):
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
......@@ -820,10 +732,7 @@ class Scheduler(
self.process_batch_result(batch, result)
else:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
......@@ -866,10 +775,7 @@ class Scheduler(
)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
self.self_check_during_idle()
self.last_batch = batch
......@@ -1003,10 +909,8 @@ class Scheduler(
# When the server is idle, self-check and re-init some states
if server_is_idle:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
# When the server is idle, do self-check and re-init some states
self.self_check_during_idle()
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
......@@ -1355,170 +1259,11 @@ class Scheduler(
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
def _emit_kv_metrics(self):
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
kv_metrics.kv_active_blocks = int(
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
logger.info(f)
if self.enable_metrics:
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def self_check_during_idle(self):
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
def check_memory(self):
if self.is_hybrid:
......@@ -2422,22 +2167,6 @@ class Scheduler(
barrier()
return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = []
......@@ -2515,16 +2244,6 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput:
......@@ -2541,81 +2260,6 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
......@@ -2623,254 +2267,6 @@ class Scheduler(
self.forward_sleep_time = t
return SlowDownReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
if recv_req.profile_by_stage or recv_req.start_step:
return self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
return self.start_profile(True)
else:
return self.stop_profile()
def init_profile(
self,
output_dir: Optional[str],
start_step: Optional[int],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
self.profile_by_stage = profile_by_stage
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities
self.profile_id = profile_id
if start_step:
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
elif start_step:
self.profiler_target_forward_ct = (
self.profiler_start_forward_ct + num_steps
)
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
self.profile_in_progress = True
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
)
torch.distributed.barrier(self.tp_cpu_group)
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
elif batch.forward_mode.is_idle():
pass
else:
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
if (
self.profiler_start_forward_ct
and self.profiler_start_forward_ct == self.forward_ct
):
self.start_profile()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
get_global_expert_distribution_recorder().start_record()
......@@ -2879,7 +2275,7 @@ class Scheduler(
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput):
......@@ -2915,34 +2311,41 @@ class Scheduler(
prefix += f" PP{self.pp_rank}"
return prefix
def _publish_kv_events(self):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def current_scheduler_metrics_enabled(self):
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
class IdleSleeper:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing. This would lead not only
to power savings, but also to more CPU thermal headroom when a request
eventually comes. This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100% CPU usage.
def is_work_request(recv_req):
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
The simplest solution is to use zmq.Poller on all sockets that may receive
data that needs handling immediately.
"""
def __init__(self, sockets):
self.poller = zmq.Poller()
for s in sockets:
self.poller.register(s, zmq.POLLIN)
def maybe_sleep(self):
self.poller.poll(1000)
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def is_work_request(recv_req):
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
def run_scheduler_process(
......
import logging
import time
from collections import defaultdict
from typing import List, Optional
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.schedule_policy import PrefillAdder
from sglang.srt.managers.scheduler import Req, ScheduleBatch
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
class KvMetrics:
def __init__(self):
self.request_active_slots = None
self.request_total_slots = None
self.kv_active_blocks = None
self.kv_total_blocks = None
self.num_requests_waiting = None
self.gpu_cache_usage_perc = None
self.gpu_prefix_cache_hit_rate = None
self.data_parallel_rank = None
class SchedulerMetricsMixin:
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
"tp_rank": tp_rank,
"pp_rank": pp_rank,
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
logger.info(f)
if self.enable_metrics:
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def _emit_kv_metrics(self):
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
kv_metrics.kv_active_blocks = int(
self.stats.token_usage * self.max_total_num_tokens
)
kv_metrics.kv_total_blocks = self.max_total_num_tokens
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
if not self.send_metrics_from_scheduler.closed:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def _publish_kv_events(self):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
import logging
import os
import time
from pathlib import Path
from typing import List, Optional
import torch
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
from sglang.srt.model_executor.forward_batch_info import ForwardMode
logger = logging.getLogger(__name__)
class SchedulerProfilerMixin:
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_start_forward_ct: Optional[int] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_profile(
self,
output_dir: Optional[str],
start_step: Optional[int],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
self.profile_by_stage = profile_by_stage
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities
self.profile_id = profile_id
if start_step:
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
elif start_step:
self.profiler_target_forward_ct = (
self.profiler_start_forward_ct + num_steps
)
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
self.profile_in_progress = True
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
)
torch.distributed.barrier(self.tp_cpu_group)
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
elif batch.forward_mode.is_idle():
pass
else:
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
if (
self.profiler_start_forward_ct
and self.profiler_start_forward_ct == self.forward_ct
):
self.start_profile()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
if recv_req.profile_by_stage or recv_req.start_step:
return self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.start_step,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
)
return self.start_profile(True)
else:
return self.stop_profile()
import logging
from typing import Tuple
import torch
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
logger = logging.getLogger(__name__)
class SchedulerUpdateWeightsMixin:
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput()
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
......@@ -170,16 +170,6 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
......@@ -199,16 +189,6 @@ class TokenizerManager:
else None
)
self.crash_dump_folder = server_args.crash_dump_folder
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Read model args
self.model_path = server_args.model_path
......@@ -218,8 +198,7 @@ class TokenizerManager:
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
self._updating = False
self._cond = asyncio.Condition()
self.max_req_input_len = None # Will be set later in engine.py
if self.model_config.is_multimodal:
import_processors()
......@@ -258,39 +237,57 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Store states
# Request states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.asyncio_tasks = set()
# Health check
self.health_check_failed = False
self.gracefully_exit = False
self.last_receive_tstamp = 0
# Dumping
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.crash_dump_request_list: deque[Tuple] = deque()
self.log_request_metadata = self.get_log_request_metadata()
self.crash_dump_request_list: deque[Tuple] = deque()
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Session
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
self.asyncio_tasks = set()
# Weight updates
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self._is_updating = False
self._is_updating_cond = asyncio.Condition()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# For pd disaggregtion
# For PD disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
......@@ -458,17 +455,11 @@ class TokenizerManager:
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
async with self._cond:
await self._cond.wait_for(lambda: not self._updating)
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
async with self._is_updating_cond:
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
......@@ -567,6 +558,12 @@ class TokenizerManager:
f"model's context length ({self.context_len} tokens)."
)
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
# Check total tokens (input + max_new_tokens)
max_new_tokens = obj.sampling_params.get("max_new_tokens")
if (
......@@ -959,14 +956,14 @@ class TokenizerManager:
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self):
async with self._cond:
self._updating = True
async with self._is_updating_cond:
self._is_updating = True
self.abort_request(abort_all=True)
async def continue_generation(self):
async with self._cond:
self._updating = False
self._cond.notify_all()
async with self._is_updating_cond:
self._is_updating = False
self._is_updating_cond.notify_all()
async def update_weights_from_disk(
self,
......@@ -1208,14 +1205,6 @@ class TokenizerManager:
# Many DP ranks
return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
......@@ -1224,6 +1213,14 @@ class TokenizerManager:
)
return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
def get_log_request_metadata(self):
max_length = None
skip_names = None
......@@ -1343,11 +1340,24 @@ class TokenizerManager:
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
if not self.crash_dump_folder:
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
# Check if NFS directory is available
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
# expected_nfs_dir, os.W_OK
# )
use_nfs_dir = False
if not use_nfs_dir:
logger.error(
f"Expected NFS directory is not available or writable. Uploading to GCS."
)
data_to_dump = []
if self.crash_dump_request_list:
data_to_dump.extend(self.crash_dump_request_list)
......@@ -1357,7 +1367,12 @@ class TokenizerManager:
for rid, state in self.rid_to_state.items():
if not state.finished:
unfinished_requests.append(
(state.obj, {}, state.created_time, time.time())
(
state.obj,
state.out_list[-1] if state.out_list else {},
state.created_time,
time.time(),
)
)
if unfinished_requests:
data_to_dump.extend(unfinished_requests)
......@@ -1365,10 +1380,11 @@ class TokenizerManager:
if not data_to_dump:
return
object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
object_name,
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
......@@ -1383,6 +1399,24 @@ class TokenizerManager:
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
)
def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
from google.cloud import storage
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(object_name)
blob.upload_from_filename(source_file_path, if_generation_match=0)
logger.error(
f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
)
if not use_nfs_dir:
_upload_file_to_gcs(
"sglang_crash_dump",
filename,
os.getenv("HOSTNAME", None) + "/" + object_name,
)
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(5)
......@@ -1426,7 +1460,7 @@ class TokenizerManager:
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
self.last_receive_tstamp = time.perf_counter()
self.last_receive_tstamp = time.time()
def _handle_batch_output(
self,
......@@ -1697,24 +1731,13 @@ class TokenizerManager:
self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
)
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
to_dump = self.dump_request_list
self._dump_data_to_file(
data_list=self.dump_request_list,
filename=filename,
log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
)
self.dump_request_list = []
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": to_dump,
}
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump_with_server_args, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
current_time = time.time()
self.crash_dump_request_list.append(
......@@ -1727,6 +1750,22 @@ class TokenizerManager:
):
self.crash_dump_request_list.popleft()
def _dump_data_to_file(
self, data_list: List[Tuple], filename: str, log_message: str
):
logger.info(log_message)
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": data_list.copy(),
}
def background_task():
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump_with_server_args, f)
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj):
state = self.rid_to_state[recv_obj.rid]
state.finished = True
......@@ -1862,6 +1901,16 @@ class TokenizerManager:
return scores
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
......
......@@ -2071,6 +2071,9 @@ class PortArgs:
dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1
detokenizer_port = port_base + 1
rpc_port = port_base + 2
metrics_ipc_name = port_base + 3
if dp_rank is None:
# TokenizerManager to DataParallelController
scheduler_input_port = port_base + 4
......@@ -2080,10 +2083,10 @@ class PortArgs:
return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
)
......
......@@ -291,17 +291,6 @@ def find_printable_text(text: str):
return text[: text.rfind(" ") + 1]
def graceful_registry(sub_module_name: str):
def graceful_shutdown(signum, frame):
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
)
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} receive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
class LazyImport:
"""Lazy import to make `import sglang` run faster."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment