Unverified Commit 2d62af6b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix metrics and request tracing (TimeStats) (#11123)

parent a28b394f
...@@ -14,18 +14,17 @@ classifiers = [ ...@@ -14,18 +14,17 @@ classifiers = [
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
] ]
dependencies = [ dependencies = [
"aiohttp",
"requests",
"tqdm",
"numpy",
"IPython", "IPython",
"setproctitle", "aiohttp",
"anthropic>=0.20.0",
"blobfile==3.0.0", "blobfile==3.0.0",
"build", "build",
"compressed-tensors", "compressed-tensors",
"cuda-python",
"datasets", "datasets",
"einops", "einops",
"fastapi", "fastapi",
"flashinfer_python==0.4.0rc3",
"hf_transfer", "hf_transfer",
"huggingface_hub", "huggingface_hub",
"interegular", "interegular",
...@@ -33,8 +32,10 @@ dependencies = [ ...@@ -33,8 +32,10 @@ dependencies = [
"modelscope", "modelscope",
"msgspec", "msgspec",
"ninja", "ninja",
"openai==1.99.1", "numpy",
"nvidia-cutlass-dsl==4.2.1",
"openai-harmony==0.0.4", "openai-harmony==0.0.4",
"openai==1.99.1",
"orjson", "orjson",
"outlines==0.1.11", "outlines==0.1.11",
"packaging", "packaging",
...@@ -42,32 +43,30 @@ dependencies = [ ...@@ -42,32 +43,30 @@ dependencies = [
"pillow", "pillow",
"prometheus-client>=0.20.0", "prometheus-client>=0.20.0",
"psutil", "psutil",
"py-spy",
"pybase64", "pybase64",
"pydantic", "pydantic",
"pynvml", "pynvml",
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"requests",
"scipy", "scipy",
"sentencepiece", "sentencepiece",
"setproctitle",
"sgl-kernel==0.3.13",
"soundfile==0.13.1", "soundfile==0.13.1",
"timm==1.0.16",
"tiktoken", "tiktoken",
"timm==1.0.16",
"torch==2.8.0",
"torch_memory_saver==0.0.8",
"torchao==0.9.0", "torchao==0.9.0",
"torchaudio==2.8.0",
"torchvision",
"tqdm",
"transformers==4.56.1", "transformers==4.56.1",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.24", "xgrammar==0.1.24"
"sgl-kernel==0.3.13",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.4.0rc3",
"openai==1.99.1",
"tiktoken",
"anthropic>=0.20.0",
"torch_memory_saver==0.0.8",
"nvidia-cutlass-dsl==4.2.1",
] ]
[project.optional-dependencies] [project.optional-dependencies]
...@@ -79,15 +78,15 @@ test = [ ...@@ -79,15 +78,15 @@ test = [
"matplotlib", "matplotlib",
"pandas", "pandas",
"peft", "peft",
"sentence_transformers",
"pytest", "pytest",
"sentence_transformers",
"tabulate", "tabulate",
] ]
tracing = [ tracing = [
"opentelemetry-sdk",
"opentelemetry-api", "opentelemetry-api",
"opentelemetry-exporter-otlp", "opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc", "opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
] ]
all = ["sglang[test]", "sglang[decord]"] all = ["sglang[test]", "sglang[decord]"]
blackwell = ["sglang[test]", "sglang[decord]"] blackwell = ["sglang[test]", "sglang[decord]"]
......
...@@ -21,6 +21,7 @@ Life cycle of a request in the decode server ...@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
...@@ -422,9 +423,13 @@ class DecodePreallocQueue: ...@@ -422,9 +423,13 @@ class DecodePreallocQueue:
kv_indices, self.token_to_kv_pool_allocator.page_size kv_indices, self.token_to_kv_pool_allocator.page_size
) )
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
time.perf_counter()
)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
self.queue = [ self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
...@@ -625,6 +630,7 @@ class DecodeTransferQueue: ...@@ -625,6 +630,7 @@ class DecodeTransferQueue:
decode_req.req.output_topk_p = output_topk_p decode_req.req.output_topk_p = output_topk_p
decode_req.req.output_topk_index = output_topk_index decode_req.req.output_topk_index = output_topk_index
decode_req.req.hidden_states_tensor = output_hidden_states decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob: if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append( decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item() output_token_logprobs_val[0].item()
...@@ -645,10 +651,17 @@ class DecodeTransferQueue: ...@@ -645,10 +651,17 @@ class DecodeTransferQueue:
if hasattr(decode_req.kv_receiver, "clear"): if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear() decode_req.kv_receiver.clear()
decode_req.kv_receiver = None
indices_to_remove.add(i)
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
# special handling for sampling_params.max_new_tokens == 1 # special handling for sampling_params.max_new_tokens == 1
if decode_req.req.sampling_params.max_new_tokens == 1: if decode_req.req.sampling_params.max_new_tokens == 1:
# finish immediately # finish immediately
decode_req.req.time_stats.forward_entry_time = (
decode_req.req.time_stats.completion_time
) = time.perf_counter()
decode_req.req.check_finished() decode_req.req.check_finished()
self.scheduler.stream_output( self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob [decode_req.req], decode_req.req.return_logprob
...@@ -656,8 +669,6 @@ class DecodeTransferQueue: ...@@ -656,8 +669,6 @@ class DecodeTransferQueue:
self.tree_cache.cache_finished_req(decode_req.req) self.tree_cache.cache_finished_req(decode_req.req)
else: else:
transferred_reqs.append(decode_req.req) transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [ elif poll in [
KVPoll.Bootstrapping, KVPoll.Bootstrapping,
KVPoll.WaitingForInput, KVPoll.WaitingForInput,
...@@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin:
if len(can_run_list) == 0: if len(can_run_list) == 0:
return None return None
for req in can_run_list:
req.time_stats.forward_entry_time = time.perf_counter()
# construct a schedule batch with those requests and mark as decode # construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
can_run_list, can_run_list,
......
...@@ -21,6 +21,7 @@ from __future__ import annotations ...@@ -21,6 +21,7 @@ from __future__ import annotations
import logging import logging
import threading import threading
import time
from collections import deque from collections import deque
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Type from typing import TYPE_CHECKING, List, Optional, Type
...@@ -263,9 +264,10 @@ class PrefillBootstrapQueue: ...@@ -263,9 +264,10 @@ class PrefillBootstrapQueue:
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req) bootstrapped_reqs.append(req)
indices_to_remove.add(i) indices_to_remove.add(i)
req.time_stats.wait_queue_entry_time = time.perf_counter()
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
self.queue = [ self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
...@@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin:
for i, (req, next_token_id) in enumerate( for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True) zip(batch.reqs, next_token_ids, strict=True)
): ):
req: Req
if req.is_chunked <= 0: if req.is_chunked <= 0:
# There is no output_ids for prefill # There is no output_ids for prefill
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
...@@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
) )
logprob_pt += num_input_logprobs logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True) self.send_kv_chunk(req, last_chunk=True)
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
if req.grammar is not None: if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue. # FIXME: this try-except block is for handling unexpected xgrammar issue.
...@@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
else: else:
assert False, f"Unexpected polling state {poll=}" assert False, f"Unexpected polling state {poll=}"
for req in done_reqs:
req.time_stats.completion_time = time.perf_counter()
# Stream requests which have finished transfer # Stream requests which have finished transfer
self.stream_output( self.stream_output(
done_reqs, done_reqs,
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
from collections import deque from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Type, Union from typing import TYPE_CHECKING, Optional, Type
import numpy as np import numpy as np
import torch import torch
......
...@@ -41,7 +41,7 @@ import time ...@@ -41,7 +41,7 @@ import time
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender ...@@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
...@@ -452,6 +453,7 @@ class Req: ...@@ -452,6 +453,7 @@ class Req:
bootstrap_host: Optional[str] = None, bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None, bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
disagg_mode: Optional[DisaggregationMode] = None,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None, vocab_size: Optional[int] = None,
priority: Optional[int] = None, priority: Optional[int] = None,
...@@ -628,10 +630,8 @@ class Req: ...@@ -628,10 +630,8 @@ class Req:
# For metrics # For metrics
self.metrics_collector = metrics_collector self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats() self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
self.has_log_time_stats: bool = False self.has_log_time_stats: bool = False
self.queue_time_start = None
self.queue_time_end = None
self.last_tic = time.monotonic() self.last_tic = time.monotonic()
# For disaggregation # For disaggregation
...@@ -668,9 +668,9 @@ class Req: ...@@ -668,9 +668,9 @@ class Req:
def add_latency(self, stage: RequestStage): def add_latency(self, stage: RequestStage):
if self.metrics_collector is None: if self.metrics_collector is None:
return return
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
now = time.monotonic() now = time.monotonic()
self.metrics_collector.observe_request_latency_seconds( self.metrics_collector.observe_per_stage_req_latency(
stage.value, now - self.last_tic stage.value, now - self.last_tic
) )
self.last_tic = now self.last_tic = now
...@@ -834,10 +834,10 @@ class Req: ...@@ -834,10 +834,10 @@ class Req:
return return
if self.bootstrap_room is not None: if self.bootstrap_room is not None:
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
else: else:
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})" prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
logger.info(f"{prefix}: {self.time_stats}") logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
self.has_log_time_stats = True self.has_log_time_stats = True
def set_finish_with_abort(self, error_msg: str): def set_finish_with_abort(self, error_msg: str):
...@@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) / total_max_new_tokens ) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio) new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio return retracted_reqs, new_estimate_ratio, []
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx] req = self.reqs[idx]
......
...@@ -276,9 +276,13 @@ class SchedulePolicy: ...@@ -276,9 +276,13 @@ class SchedulePolicy:
) -> None: ) -> None:
"""Sorts the waiting queue based on the request priority then received titmestamp.""" """Sorts the waiting queue based on the request priority then received titmestamp."""
if schedule_low_priority_values_first: if schedule_low_priority_values_first:
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start)) waiting_queue.sort(
key=lambda x: (x.priority, x.time_stats.wait_queue_entry_time)
)
else: else:
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start)) waiting_queue.sort(
key=lambda x: (-x.priority, x.time_stats.wait_queue_entry_time)
)
@staticmethod @staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
...@@ -642,12 +646,12 @@ class PrefillAdder: ...@@ -642,12 +646,12 @@ class PrefillAdder:
if server_args.schedule_low_priority_values_first: if server_args.schedule_low_priority_values_first:
sorted_running_reqs = sorted( sorted_running_reqs = sorted(
self.running_batch.reqs, self.running_batch.reqs,
key=lambda x: (-x.priority, -x.queue_time_start), key=lambda x: (-x.priority, -x.time_stats.wait_queue_entry_time),
) )
else: else:
sorted_running_reqs = sorted( sorted_running_reqs = sorted(
self.running_batch.reqs, self.running_batch.reqs,
key=lambda x: (x.priority, -x.queue_time_start), key=lambda x: (x.priority, -x.time_stats.wait_queue_entry_time),
) )
preemptible_reqs = [] preemptible_reqs = []
min_tokens_to_remove = ( min_tokens_to_remove = (
......
...@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm ...@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import ( from sglang.srt.tracing.trace import (
process_tracing_init, process_tracing_init,
trace_event,
trace_set_proc_propagate_context, trace_set_proc_propagate_context,
trace_set_thread_info, trace_set_thread_info,
trace_slice, trace_slice_batch,
trace_slice_end, trace_slice_end,
trace_slice_start, trace_slice_start,
) )
...@@ -263,6 +262,7 @@ class Scheduler( ...@@ -263,6 +262,7 @@ class Scheduler(
server_args.enable_metrics_for_all_schedulers server_args.enable_metrics_for_all_schedulers
) )
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0 self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
self.enable_trace = server_args.enable_trace
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
...@@ -899,10 +899,6 @@ class Scheduler( ...@@ -899,10 +899,6 @@ class Scheduler(
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
...@@ -924,10 +920,6 @@ class Scheduler( ...@@ -924,10 +920,6 @@ class Scheduler(
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch: if batch:
batch.launch_done = threading.Event() batch.launch_done = threading.Event()
result = self.run_batch(batch) result = self.run_batch(batch)
...@@ -1192,10 +1184,13 @@ class Scheduler( ...@@ -1192,10 +1184,13 @@ class Scheduler(
src=self.tp_group.ranks[0], src=self.tp_group.ranks[0],
) )
for req in recv_reqs: if self.enable_trace:
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)): for req in recv_reqs:
trace_set_proc_propagate_context(req.rid, req.trace_context) if isinstance(
trace_slice_start("", req.rid, anonymous=True) req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)
return recv_reqs return recv_reqs
...@@ -1277,6 +1272,7 @@ class Scheduler( ...@@ -1277,6 +1272,7 @@ class Scheduler(
bootstrap_host=recv_req.bootstrap_host, bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port, bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room, bootstrap_room=recv_req.bootstrap_room,
disagg_mode=self.disaggregation_mode,
data_parallel_rank=recv_req.data_parallel_rank, data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size, vocab_size=self.model_config.vocab_size,
priority=recv_req.priority, priority=recv_req.priority,
...@@ -1403,7 +1399,6 @@ class Scheduler( ...@@ -1403,7 +1399,6 @@ class Scheduler(
req.set_finish_with_abort(error_msg) req.set_finish_with_abort(error_msg)
if add_to_grammar_queue: if add_to_grammar_queue:
req.queue_time_start = time.perf_counter()
self.grammar_queue.append(req) self.grammar_queue.append(req)
else: else:
self._add_request_to_queue(req) self._add_request_to_queue(req)
...@@ -1419,23 +1414,6 @@ class Scheduler( ...@@ -1419,23 +1414,6 @@ class Scheduler(
for tokenized_req in recv_req: for tokenized_req in recv_req:
self.handle_generate_request(tokenized_req) self.handle_generate_request(tokenized_req)
def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
def _prefetch_kvcache(self, req: Req): def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage: if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
...@@ -1449,19 +1427,27 @@ class Scheduler( ...@@ -1449,19 +1427,27 @@ class Scheduler(
req.rid, req.last_host_node, new_input_tokens, last_hash req.rid, req.last_host_node, new_input_tokens, last_hash
) )
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.NULL:
self.disagg_prefill_bootstrap_queue.extend( self._set_or_validate_priority(req)
reqs, self.model_config.num_key_value_heads if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
req.time_stats.wait_queue_entry_time = time.perf_counter()
trace_slice_end("process req", req.rid, auto_next_anon=True)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
) )
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted) if not is_retracted:
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
else: else:
for req in reqs: raise ValueError(f"Invalid {self.disaggregation_mode=}")
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
def _set_or_validate_priority(self, req: Req): def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode.""" """Set the default priority value, or abort the request based on the priority scheduling mode."""
...@@ -1500,7 +1486,7 @@ class Scheduler( ...@@ -1500,7 +1486,7 @@ class Scheduler(
direction = 1 if self.schedule_low_priority_values_first else -1 direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: ( key_fn = lambda item: (
direction * item[1].priority, direction * item[1].priority,
item[1].queue_time_start, item[1].time_stats.wait_queue_entry_time,
) )
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn) idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = ( abort_existing_req = (
...@@ -1902,14 +1888,14 @@ class Scheduler( ...@@ -1902,14 +1888,14 @@ class Scheduler(
if self.enable_metrics: if self.enable_metrics:
# only record queue time when enable_metrics is True to avoid overhead # only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list: for req in can_run_list:
req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING) req.add_latency(RequestStage.PREFILL_WAITING)
self.waiting_queue = [ self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if adder.preempt_list: if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list) for req in adder.preempt_list:
self._add_request_to_queue(req)
if adder.new_chunked_req is not None: if adder.new_chunked_req is not None:
assert self.chunked_req is None assert self.chunked_req is None
...@@ -1920,7 +1906,16 @@ class Scheduler( ...@@ -1920,7 +1906,16 @@ class Scheduler(
# Print stats # Print stats
if self.current_scheduler_metrics_enabled(): if self.current_scheduler_metrics_enabled():
self.log_prefill_stats(adder, can_run_list, running_bs) self.log_prefill_stats(adder, can_run_list, running_bs, 0)
for req in can_run_list:
if req.time_stats.forward_entry_time == 0:
# Avoid update chunked request many times
req.time_stats.forward_entry_time = time.perf_counter()
if self.enable_metrics:
self.metrics_collector.observe_queue_time(
req.time_stats.get_queueing_time(),
)
# Create a new batch # Create a new batch
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
...@@ -1975,19 +1970,25 @@ class Scheduler( ...@@ -1975,19 +1970,25 @@ class Scheduler(
TEST_RETRACT and batch.batch_size() > 10 TEST_RETRACT and batch.batch_size() > 10
): ):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) self.server_args
num_retracted_reqs = len(retracted_reqs) )
self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj(
AbortReq(req.rid, abort_reason=req.to_abort_message)
)
logger.info( logger.info(
"KV cache pool is full. Retract requests. " "KV cache pool is full. Retract requests. "
f"#retracted_reqs: {num_retracted_reqs}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
) )
self._extend_requests_to_queue(retracted_reqs, is_retracted=True) for req in retracted_reqs:
self.total_retracted_reqs += num_retracted_reqs self._add_request_to_queue(req, is_retracted=True)
else: else:
self.new_token_ratio = max( self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay, self.new_token_ratio - self.new_token_ratio_decay,
...@@ -2086,23 +2087,14 @@ class Scheduler( ...@@ -2086,23 +2087,14 @@ class Scheduler(
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done) self.process_batch_result_decode(batch, result, launch_done)
for req in batch.reqs: if self.enable_trace:
trace_slice( trace_slice_batch("decode loop", batch.reqs)
"decode loop",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done) self.process_batch_result_prefill(batch, result, launch_done)
for req in batch.reqs: if self.enable_trace:
trace_slice( trace_slice_batch("prefill", batch.reqs)
"prefill",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done) self.tp_worker.resolve_last_batch_result(launch_done)
...@@ -2261,12 +2253,13 @@ class Scheduler( ...@@ -2261,12 +2253,13 @@ class Scheduler(
if req.finished(): # It is aborted by AbortReq if req.finished(): # It is aborted by AbortReq
num_ready_reqs += 1 num_ready_reqs += 1
continue continue
req.grammar = req.grammar.result(timeout=0.03) req.grammar = req.grammar.result(timeout=0.03)
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ: if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort( error_msg = f"Invalid grammar request: {req.grammar_key=}"
f"Invalid grammar request: {req.grammar_key=}" req.set_finish_with_abort(error_msg)
)
num_ready_reqs += 1 num_ready_reqs += 1
except futures._base.TimeoutError: except futures._base.TimeoutError:
req.grammar_wait_ct += 1 req.grammar_wait_ct += 1
...@@ -2298,9 +2291,8 @@ class Scheduler( ...@@ -2298,9 +2291,8 @@ class Scheduler(
req.grammar = req.grammar.result() req.grammar = req.grammar.result()
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ: if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort( error_msg = f"Invalid grammar request: {req.grammar_key=}"
f"Invalid grammar request: {req.grammar_key=}" req.set_finish_with_abort(error_msg)
)
else: else:
num_ready_reqs_max = num_ready_reqs num_ready_reqs_max = num_ready_reqs
num_timeout_reqs_max = num_timeout_reqs num_timeout_reqs_max = num_timeout_reqs
...@@ -2308,12 +2300,14 @@ class Scheduler( ...@@ -2308,12 +2300,14 @@ class Scheduler(
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max): for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
req = self.grammar_queue[i] req = self.grammar_queue[i]
req.grammar.cancel() req.grammar.cancel()
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}" error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
req.set_finish_with_abort(error_msg) req.set_finish_with_abort(error_msg)
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) for req in self.grammar_queue[:num_ready_reqs]:
self._add_request_to_queue(req)
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch): def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
...@@ -2795,17 +2789,11 @@ def run_scheduler_process( ...@@ -2795,17 +2789,11 @@ def run_scheduler_process(
pipe_writer, pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None, balance_meta: Optional[DPBalanceMeta] = None,
): ):
if server_args.enable_trace: # Generate the logger prefix
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])
# Generate the prefix
prefix = "" prefix = ""
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
dp_rank = int(os.environ["SGLANG_DP_RANK"])
if dp_rank is not None: if dp_rank is not None:
prefix += f" DP{dp_rank}" prefix += f" DP{dp_rank}"
if server_args.tp_size > 1: if server_args.tp_size > 1:
...@@ -2821,10 +2809,6 @@ def run_scheduler_process( ...@@ -2821,10 +2809,6 @@ def run_scheduler_process(
kill_itself_when_parent_died() kill_itself_when_parent_died()
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configure the logger # Configure the logger
configure_logger(server_args, prefix=prefix) configure_logger(server_args, prefix=prefix)
suppress_other_loggers() suppress_other_loggers()
...@@ -2832,6 +2816,15 @@ def run_scheduler_process( ...@@ -2832,6 +2816,15 @@ def run_scheduler_process(
# Set cpu affinity to this gpu process # Set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])
# Set up tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
# Create a scheduler and run the event loop # Create a scheduler and run the event loop
try: try:
......
...@@ -47,8 +47,11 @@ class SchedulerMetricsMixin: ...@@ -47,8 +47,11 @@ class SchedulerMetricsMixin:
self.spec_num_total_forward_ct = 0 self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0 self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0 self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0 self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.stats = SchedulerStats() self.stats = SchedulerStats()
if self.enable_metrics: if self.enable_metrics:
engine_type = "unified" engine_type = "unified"
labels = { labels = {
...@@ -82,12 +85,14 @@ class SchedulerMetricsMixin: ...@@ -82,12 +85,14 @@ class SchedulerMetricsMixin:
adder: PrefillAdder, adder: PrefillAdder,
can_run_list: List[Req], can_run_list: List[Req],
running_bs: int, running_bs: int,
running_bs_offline_batch: int,
): ):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter() self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens self.last_prefill_tokens = adder.log_input_tokens
# TODO: generalize this for various memory pools
if self.is_hybrid: if self.is_hybrid:
( (
full_num_used, full_num_used,
...@@ -101,51 +106,53 @@ class SchedulerMetricsMixin: ...@@ -101,51 +106,53 @@ class SchedulerMetricsMixin:
) = self._get_swa_token_info() ) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used) num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage) token_usage = max(full_token_usage, swa_token_usage)
token_msg = ( token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, " f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, "
) )
else: else:
num_used, token_usage, _, _ = self._get_token_info() num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, " token_usage_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = ( f = (
f"Prefill batch. " f"Prefill batch. "
f"#new-seq: {num_new_seq}, " f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}" f"{token_usage_msg}"
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue)}, "
) )
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, " f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, " f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
logger.info(f) logger.info(f)
if self.enable_metrics: if self.enable_metrics:
# Basics
total_tokens = adder.log_input_tokens + adder.log_hit_tokens total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = ( cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0 adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
) )
self.stats.num_running_reqs = running_bs self.stats.num_running_reqs = running_bs
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2) self.stats.token_usage = token_usage
if self.is_hybrid:
self.stats.swa_token_usage = swa_token_usage
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0 # Retract
for req in can_run_list: self.stats.num_retracted_reqs = self.num_retracted_reqs
total_queue_latency += req.queue_time_end - req.queue_time_start self.stats.num_paused_reqs = self.num_paused_reqs
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq self.num_retracted_reqs = self.num_paused_reqs = 0
# PD disaggregation
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len( self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue self.disagg_prefill_bootstrap_queue.queue
...@@ -153,7 +160,18 @@ class SchedulerMetricsMixin: ...@@ -153,7 +160,18 @@ class SchedulerMetricsMixin:
self.stats.num_prefill_inflight_queue_reqs = len( self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue self.disagg_prefill_inflight_queue
) )
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
# Others
self.calculate_utilization()
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics() self._emit_kv_metrics()
self._publish_kv_events() self._publish_kv_events()
...@@ -166,8 +184,12 @@ class SchedulerMetricsMixin: ...@@ -166,8 +184,12 @@ class SchedulerMetricsMixin:
gap_latency = time.perf_counter() - self.last_decode_stats_tic gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter() self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0 self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs) num_running_reqs = len(batch.reqs)
num_running_reqs_offline_batch = 0
# TODO: generalize this for various memory pools
if self.is_hybrid: if self.is_hybrid:
( (
full_num_used, full_num_used,
...@@ -181,7 +203,7 @@ class SchedulerMetricsMixin: ...@@ -181,7 +203,7 @@ class SchedulerMetricsMixin:
) = self._get_swa_token_info() ) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used) num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage) token_usage = max(full_token_usage, swa_token_usage)
token_msg = ( token_usage_msg = (
f"#full token: {full_num_used}, " f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, " f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, " f"#swa token: {swa_num_used}, "
...@@ -189,14 +211,14 @@ class SchedulerMetricsMixin: ...@@ -189,14 +211,14 @@ class SchedulerMetricsMixin:
) )
else: else:
num_used, token_usage, _, _ = self._get_token_info() num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME: if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append( self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval gap_latency / self.server_args.decode_log_interval
) )
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}" msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
spec_accept_length = 0 spec_accept_length = 0
...@@ -208,41 +230,66 @@ class SchedulerMetricsMixin: ...@@ -208,41 +230,66 @@ class SchedulerMetricsMixin:
self.cum_spec_accept_count += self.spec_num_total_forward_ct self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, " msg += f"accept len: {spec_accept_length:.2f}, "
cache_hit_rate = 0.0
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += ( msg += (
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, " f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
) )
logger.info(msg) logger.info(msg)
if self.enable_metrics: if self.enable_metrics:
# Basics
self.stats.num_running_reqs = num_running_reqs self.stats.num_running_reqs = num_running_reqs
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2) self.stats.token_usage = token_usage
self.stats.cache_hit_rate = 0.0 if self.is_hybrid:
self.stats.swa_token_usage = swa_token_usage
self.stats.gen_throughput = self.last_gen_throughput self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate
self.stats.spec_accept_length = spec_accept_length self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.stats.avg_request_queue_latency = 0.0 # Retract
if self.disaggregation_mode == DisaggregationMode.DECODE: self.stats.num_retracted_reqs = self.num_retracted_reqs
self.stats.num_paused_reqs = self.num_paused_reqs
self.num_retracted_reqs = self.num_paused_reqs = 0
# PD disaggregation
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
)
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len( self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue self.disagg_decode_prealloc_queue.queue
) )
self.stats.num_decode_transfer_queue_reqs = len( self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue self.disagg_decode_transfer_queue.queue
) )
# Others
self.calculate_utilization()
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics() self._emit_kv_metrics()
self._publish_kv_events() self._publish_kv_events()
def _emit_kv_metrics(self: Scheduler): def _emit_kv_metrics(self: Scheduler):
if not self.enable_kv_cache_events:
return
kv_metrics = KvMetrics() kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests kv_metrics.request_total_slots = self.max_running_requests
...@@ -259,11 +306,13 @@ class SchedulerMetricsMixin: ...@@ -259,11 +306,13 @@ class SchedulerMetricsMixin:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics) self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def _publish_kv_events(self: Scheduler): def _publish_kv_events(self: Scheduler):
if self.enable_kv_cache_events: if not self.enable_kv_cache_events:
events = self.tree_cache.take_events() return
if events:
batch = KVEventBatch(ts=time.time(), events=events) events = self.tree_cache.take_events()
self.kv_event_publisher.publish(batch) if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def maybe_update_dp_balance_data( def maybe_update_dp_balance_data(
self: Scheduler, recv_req: TokenizedGenerateReqInput self: Scheduler, recv_req: TokenizedGenerateReqInput
...@@ -349,3 +398,17 @@ class SchedulerMetricsMixin: ...@@ -349,3 +398,17 @@ class SchedulerMetricsMixin:
# 2. Atomically write local_tokens and onfly into shm under the mutex # 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list) meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens) meta.set_shared_local_tokens(local_tokens)
def calculate_utilization(self):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.utilization = -1
else:
if (
self.stats.max_running_requests_under_SLO is not None
and self.stats.max_running_requests_under_SLO > 0
):
self.stats.utilization = max(
self.stats.num_running_reqs
/ self.stats.max_running_requests_under_SLO,
self.stats.token_usage / 0.9,
)
...@@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin: ...@@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin:
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time() req.time_stats.completion_time = time.perf_counter()
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match # This updates radix so others can match
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
...@@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin: ...@@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin:
else: else:
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time() req.time_stats.completion_time = time.perf_counter()
if req.return_logprob and batch.spec_algorithm.is_none(): if req.return_logprob and batch.spec_algorithm.is_none():
# speculative worker handles logprob in speculative decoding # speculative worker handles logprob in speculative decoding
...@@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin: ...@@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin:
and self.tp_rank == 0 and self.tp_rank == 0
and self.server_args.enable_request_time_stats_logging and self.server_args.enable_request_time_stats_logging
): ):
print(f"{req.finished_reason=}")
req.log_time_stats() req.log_time_stats()
# Send to detokenizer # Send to detokenizer
......
...@@ -5,6 +5,7 @@ import copy ...@@ -5,6 +5,7 @@ import copy
import logging import logging
import os import os
import time import time
import uuid
from collections import deque from collections import deque
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -24,6 +25,7 @@ import zmq ...@@ -24,6 +25,7 @@ import zmq
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput, ClearHiCacheReqInput,
ClearHiCacheReqOutput, ClearHiCacheReqOutput,
CloseSessionReqInput,
DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput, DestroyWeightsUpdateGroupReqOutput,
ExpertDistributionReq, ExpertDistributionReq,
...@@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqOutput, LoadLoRAAdapterReqOutput,
LoRAUpdateResult, LoRAUpdateResult,
MultiTokenizerWrapper, MultiTokenizerWrapper,
OpenSessionReqInput,
ProfileReq, ProfileReq,
ProfileReqOutput, ProfileReqOutput,
ProfileReqType, ProfileReqType,
...@@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin: ...@@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin:
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]: async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
req = GetLoadReqInput() req = GetLoadReqInput()
return await self.get_load_communicator(req) return await self.get_load_communicator(req)
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
if obj.session_id is None:
obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
await self.send_to_scheduler.send_pyobj(obj)
def get_log_request_metadata(self):
max_length = None
skip_names = None
out_skip_names = None
if self.log_requests:
if self.log_requests_level == 0:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 1:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30
else:
raise ValueError(
f"Invalid --log-requests-level: {self.log_requests_level=}"
)
return max_length, skip_names, out_skip_names
...@@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else None else None
) )
self.crash_dump_folder = server_args.crash_dump_folder self.crash_dump_folder = server_args.crash_dump_folder
self.enable_trace = server_args.enable_trace
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
...@@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# If it's a single value, add worker_id prefix # If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}" obj.rid = f"{self.worker_id}_{obj.rid}"
if obj.is_single: if self.enable_trace:
bootstrap_room = ( self._trace_request_start(obj, created_time)
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
)
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else:
for i in range(len(obj.rid)):
bootstrap_room = (
obj.bootstrap_room[i]
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None
)
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
)
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
...@@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
req = AbortReq(rid, abort_all) req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics: if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request() # TODO: also use custom_labels from the request
self.metrics_collector.observe_one_aborted_request(
self.metrics_collector.labels
)
async def pause_generation(self): async def pause_generation(self):
async with self.is_pause_cond: async with self.is_pause_cond:
...@@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
all_paused_requests = [r.num_paused_requests for r in result] all_paused_requests = [r.num_paused_requests for r in result]
return all_success, all_message, all_paused_requests return all_success, all_message, all_paused_requests
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
if obj.session_id is None:
obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
await self.send_to_scheduler.send_pyobj(obj)
def get_log_request_metadata(self):
max_length = None
skip_names = None
out_skip_names = None
if self.log_requests:
if self.log_requests_level == 0:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 1:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30
else:
raise ValueError(
f"Invalid --log-requests-level: {self.log_requests_level=}"
)
return max_length, skip_names, out_skip_names
def configure_logging(self, obj: ConfigureLoggingReq): def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None: if obj.log_requests is not None:
self.log_requests = obj.log_requests self.log_requests = obj.log_requests
...@@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# Drain requests # Drain requests
while True: while True:
remain_num_req = len(self.rid_to_state) remain_num_req = len(self.rid_to_state)
remaining_rids = list(self.rid_to_state.keys())
if self.server_status == ServerStatus.UnHealthy: if self.server_status == ServerStatus.UnHealthy:
# if health check failed, we should exit immediately # if health check failed, we should exit immediately
logger.error( logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", "Signal SIGTERM received while health check failed. Force exiting."
remain_num_req,
) )
self.dump_requests_before_crash() self.dump_requests_before_crash()
break break
...@@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"): elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
# if force shutdown flag set, exit immediately # if force shutdown flag set, exit immediately
logger.error( logger.error(
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d", "Signal SIGTERM received while force shutdown flag set. Force exiting."
remain_num_req,
) )
break break
logger.info( logger.info(
f"Gracefully exiting... remaining number of requests {remain_num_req}" f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
) )
if remain_num_req > 0: if remain_num_req > 0:
await asyncio.sleep(5) await asyncio.sleep(5)
...@@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
load_udpate_req = WatchLoadUpdateReq(loads=loads) load_udpate_req = WatchLoadUpdateReq(loads=loads)
self.send_to_scheduler.send_pyobj(load_udpate_req) self.send_to_scheduler.send_pyobj(load_udpate_req)
def _trace_request_start(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
created_time: Optional[float] = None,
):
if obj.is_single:
bootstrap_room = (
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
)
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else:
for i in range(len(obj.rid)):
bootstrap_room = (
obj.bootstrap_room[i]
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None
)
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
)
class ServerStatus(Enum): class ServerStatus(Enum):
Up = "Up" Up = "Up"
...@@ -1933,7 +1866,7 @@ class SignalHandler: ...@@ -1933,7 +1866,7 @@ class SignalHandler:
def running_phase_sigquit_handler(self, signum=None, frame=None): def running_phase_sigquit_handler(self, signum=None, frame=None):
logger.error( logger.error(
"Received sigquit from a child process. It usually means the child failed." f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
) )
self.tokenizer_manager.dump_requests_before_crash() self.tokenizer_manager.dump_requests_before_crash()
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
"""Utilities for Prometheus Metrics Collection.""" """Utilities for Prometheus Metrics Collection."""
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
...@@ -34,6 +34,7 @@ class TimeStats: ...@@ -34,6 +34,7 @@ class TimeStats:
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
""" """
disagg_mode: DisaggregationMode = DisaggregationMode.NULL
lb_entry_time: float = 0.0 lb_entry_time: float = 0.0
wait_queue_entry_time: float = 0.0 wait_queue_entry_time: float = 0.0
forward_entry_time: float = 0.0 forward_entry_time: float = 0.0
...@@ -43,20 +44,11 @@ class TimeStats: ...@@ -43,20 +44,11 @@ class TimeStats:
decode_prealloc_queue_entry_time: float = 0.0 decode_prealloc_queue_entry_time: float = 0.0
decode_transfer_queue_entry_time: float = 0.0 decode_transfer_queue_entry_time: float = 0.0
class RequestType(Enum):
UNIFIED = "unified"
PREFILL = "prefill"
DECODE = "decode"
INVALID = "invalid"
def get_queueing_time(self) -> float: def get_queueing_time(self) -> float:
return self.forward_entry_time - self.wait_queue_entry_time return self.forward_entry_time - self.wait_queue_entry_time
def __str__(self) -> str: def convert_to_duration(self) -> str:
# if unified if self.disagg_mode == DisaggregationMode.NULL:
_type = self.get_type()
if _type == self.RequestType.UNIFIED:
queue_duration = self.forward_entry_time - self.wait_queue_entry_time queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time forward_duration = self.completion_time - self.forward_entry_time
...@@ -65,30 +57,28 @@ class TimeStats: ...@@ -65,30 +57,28 @@ class TimeStats:
queue_duration >= 0 and forward_duration >= 0 queue_duration >= 0 and forward_duration >= 0
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}" return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}"
elif _type == self.RequestType.PREFILL: elif self.disagg_mode == DisaggregationMode.PREFILL:
bootstrap_duration = ( bootstrap_duration = (
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
) )
queue_duration = self.forward_entry_time - self.wait_queue_entry_time queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS: if SGLANG_TEST_REQUEST_TIME_STATS:
assert ( if self.wait_queue_entry_time > 0:
bootstrap_duration >= 0 assert (
and queue_duration >= 0 bootstrap_duration >= 0
and forward_duration >= 0 and queue_duration >= 0
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" and forward_duration >= 0
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}" ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
# if decode
elif _type == self.RequestType.DECODE: return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}"
elif self.disagg_mode == DisaggregationMode.DECODE:
prealloc_duration = ( prealloc_duration = (
self.decode_transfer_queue_entry_time self.decode_transfer_queue_entry_time
- self.decode_prealloc_queue_entry_time - self.decode_prealloc_queue_entry_time
) )
transfer_duration = ( transfer_duration = (
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
) )
...@@ -96,42 +86,30 @@ class TimeStats: ...@@ -96,42 +86,30 @@ class TimeStats:
forward_duration = self.completion_time - self.forward_entry_time forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS: if SGLANG_TEST_REQUEST_TIME_STATS:
assert ( if self.wait_queue_entry_time > 0:
prealloc_duration >= 0 assert (
and transfer_duration >= 0 prealloc_duration >= 0
and queue_duration >= 0 and transfer_duration >= 0
and forward_duration >= 0 and queue_duration >= 0
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" and forward_duration >= 0
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}"
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}"
else: else:
return "Invalid Time Stats" return "Unknown Time Stats"
def format_duration(self, duration: float) -> str: def format_duration(self, duration: float) -> str:
return f"{duration * 1e3:.2f}ms" return f"{duration * 1e3:.2f}ms"
def get_type(self) -> RequestType: def disagg_mode_str(self) -> str:
"""Determine the type of request based on timestamp values.""" if self.disagg_mode == DisaggregationMode.NULL:
if ( return "unified"
self.prefill_bootstrap_queue_entry_time == 0.0 elif self.disagg_mode == DisaggregationMode.DECODE:
and self.prefill_transfer_queue_entry_time == 0.0 return "decode"
and self.decode_prealloc_queue_entry_time == 0.0 elif self.disagg_mode == DisaggregationMode.PREFILL:
and self.decode_transfer_queue_entry_time == 0.0 return "prefill"
):
return self.RequestType.UNIFIED
elif (
self.prefill_bootstrap_queue_entry_time > 0.0
and self.prefill_transfer_queue_entry_time > 0.0
):
return self.RequestType.PREFILL
elif (
self.decode_prealloc_queue_entry_time > 0.0
and self.decode_transfer_queue_entry_time > 0.0
and self.wait_queue_entry_time > 0.0
):
return self.RequestType.DECODE
else: else:
return self.RequestType.INVALID return "unknown"
@dataclass @dataclass
...@@ -145,12 +123,15 @@ class SchedulerStats: ...@@ -145,12 +123,15 @@ class SchedulerStats:
num_queue_reqs: int = 0 num_queue_reqs: int = 0
num_grammar_queue_reqs: int = 0 num_grammar_queue_reqs: int = 0
num_running_reqs_offline_batch: int = 0 num_running_reqs_offline_batch: int = 0
avg_request_queue_latency: float = 0.0
cache_hit_rate: float = 0.0 cache_hit_rate: float = 0.0
# Speculative decoding # Speculative decoding
spec_accept_length: float = 0.0 spec_accept_length: float = 0.0
# Retract
num_retracted_reqs: int = 0
num_paused_reqs: int = 0
# PD disaggregation # PD disaggregation
num_prefill_prealloc_queue_reqs: int = 0 num_prefill_prealloc_queue_reqs: int = 0
num_prefill_inflight_queue_reqs: int = 0 num_prefill_inflight_queue_reqs: int = 0
...@@ -159,11 +140,6 @@ class SchedulerStats: ...@@ -159,11 +140,6 @@ class SchedulerStats:
kv_transfer_speed_gb_s: float = 0.0 kv_transfer_speed_gb_s: float = 0.0
kv_transfer_latency_ms: float = 0.0 kv_transfer_latency_ms: float = 0.0
# Retract
total_retracted_reqs: int = 0
num_retracted_reqs: int = 0
num_paused_reqs: int = 0
# Utilization # Utilization
utilization: float = 0.0 utilization: float = 0.0
max_running_requests_under_SLO: Optional[int] = None max_running_requests_under_SLO: Optional[int] = None
...@@ -230,12 +206,6 @@ class SchedulerMetricsCollector: ...@@ -230,12 +206,6 @@ class SchedulerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.avg_request_queue_latency = Gauge(
name="sglang:avg_request_queue_latency",
documentation="The average request queue latency for the last batch of requests in seconds.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.cache_hit_rate = Gauge( self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate", name="sglang:cache_hit_rate",
documentation="The prefix cache hit rate.", documentation="The prefix cache hit rate.",
...@@ -251,6 +221,18 @@ class SchedulerMetricsCollector: ...@@ -251,6 +221,18 @@ class SchedulerMetricsCollector:
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
# Retract
self.num_retracted_reqs = Gauge(
name="sglang:num_retracted_reqs",
documentation="The number of retracted requests.",
labelnames=labels.keys(),
)
self.num_paused_reqs = Gauge(
name="sglang:num_paused_reqs",
documentation="The number of paused requests by async weight sync.",
labelnames=labels.keys(),
)
# PD disaggregation # PD disaggregation
self.num_prefill_prealloc_queue_reqs = Gauge( self.num_prefill_prealloc_queue_reqs = Gauge(
name="sglang:num_prefill_prealloc_queue_reqs", name="sglang:num_prefill_prealloc_queue_reqs",
...@@ -299,24 +281,6 @@ class SchedulerMetricsCollector: ...@@ -299,24 +281,6 @@ class SchedulerMetricsCollector:
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
# Retract
self.total_retracted_reqs = Gauge(
name="sglang:total_retracted_reqs",
documentation="The total number of retracted requests due to kvcache full.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_retracted_reqs = Gauge(
name="sglang:num_retracted_reqs",
documentation="The number of retracted requests.",
labelnames=labels.keys(),
)
self.num_paused_reqs = Gauge(
name="sglang:num_paused_reqs",
documentation="The number of paused requests by async weight sync.",
labelnames=labels.keys(),
)
# Utilization # Utilization
self.utilization = Gauge( self.utilization = Gauge(
name="sglang:utilization", name="sglang:utilization",
...@@ -347,7 +311,7 @@ class SchedulerMetricsCollector: ...@@ -347,7 +311,7 @@ class SchedulerMetricsCollector:
# Additional queueing time histogram # Additional queueing time histogram
self.queue_time = Histogram( self.queue_time = Histogram(
name="sglang:queue_time_s", name="sglang:queue_time_seconds",
documentation="Histogram of queueing time in seconds.", documentation="Histogram of queueing time in seconds.",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
...@@ -513,8 +477,8 @@ class SchedulerMetricsCollector: ...@@ -513,8 +477,8 @@ class SchedulerMetricsCollector:
buckets=tree_traversal_time_buckets, buckets=tree_traversal_time_buckets,
) )
self.request_latency_seconds = Histogram( self.per_stage_req_latency_seconds = Histogram(
name="sglang:request_latency_seconds", name="sglang:per_stage_req_latency_seconds",
documentation="The latency of each stage of requests.", documentation="The latency of each stage of requests.",
# captures latency in range [1ms - ~1191s] # captures latency in range [1ms - ~1191s]
buckets=exponential_buckets(start=0.001, width=1.62, length=30), buckets=exponential_buckets(start=0.001, width=1.62, length=30),
...@@ -525,7 +489,7 @@ class SchedulerMetricsCollector: ...@@ -525,7 +489,7 @@ class SchedulerMetricsCollector:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data) gauge.labels(**self.labels).set(data)
def log_histogram(self, histogram, data: Union[int, float]) -> None: def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data) histogram.labels(**self.labels).observe(data)
def increment_bootstrap_failed_reqs(self) -> None: def increment_bootstrap_failed_reqs(self) -> None:
...@@ -534,9 +498,12 @@ class SchedulerMetricsCollector: ...@@ -534,9 +498,12 @@ class SchedulerMetricsCollector:
def increment_transfer_failed_reqs(self) -> None: def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1) self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
def observe_request_latency_seconds(self, stage: str, latency: float) -> None: def observe_per_stage_req_latency(self, stage: str, latency: float) -> None:
labels_with_stage = {**self.labels, "stage": stage} labels_with_stage = {**self.labels, "stage": stage}
self.request_latency_seconds.labels(**labels_with_stage).observe(latency) self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency)
def observe_queue_time(self, latency: float) -> None:
self._log_histogram(self.queue_time, latency)
def log_stats(self, stats: SchedulerStats) -> None: def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
...@@ -550,7 +517,6 @@ class SchedulerMetricsCollector: ...@@ -550,7 +517,6 @@ class SchedulerMetricsCollector:
self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
) )
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
# Speculative decoding # Speculative decoding
self._log_gauge(self.spec_accept_length, stats.spec_accept_length) self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
...@@ -572,7 +538,6 @@ class SchedulerMetricsCollector: ...@@ -572,7 +538,6 @@ class SchedulerMetricsCollector:
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms) self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
# Retract # Retract
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs) self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs) self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
...@@ -596,19 +561,19 @@ class SchedulerMetricsCollector: ...@@ -596,19 +561,19 @@ class SchedulerMetricsCollector:
def log_grammar_stats(self, grammar_stats) -> None: def log_grammar_stats(self, grammar_stats) -> None:
# Duck-typed GrammarStats to avoid cross-package dependency # Duck-typed GrammarStats to avoid cross-package dependency
if getattr(grammar_stats, "compilation_time", None) is not None: if getattr(grammar_stats, "compilation_time", None) is not None:
self.log_histogram( self._log_histogram(
self.grammar_compilation_time, grammar_stats.compilation_time self.grammar_compilation_time, grammar_stats.compilation_time
) )
if getattr(grammar_stats, "schema_count", None) is not None: if getattr(grammar_stats, "schema_count", None) is not None:
self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count) self._log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
if getattr(grammar_stats, "ebnf_size", None) is not None: if getattr(grammar_stats, "ebnf_size", None) is not None:
self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size) self._log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
tree_times = getattr(grammar_stats, "tree_traversal_time", None) tree_times = getattr(grammar_stats, "tree_traversal_time", None)
if tree_times: if tree_times:
max_time = max(tree_times) max_time = max(tree_times)
avg_time = sum(tree_times) / len(tree_times) avg_time = sum(tree_times) / len(tree_times)
self.log_histogram(self.grammar_tree_traversal_time_max, max_time) self._log_histogram(self.grammar_tree_traversal_time_max, max_time)
self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time) self._log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
if getattr(grammar_stats, "is_cache_hit", False): if getattr(grammar_stats, "is_cache_hit", False):
self.num_grammar_cache_hit.labels(**self.labels).inc(1) self.num_grammar_cache_hit.labels(**self.labels).inc(1)
if getattr(grammar_stats, "is_grammar_aborted", False): if getattr(grammar_stats, "is_grammar_aborted", False):
...@@ -714,7 +679,7 @@ class TokenizerMetricsCollector: ...@@ -714,7 +679,7 @@ class TokenizerMetricsCollector:
) )
self.num_aborted_requests_total = Counter( self.num_aborted_requests_total = Counter(
name="sglang:num_aborted_requests", name="sglang:num_aborted_requests_total",
documentation="Number of requests aborted.", documentation="Number of requests aborted.",
labelnames=labels.keys(), labelnames=labels.keys(),
) )
...@@ -801,7 +766,7 @@ class TokenizerMetricsCollector: ...@@ -801,7 +766,7 @@ class TokenizerMetricsCollector:
buckets=bucket_time_to_first_token, buckets=bucket_time_to_first_token,
) )
self.histogram_inter_token_latency_seconds = Histogram( self.histogram_inter_token_latency = Histogram(
name="sglang:inter_token_latency_seconds", name="sglang:inter_token_latency_seconds",
documentation="Histogram of inter-token latency in seconds.", documentation="Histogram of inter-token latency in seconds.",
labelnames=labels.keys(), labelnames=labels.keys(),
...@@ -815,14 +780,6 @@ class TokenizerMetricsCollector: ...@@ -815,14 +780,6 @@ class TokenizerMetricsCollector:
buckets=bucket_e2e_request_latency, buckets=bucket_e2e_request_latency,
) )
# Offline batch specific TTFB histogram
self.histogram_time_to_first_token_offline_batch = Histogram(
name="sglang:time_to_first_token_seconds_offline_batch",
documentation="Histogram of time to first token in seconds for offline batch requests.",
labelnames=labels.keys(),
buckets=bucket_time_to_first_token,
)
def observe_one_finished_request( def observe_one_finished_request(
self, self,
labels: Dict[str, str], labels: Dict[str, str],
...@@ -846,15 +803,8 @@ class TokenizerMetricsCollector: ...@@ -846,15 +803,8 @@ class TokenizerMetricsCollector:
float(generation_tokens) float(generation_tokens)
) )
def observe_time_to_first_token( def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
self, labels: Dict[str, str], value: float, type: str = "" self.histogram_time_to_first_token.labels(**labels).observe(value)
):
if type == "batch":
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
value
)
else:
self.histogram_time_to_first_token.labels(**labels).observe(value)
def check_time_to_first_token_straggler(self, value: float) -> bool: def check_time_to_first_token_straggler(self, value: float) -> bool:
his = self.histogram_time_to_first_token.labels(**self.labels) his = self.histogram_time_to_first_token.labels(**self.labels)
...@@ -876,7 +826,7 @@ class TokenizerMetricsCollector: ...@@ -876,7 +826,7 @@ class TokenizerMetricsCollector:
# A faster version of the Histogram::observe which observes multiple values at the same time. # A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639 # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**labels) his = self.histogram_inter_token_latency.labels(**labels)
his._sum.inc(internval) his._sum.inc(internval)
for i, bound in enumerate(his._upper_bounds): for i, bound in enumerate(his._upper_bounds):
...@@ -884,8 +834,8 @@ class TokenizerMetricsCollector: ...@@ -884,8 +834,8 @@ class TokenizerMetricsCollector:
his._buckets[i].inc(num_new_tokens) his._buckets[i].inc(num_new_tokens)
break break
def observe_one_aborted_request(self): def observe_one_aborted_request(self, labels: Dict[str, str]):
self.num_aborted_requests_total.labels(**self.labels).inc(1) self.num_aborted_requests_total.labels(**labels).inc(1)
@dataclass @dataclass
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from __future__ import annotations from __future__ import annotations
import ctypes
import logging import logging
import os import os
import random import random
...@@ -23,7 +22,10 @@ import threading ...@@ -23,7 +22,10 @@ import threading
import time import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Req
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
opentelemetry_imported = False opentelemetry_imported = False
...@@ -407,9 +409,11 @@ def trace_slice_start( ...@@ -407,9 +409,11 @@ def trace_slice_start(
ts: Optional[int] = None, ts: Optional[int] = None,
anonymous: bool = False, anonymous: bool = False,
): ):
if not tracing_enabled:
return
rid = str(rid) rid = str(rid)
if not tracing_enabled or rid not in reqs_context: if rid not in reqs_context:
return return
pid = threading.get_native_id() pid = threading.get_native_id()
...@@ -458,8 +462,11 @@ def trace_slice_end( ...@@ -458,8 +462,11 @@ def trace_slice_end(
auto_next_anon: bool = False, auto_next_anon: bool = False,
thread_finish_flag: bool = False, thread_finish_flag: bool = False,
): ):
if not tracing_enabled:
return
rid = str(rid) rid = str(rid)
if not tracing_enabled or rid not in reqs_context: if rid not in reqs_context:
return return
pid = threading.get_native_id() pid = threading.get_native_id()
...@@ -512,10 +519,13 @@ trace_slice = trace_slice_end ...@@ -512,10 +519,13 @@ trace_slice = trace_slice_end
# Add event to the current slice on the same thread with the same rid. # Add event to the current slice on the same thread with the same rid.
def trace_event(name: str, rid: str, ts: Optional[int] = None): def trace_event(name: str, rid: str, ts: Optional[int] = None):
if not tracing_enabled or rid not in reqs_context: if not tracing_enabled:
return return
rid = str(rid) rid = str(rid)
if rid not in reqs_context:
return
pid = threading.get_native_id() pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context: if pid not in reqs_context[rid].threads_context:
return return
...@@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None): ...@@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
# Add attrs to the current slice on the same thread with the same rid. # Add attrs to the current slice on the same thread with the same rid.
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
if not tracing_enabled or rid not in reqs_context: if not tracing_enabled:
return return
rid = str(rid) rid = str(rid)
if rid not in reqs_context:
return
pid = threading.get_native_id() pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context: if pid not in reqs_context[rid].threads_context:
return return
...@@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): ...@@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
slice_info = thread_context.cur_slice_stack[-1] slice_info = thread_context.cur_slice_stack[-1]
slice_info.span.set_attributes(attrs) slice_info.span.set_attributes(attrs)
def trace_slice_batch(
name: str,
reqs: List[Req],
):
for req in reqs:
trace_slice(
name,
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment