Unverified Commit 2d62af6b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix metrics and request tracing (TimeStats) (#11123)

parent a28b394f
......@@ -14,18 +14,17 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"aiohttp",
"requests",
"tqdm",
"numpy",
"IPython",
"setproctitle",
"aiohttp",
"anthropic>=0.20.0",
"blobfile==3.0.0",
"build",
"compressed-tensors",
"cuda-python",
"datasets",
"einops",
"fastapi",
"flashinfer_python==0.4.0rc3",
"hf_transfer",
"huggingface_hub",
"interegular",
......@@ -33,8 +32,10 @@ dependencies = [
"modelscope",
"msgspec",
"ninja",
"openai==1.99.1",
"numpy",
"nvidia-cutlass-dsl==4.2.1",
"openai-harmony==0.0.4",
"openai==1.99.1",
"orjson",
"outlines==0.1.11",
"packaging",
......@@ -42,32 +43,30 @@ dependencies = [
"pillow",
"prometheus-client>=0.20.0",
"psutil",
"py-spy",
"pybase64",
"pydantic",
"pynvml",
"python-multipart",
"pyzmq>=25.1.2",
"requests",
"scipy",
"sentencepiece",
"setproctitle",
"sgl-kernel==0.3.13",
"soundfile==0.13.1",
"timm==1.0.16",
"tiktoken",
"timm==1.0.16",
"torch==2.8.0",
"torch_memory_saver==0.0.8",
"torchao==0.9.0",
"torchaudio==2.8.0",
"torchvision",
"tqdm",
"transformers==4.56.1",
"uvicorn",
"uvloop",
"xgrammar==0.1.24",
"sgl-kernel==0.3.13",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.4.0rc3",
"openai==1.99.1",
"tiktoken",
"anthropic>=0.20.0",
"torch_memory_saver==0.0.8",
"nvidia-cutlass-dsl==4.2.1",
"xgrammar==0.1.24"
]
[project.optional-dependencies]
......@@ -79,15 +78,15 @@ test = [
"matplotlib",
"pandas",
"peft",
"sentence_transformers",
"pytest",
"sentence_transformers",
"tabulate",
]
tracing = [
"opentelemetry-sdk",
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
]
all = ["sglang[test]", "sglang[decord]"]
blackwell = ["sglang[test]", "sglang[decord]"]
......
......@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from __future__ import annotations
import logging
import time
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
......@@ -422,9 +423,13 @@ class DecodePreallocQueue:
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
time.perf_counter()
)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
......@@ -625,6 +630,7 @@ class DecodeTransferQueue:
decode_req.req.output_topk_p = output_topk_p
decode_req.req.output_topk_index = output_topk_index
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item()
......@@ -645,10 +651,17 @@ class DecodeTransferQueue:
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()
decode_req.kv_receiver = None
indices_to_remove.add(i)
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
# special handling for sampling_params.max_new_tokens == 1
if decode_req.req.sampling_params.max_new_tokens == 1:
# finish immediately
decode_req.req.time_stats.forward_entry_time = (
decode_req.req.time_stats.completion_time
) = time.perf_counter()
decode_req.req.check_finished()
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
......@@ -656,8 +669,6 @@ class DecodeTransferQueue:
self.tree_cache.cache_finished_req(decode_req.req)
else:
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
......@@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin:
if len(can_run_list) == 0:
return None
for req in can_run_list:
req.time_stats.forward_entry_time = time.perf_counter()
# construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new(
can_run_list,
......
......@@ -21,6 +21,7 @@ from __future__ import annotations
import logging
import threading
import time
from collections import deque
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Type
......@@ -263,9 +264,10 @@ class PrefillBootstrapQueue:
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
req.time_stats.wait_queue_entry_time = time.perf_counter()
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
......@@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin:
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
req: Req
if req.is_chunked <= 0:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
......@@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
)
logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True)
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
......@@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
else:
assert False, f"Unexpected polling state {poll=}"
for req in done_reqs:
req.time_stats.completion_time = time.perf_counter()
# Stream requests which have finished transfer
self.stream_output(
done_reqs,
......
......@@ -5,7 +5,7 @@ import random
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Type, Union
from typing import TYPE_CHECKING, Optional, Type
import numpy as np
import torch
......
......@@ -41,7 +41,7 @@ import time
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
......@@ -452,6 +453,7 @@ class Req:
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
disagg_mode: Optional[DisaggregationMode] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
......@@ -628,10 +630,8 @@ class Req:
# For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats()
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
self.has_log_time_stats: bool = False
self.queue_time_start = None
self.queue_time_end = None
self.last_tic = time.monotonic()
# For disaggregation
......@@ -668,9 +668,9 @@ class Req:
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
now = time.monotonic()
self.metrics_collector.observe_request_latency_seconds(
self.metrics_collector.observe_per_stage_req_latency(
stage.value, now - self.last_tic
)
self.last_tic = now
......@@ -834,10 +834,10 @@ class Req:
return
if self.bootstrap_room is not None:
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
else:
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
logger.info(f"{prefix}: {self.time_stats}")
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
self.has_log_time_stats = True
def set_finish_with_abort(self, error_msg: str):
......@@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio
return retracted_reqs, new_estimate_ratio, []
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
......
......@@ -276,9 +276,13 @@ class SchedulePolicy:
) -> None:
"""Sorts the waiting queue based on the request priority then received titmestamp."""
if schedule_low_priority_values_first:
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
waiting_queue.sort(
key=lambda x: (x.priority, x.time_stats.wait_queue_entry_time)
)
else:
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
waiting_queue.sort(
key=lambda x: (-x.priority, x.time_stats.wait_queue_entry_time)
)
@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
......@@ -642,12 +646,12 @@ class PrefillAdder:
if server_args.schedule_low_priority_values_first:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (-x.priority, -x.queue_time_start),
key=lambda x: (-x.priority, -x.time_stats.wait_queue_entry_time),
)
else:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (x.priority, -x.queue_time_start),
key=lambda x: (x.priority, -x.time_stats.wait_queue_entry_time),
)
preemptible_reqs = []
min_tokens_to_remove = (
......
......@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_event,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice,
trace_slice_batch,
trace_slice_end,
trace_slice_start,
)
......@@ -263,6 +262,7 @@ class Scheduler(
server_args.enable_metrics_for_all_schedulers
)
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
self.enable_trace = server_args.enable_trace
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
......@@ -899,10 +899,6 @@ class Scheduler(
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
......@@ -924,10 +920,6 @@ class Scheduler(
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
......@@ -1192,10 +1184,13 @@ class Scheduler(
src=self.tp_group.ranks[0],
)
for req in recv_reqs:
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)
if self.enable_trace:
for req in recv_reqs:
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)
return recv_reqs
......@@ -1277,6 +1272,7 @@ class Scheduler(
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
disagg_mode=self.disaggregation_mode,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
......@@ -1403,7 +1399,6 @@ class Scheduler(
req.set_finish_with_abort(error_msg)
if add_to_grammar_queue:
req.queue_time_start = time.perf_counter()
self.grammar_queue.append(req)
else:
self._add_request_to_queue(req)
......@@ -1419,23 +1414,6 @@ class Scheduler(
for tokenized_req in recv_req:
self.handle_generate_request(tokenized_req)
def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache)
......@@ -1449,19 +1427,27 @@ class Scheduler(
req.rid, req.last_host_node, new_input_tokens, last_hash
)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.NULL:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
req.time_stats.wait_queue_entry_time = time.perf_counter()
trace_slice_end("process req", req.rid, auto_next_anon=True)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
if not is_retracted:
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
else:
for req in reqs:
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
raise ValueError(f"Invalid {self.disaggregation_mode=}")
def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
......@@ -1500,7 +1486,7 @@ class Scheduler(
direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: (
direction * item[1].priority,
item[1].queue_time_start,
item[1].time_stats.wait_queue_entry_time,
)
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = (
......@@ -1902,14 +1888,14 @@ class Scheduler(
if self.enable_metrics:
# only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list:
req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING)
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list)
for req in adder.preempt_list:
self._add_request_to_queue(req)
if adder.new_chunked_req is not None:
assert self.chunked_req is None
......@@ -1920,7 +1906,16 @@ class Scheduler(
# Print stats
if self.current_scheduler_metrics_enabled():
self.log_prefill_stats(adder, can_run_list, running_bs)
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
for req in can_run_list:
if req.time_stats.forward_entry_time == 0:
# Avoid update chunked request many times
req.time_stats.forward_entry_time = time.perf_counter()
if self.enable_metrics:
self.metrics_collector.observe_queue_time(
req.time_stats.get_queueing_time(),
)
# Create a new batch
new_batch = ScheduleBatch.init_new(
......@@ -1975,19 +1970,25 @@ class Scheduler(
TEST_RETRACT and batch.batch_size() > 10
):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
num_retracted_reqs = len(retracted_reqs)
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
self.server_args
)
self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_pyobj(
AbortReq(req.rid, abort_reason=req.to_abort_message)
)
logger.info(
"KV cache pool is full. Retract requests. "
f"#retracted_reqs: {num_retracted_reqs}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
)
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
self.total_retracted_reqs += num_retracted_reqs
for req in retracted_reqs:
self._add_request_to_queue(req, is_retracted=True)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
......@@ -2086,23 +2087,14 @@ class Scheduler(
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"decode loop",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
if self.enable_trace:
trace_slice_batch("decode loop", batch.reqs)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"prefill",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
if self.enable_trace:
trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
......@@ -2261,12 +2253,13 @@ class Scheduler(
if req.finished(): # It is aborted by AbortReq
num_ready_reqs += 1
continue
req.grammar = req.grammar.result(timeout=0.03)
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
error_msg = f"Invalid grammar request: {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
num_ready_reqs += 1
except futures._base.TimeoutError:
req.grammar_wait_ct += 1
......@@ -2298,9 +2291,8 @@ class Scheduler(
req.grammar = req.grammar.result()
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
error_msg = f"Invalid grammar request: {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
else:
num_ready_reqs_max = num_ready_reqs
num_timeout_reqs_max = num_timeout_reqs
......@@ -2308,12 +2300,14 @@ class Scheduler(
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
req = self.grammar_queue[i]
req.grammar.cancel()
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
for req in self.grammar_queue[:num_ready_reqs]:
self._add_request_to_queue(req)
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
......@@ -2795,17 +2789,11 @@ def run_scheduler_process(
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])
# Generate the prefix
# Generate the logger prefix
prefix = ""
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
dp_rank = int(os.environ["SGLANG_DP_RANK"])
if dp_rank is not None:
prefix += f" DP{dp_rank}"
if server_args.tp_size > 1:
......@@ -2821,10 +2809,6 @@ def run_scheduler_process(
kill_itself_when_parent_died()
parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configure the logger
configure_logger(server_args, prefix=prefix)
suppress_other_loggers()
......@@ -2832,6 +2816,15 @@ def run_scheduler_process(
# Set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])
# Set up tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
# Create a scheduler and run the event loop
try:
......
......@@ -47,8 +47,11 @@ class SchedulerMetricsMixin:
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.total_retracted_reqs = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
labels = {
......@@ -82,12 +85,14 @@ class SchedulerMetricsMixin:
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
running_bs_offline_batch: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
# TODO: generalize this for various memory pools
if self.is_hybrid:
(
full_num_used,
......@@ -101,51 +106,53 @@ class SchedulerMetricsMixin:
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
token_usage_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
f"#new-seq: {num_new_seq}, "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{token_msg}"
f"{token_usage_msg}"
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
else:
f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
logger.info(f)
if self.enable_metrics:
# Basics
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
cache_hit_rate = (
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
)
self.stats.num_running_reqs = running_bs
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.token_usage = token_usage
if self.is_hybrid:
self.stats.swa_token_usage = swa_token_usage
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
# Retract
self.stats.num_retracted_reqs = self.num_retracted_reqs
self.stats.num_paused_reqs = self.num_paused_reqs
self.num_retracted_reqs = self.num_paused_reqs = 0
# PD disaggregation
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
......@@ -153,7 +160,18 @@ class SchedulerMetricsMixin:
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
# Others
self.calculate_utilization()
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
......@@ -166,8 +184,12 @@ class SchedulerMetricsMixin:
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
num_running_reqs_offline_batch = 0
# TODO: generalize this for various memory pools
if self.is_hybrid:
(
full_num_used,
......@@ -181,7 +203,7 @@ class SchedulerMetricsMixin:
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
token_usage_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
......@@ -189,14 +211,14 @@ class SchedulerMetricsMixin:
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
......@@ -208,41 +230,66 @@ class SchedulerMetricsMixin:
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg += f"accept len: {spec_accept_length:.2f}, "
cache_hit_rate = 0.0
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
logger.info(msg)
if self.enable_metrics:
# Basics
self.stats.num_running_reqs = num_running_reqs
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.token_usage = token_usage
if self.is_hybrid:
self.stats.swa_token_usage = swa_token_usage
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.cache_hit_rate = cache_hit_rate
self.stats.spec_accept_length = spec_accept_length
self.stats.total_retracted_reqs = self.total_retracted_reqs
self.stats.avg_request_queue_latency = 0.0
if self.disaggregation_mode == DisaggregationMode.DECODE:
# Retract
self.stats.num_retracted_reqs = self.num_retracted_reqs
self.stats.num_paused_reqs = self.num_paused_reqs
self.num_retracted_reqs = self.num_paused_reqs = 0
# PD disaggregation
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.num_prefill_prealloc_queue_reqs = len(
self.disagg_prefill_bootstrap_queue.queue
)
self.stats.num_prefill_inflight_queue_reqs = len(
self.disagg_prefill_inflight_queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue
)
self.stats.num_decode_transfer_queue_reqs = len(
self.disagg_decode_transfer_queue.queue
)
# Others
self.calculate_utilization()
self.metrics_collector.log_stats(self.stats)
self._emit_kv_metrics()
self._publish_kv_events()
def _emit_kv_metrics(self: Scheduler):
if not self.enable_kv_cache_events:
return
kv_metrics = KvMetrics()
kv_metrics.request_active_slots = self.stats.num_running_reqs
kv_metrics.request_total_slots = self.max_running_requests
......@@ -259,11 +306,13 @@ class SchedulerMetricsMixin:
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
def _publish_kv_events(self: Scheduler):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
if not self.enable_kv_cache_events:
return
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
def maybe_update_dp_balance_data(
self: Scheduler, recv_req: TokenizedGenerateReqInput
......@@ -349,3 +398,17 @@ class SchedulerMetricsMixin:
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
def calculate_utilization(self):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.utilization = -1
else:
if (
self.stats.max_running_requests_under_SLO is not None
and self.stats.max_running_requests_under_SLO > 0
):
self.stats.utilization = max(
self.stats.num_running_reqs
/ self.stats.max_running_requests_under_SLO,
self.stats.token_usage / 0.9,
)
......@@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin:
if req.finished():
self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
req.time_stats.completion_time = time.perf_counter()
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
# This updates radix so others can match
self.tree_cache.cache_unfinished_req(req)
......@@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin:
else:
self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
req.time_stats.completion_time = time.perf_counter()
if req.return_logprob and batch.spec_algorithm.is_none():
# speculative worker handles logprob in speculative decoding
......@@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin:
and self.tp_rank == 0
and self.server_args.enable_request_time_stats_logging
):
print(f"{req.finished_reason=}")
req.log_time_stats()
# Send to detokenizer
......
......@@ -5,6 +5,7 @@ import copy
import logging
import os
import time
import uuid
from collections import deque
from typing import (
TYPE_CHECKING,
......@@ -24,6 +25,7 @@ import zmq
from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput,
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
ExpertDistributionReq,
......@@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWrapper,
OpenSessionReqInput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
......@@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin:
async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
req = GetLoadReqInput()
return await self.get_load_communicator(req)
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
if obj.session_id is None:
obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
await self.send_to_scheduler.send_pyobj(obj)
def get_log_request_metadata(self):
max_length = None
skip_names = None
out_skip_names = None
if self.log_requests:
if self.log_requests_level == 0:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 1:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30
else:
raise ValueError(
f"Invalid --log-requests-level: {self.log_requests_level=}"
)
return max_length, skip_names, out_skip_names
......@@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else None
)
self.crash_dump_folder = server_args.crash_dump_folder
self.enable_trace = server_args.enable_trace
# Read model args
self.model_path = server_args.model_path
......@@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if obj.is_single:
bootstrap_room = (
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
)
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else:
for i in range(len(obj.rid)):
bootstrap_room = (
obj.bootstrap_room[i]
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None
)
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
)
if self.enable_trace:
self._trace_request_start(obj, created_time)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
......@@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request()
# TODO: also use custom_labels from the request
self.metrics_collector.observe_one_aborted_request(
self.metrics_collector.labels
)
async def pause_generation(self):
async with self.is_pause_cond:
......@@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
all_paused_requests = [r.num_paused_requests for r in result]
return all_success, all_message, all_paused_requests
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
if obj.session_id is None:
obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWrapper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
await self.send_to_scheduler.send_pyobj(obj)
def get_log_request_metadata(self):
max_length = None
skip_names = None
out_skip_names = None
if self.log_requests:
if self.log_requests_level == 0:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 1:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30
else:
raise ValueError(
f"Invalid --log-requests-level: {self.log_requests_level=}"
)
return max_length, skip_names, out_skip_names
def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None:
self.log_requests = obj.log_requests
......@@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# Drain requests
while True:
remain_num_req = len(self.rid_to_state)
remaining_rids = list(self.rid_to_state.keys())
if self.server_status == ServerStatus.UnHealthy:
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
remain_num_req,
"Signal SIGTERM received while health check failed. Force exiting."
)
self.dump_requests_before_crash()
break
......@@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
# if force shutdown flag set, exit immediately
logger.error(
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
remain_num_req,
"Signal SIGTERM received while force shutdown flag set. Force exiting."
)
break
logger.info(
f"Gracefully exiting... remaining number of requests {remain_num_req}"
f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
)
if remain_num_req > 0:
await asyncio.sleep(5)
......@@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
load_udpate_req = WatchLoadUpdateReq(loads=loads)
self.send_to_scheduler.send_pyobj(load_udpate_req)
def _trace_request_start(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
created_time: Optional[float] = None,
):
if obj.is_single:
bootstrap_room = (
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
)
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else:
for i in range(len(obj.rid)):
bootstrap_room = (
obj.bootstrap_room[i]
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None
)
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
)
class ServerStatus(Enum):
Up = "Up"
......@@ -1933,7 +1866,7 @@ class SignalHandler:
def running_phase_sigquit_handler(self, signum=None, frame=None):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
)
self.tokenizer_manager.dump_requests_before_crash()
kill_process_tree(os.getpid())
......
......@@ -14,9 +14,9 @@
"""Utilities for Prometheus Metrics Collection."""
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var
......@@ -34,6 +34,7 @@ class TimeStats:
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
"""
disagg_mode: DisaggregationMode = DisaggregationMode.NULL
lb_entry_time: float = 0.0
wait_queue_entry_time: float = 0.0
forward_entry_time: float = 0.0
......@@ -43,20 +44,11 @@ class TimeStats:
decode_prealloc_queue_entry_time: float = 0.0
decode_transfer_queue_entry_time: float = 0.0
class RequestType(Enum):
UNIFIED = "unified"
PREFILL = "prefill"
DECODE = "decode"
INVALID = "invalid"
def get_queueing_time(self) -> float:
return self.forward_entry_time - self.wait_queue_entry_time
def __str__(self) -> str:
# if unified
_type = self.get_type()
if _type == self.RequestType.UNIFIED:
def convert_to_duration(self) -> str:
if self.disagg_mode == DisaggregationMode.NULL:
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time
......@@ -65,30 +57,28 @@ class TimeStats:
queue_duration >= 0 and forward_duration >= 0
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
elif _type == self.RequestType.PREFILL:
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}"
elif self.disagg_mode == DisaggregationMode.PREFILL:
bootstrap_duration = (
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
)
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS:
assert (
bootstrap_duration >= 0
and queue_duration >= 0
and forward_duration >= 0
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
# if decode
elif _type == self.RequestType.DECODE:
if self.wait_queue_entry_time > 0:
assert (
bootstrap_duration >= 0
and queue_duration >= 0
and forward_duration >= 0
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}"
elif self.disagg_mode == DisaggregationMode.DECODE:
prealloc_duration = (
self.decode_transfer_queue_entry_time
- self.decode_prealloc_queue_entry_time
)
transfer_duration = (
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
)
......@@ -96,42 +86,30 @@ class TimeStats:
forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS:
assert (
prealloc_duration >= 0
and transfer_duration >= 0
and queue_duration >= 0
and forward_duration >= 0
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
if self.wait_queue_entry_time > 0:
assert (
prealloc_duration >= 0
and transfer_duration >= 0
and queue_duration >= 0
and forward_duration >= 0
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}"
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}"
else:
return "Invalid Time Stats"
return "Unknown Time Stats"
def format_duration(self, duration: float) -> str:
return f"{duration * 1e3:.2f}ms"
def get_type(self) -> RequestType:
"""Determine the type of request based on timestamp values."""
if (
self.prefill_bootstrap_queue_entry_time == 0.0
and self.prefill_transfer_queue_entry_time == 0.0
and self.decode_prealloc_queue_entry_time == 0.0
and self.decode_transfer_queue_entry_time == 0.0
):
return self.RequestType.UNIFIED
elif (
self.prefill_bootstrap_queue_entry_time > 0.0
and self.prefill_transfer_queue_entry_time > 0.0
):
return self.RequestType.PREFILL
elif (
self.decode_prealloc_queue_entry_time > 0.0
and self.decode_transfer_queue_entry_time > 0.0
and self.wait_queue_entry_time > 0.0
):
return self.RequestType.DECODE
def disagg_mode_str(self) -> str:
if self.disagg_mode == DisaggregationMode.NULL:
return "unified"
elif self.disagg_mode == DisaggregationMode.DECODE:
return "decode"
elif self.disagg_mode == DisaggregationMode.PREFILL:
return "prefill"
else:
return self.RequestType.INVALID
return "unknown"
@dataclass
......@@ -145,12 +123,15 @@ class SchedulerStats:
num_queue_reqs: int = 0
num_grammar_queue_reqs: int = 0
num_running_reqs_offline_batch: int = 0
avg_request_queue_latency: float = 0.0
cache_hit_rate: float = 0.0
# Speculative decoding
spec_accept_length: float = 0.0
# Retract
num_retracted_reqs: int = 0
num_paused_reqs: int = 0
# PD disaggregation
num_prefill_prealloc_queue_reqs: int = 0
num_prefill_inflight_queue_reqs: int = 0
......@@ -159,11 +140,6 @@ class SchedulerStats:
kv_transfer_speed_gb_s: float = 0.0
kv_transfer_latency_ms: float = 0.0
# Retract
total_retracted_reqs: int = 0
num_retracted_reqs: int = 0
num_paused_reqs: int = 0
# Utilization
utilization: float = 0.0
max_running_requests_under_SLO: Optional[int] = None
......@@ -230,12 +206,6 @@ class SchedulerMetricsCollector:
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.avg_request_queue_latency = Gauge(
name="sglang:avg_request_queue_latency",
documentation="The average request queue latency for the last batch of requests in seconds.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate",
documentation="The prefix cache hit rate.",
......@@ -251,6 +221,18 @@ class SchedulerMetricsCollector:
multiprocess_mode="mostrecent",
)
# Retract
self.num_retracted_reqs = Gauge(
name="sglang:num_retracted_reqs",
documentation="The number of retracted requests.",
labelnames=labels.keys(),
)
self.num_paused_reqs = Gauge(
name="sglang:num_paused_reqs",
documentation="The number of paused requests by async weight sync.",
labelnames=labels.keys(),
)
# PD disaggregation
self.num_prefill_prealloc_queue_reqs = Gauge(
name="sglang:num_prefill_prealloc_queue_reqs",
......@@ -299,24 +281,6 @@ class SchedulerMetricsCollector:
multiprocess_mode="mostrecent",
)
# Retract
self.total_retracted_reqs = Gauge(
name="sglang:total_retracted_reqs",
documentation="The total number of retracted requests due to kvcache full.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_retracted_reqs = Gauge(
name="sglang:num_retracted_reqs",
documentation="The number of retracted requests.",
labelnames=labels.keys(),
)
self.num_paused_reqs = Gauge(
name="sglang:num_paused_reqs",
documentation="The number of paused requests by async weight sync.",
labelnames=labels.keys(),
)
# Utilization
self.utilization = Gauge(
name="sglang:utilization",
......@@ -347,7 +311,7 @@ class SchedulerMetricsCollector:
# Additional queueing time histogram
self.queue_time = Histogram(
name="sglang:queue_time_s",
name="sglang:queue_time_seconds",
documentation="Histogram of queueing time in seconds.",
labelnames=labels.keys(),
buckets=[
......@@ -513,8 +477,8 @@ class SchedulerMetricsCollector:
buckets=tree_traversal_time_buckets,
)
self.request_latency_seconds = Histogram(
name="sglang:request_latency_seconds",
self.per_stage_req_latency_seconds = Histogram(
name="sglang:per_stage_req_latency_seconds",
documentation="The latency of each stage of requests.",
# captures latency in range [1ms - ~1191s]
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
......@@ -525,7 +489,7 @@ class SchedulerMetricsCollector:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def log_histogram(self, histogram, data: Union[int, float]) -> None:
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def increment_bootstrap_failed_reqs(self) -> None:
......@@ -534,9 +498,12 @@ class SchedulerMetricsCollector:
def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
def observe_per_stage_req_latency(self, stage: str, latency: float) -> None:
labels_with_stage = {**self.labels, "stage": stage}
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency)
def observe_queue_time(self, latency: float) -> None:
self._log_histogram(self.queue_time, latency)
def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
......@@ -550,7 +517,6 @@ class SchedulerMetricsCollector:
self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
# Speculative decoding
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
......@@ -572,7 +538,6 @@ class SchedulerMetricsCollector:
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
# Retract
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
......@@ -596,19 +561,19 @@ class SchedulerMetricsCollector:
def log_grammar_stats(self, grammar_stats) -> None:
# Duck-typed GrammarStats to avoid cross-package dependency
if getattr(grammar_stats, "compilation_time", None) is not None:
self.log_histogram(
self._log_histogram(
self.grammar_compilation_time, grammar_stats.compilation_time
)
if getattr(grammar_stats, "schema_count", None) is not None:
self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
self._log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
if getattr(grammar_stats, "ebnf_size", None) is not None:
self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
self._log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
tree_times = getattr(grammar_stats, "tree_traversal_time", None)
if tree_times:
max_time = max(tree_times)
avg_time = sum(tree_times) / len(tree_times)
self.log_histogram(self.grammar_tree_traversal_time_max, max_time)
self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
self._log_histogram(self.grammar_tree_traversal_time_max, max_time)
self._log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
if getattr(grammar_stats, "is_cache_hit", False):
self.num_grammar_cache_hit.labels(**self.labels).inc(1)
if getattr(grammar_stats, "is_grammar_aborted", False):
......@@ -714,7 +679,7 @@ class TokenizerMetricsCollector:
)
self.num_aborted_requests_total = Counter(
name="sglang:num_aborted_requests",
name="sglang:num_aborted_requests_total",
documentation="Number of requests aborted.",
labelnames=labels.keys(),
)
......@@ -801,7 +766,7 @@ class TokenizerMetricsCollector:
buckets=bucket_time_to_first_token,
)
self.histogram_inter_token_latency_seconds = Histogram(
self.histogram_inter_token_latency = Histogram(
name="sglang:inter_token_latency_seconds",
documentation="Histogram of inter-token latency in seconds.",
labelnames=labels.keys(),
......@@ -815,14 +780,6 @@ class TokenizerMetricsCollector:
buckets=bucket_e2e_request_latency,
)
# Offline batch specific TTFB histogram
self.histogram_time_to_first_token_offline_batch = Histogram(
name="sglang:time_to_first_token_seconds_offline_batch",
documentation="Histogram of time to first token in seconds for offline batch requests.",
labelnames=labels.keys(),
buckets=bucket_time_to_first_token,
)
def observe_one_finished_request(
self,
labels: Dict[str, str],
......@@ -846,15 +803,8 @@ class TokenizerMetricsCollector:
float(generation_tokens)
)
def observe_time_to_first_token(
self, labels: Dict[str, str], value: float, type: str = ""
):
if type == "batch":
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
value
)
else:
self.histogram_time_to_first_token.labels(**labels).observe(value)
def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
self.histogram_time_to_first_token.labels(**labels).observe(value)
def check_time_to_first_token_straggler(self, value: float) -> bool:
his = self.histogram_time_to_first_token.labels(**self.labels)
......@@ -876,7 +826,7 @@ class TokenizerMetricsCollector:
# A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**labels)
his = self.histogram_inter_token_latency.labels(**labels)
his._sum.inc(internval)
for i, bound in enumerate(his._upper_bounds):
......@@ -884,8 +834,8 @@ class TokenizerMetricsCollector:
his._buckets[i].inc(num_new_tokens)
break
def observe_one_aborted_request(self):
self.num_aborted_requests_total.labels(**self.labels).inc(1)
def observe_one_aborted_request(self, labels: Dict[str, str]):
self.num_aborted_requests_total.labels(**labels).inc(1)
@dataclass
......
......@@ -15,7 +15,6 @@
from __future__ import annotations
import ctypes
import logging
import os
import random
......@@ -23,7 +22,10 @@ import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Req
logger = logging.getLogger(__name__)
opentelemetry_imported = False
......@@ -407,9 +409,11 @@ def trace_slice_start(
ts: Optional[int] = None,
anonymous: bool = False,
):
if not tracing_enabled:
return
rid = str(rid)
if not tracing_enabled or rid not in reqs_context:
if rid not in reqs_context:
return
pid = threading.get_native_id()
......@@ -458,8 +462,11 @@ def trace_slice_end(
auto_next_anon: bool = False,
thread_finish_flag: bool = False,
):
if not tracing_enabled:
return
rid = str(rid)
if not tracing_enabled or rid not in reqs_context:
if rid not in reqs_context:
return
pid = threading.get_native_id()
......@@ -512,10 +519,13 @@ trace_slice = trace_slice_end
# Add event to the current slice on the same thread with the same rid.
def trace_event(name: str, rid: str, ts: Optional[int] = None):
if not tracing_enabled or rid not in reqs_context:
if not tracing_enabled:
return
rid = str(rid)
if rid not in reqs_context:
return
pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context:
return
......@@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
# Add attrs to the current slice on the same thread with the same rid.
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
if not tracing_enabled or rid not in reqs_context:
if not tracing_enabled:
return
rid = str(rid)
if rid not in reqs_context:
return
pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context:
return
......@@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
slice_info = thread_context.cur_slice_stack[-1]
slice_info.span.set_attributes(attrs)
def trace_slice_batch(
name: str,
reqs: List[Req],
):
for req in reqs:
trace_slice(
name,
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment