Unverified Commit a06bf664 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update collector.py, startup_func_log_and_timer... (20250910) (#10242)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarcctry <shiyang@x.ai>
parent bf72b801
...@@ -50,6 +50,9 @@ class TimeStats: ...@@ -50,6 +50,9 @@ class TimeStats:
DECODE = "decode" DECODE = "decode"
INVALID = "invalid" INVALID = "invalid"
def get_queueing_time(self) -> float:
return self.forward_entry_time - self.wait_queue_entry_time
def __str__(self) -> str: def __str__(self) -> str:
# if unified # if unified
_type = self.get_type() _type = self.get_type()
...@@ -134,27 +137,48 @@ class TimeStats: ...@@ -134,27 +137,48 @@ class TimeStats:
@dataclass @dataclass
class SchedulerStats: class SchedulerStats:
# Basics
num_running_reqs: int = 0 num_running_reqs: int = 0
num_used_tokens: int = 0 num_used_tokens: int = 0
token_usage: float = 0.0 token_usage: float = 0.0
swa_token_usage: float = 0.0
gen_throughput: float = 0.0 gen_throughput: float = 0.0
num_queue_reqs: int = 0 num_queue_reqs: int = 0
cache_hit_rate: float = 0.0
num_grammar_queue_reqs: int = 0 num_grammar_queue_reqs: int = 0
spec_accept_length: float = 0.0 num_running_reqs_offline_batch: int = 0
avg_request_queue_latency: float = 0.0 avg_request_queue_latency: float = 0.0
cache_hit_rate: float = 0.0
# Speculative decoding
spec_accept_length: float = 0.0
# PD disaggregation
num_prefill_prealloc_queue_reqs: int = 0 num_prefill_prealloc_queue_reqs: int = 0
num_prefill_inflight_queue_reqs: int = 0 num_prefill_inflight_queue_reqs: int = 0
num_decode_prealloc_queue_reqs: int = 0 num_decode_prealloc_queue_reqs: int = 0
num_decode_transfer_queue_reqs: int = 0 num_decode_transfer_queue_reqs: int = 0
kv_transfer_speed_gb_s: float = 0.0
kv_transfer_latency_ms: float = 0.0
# Retract
total_retracted_reqs: int = 0 total_retracted_reqs: int = 0
num_retracted_reqs: int = 0
num_paused_reqs: int = 0
# Utilization
utilization: float = 0.0
max_running_requests_under_SLO: Optional[int] = None
# Engine startup
engine_startup_time: float = 0.0
engine_load_weights_time: float = 0.0
class SchedulerMetricsCollector: class SchedulerMetricsCollector:
def __init__(self, labels: Dict[str, str]) -> None: def __init__(self, labels: Dict[str, str]) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge, Histogram
self.labels = labels self.labels = labels
self.last_log_time = time.perf_counter() self.last_log_time = time.perf_counter()
...@@ -165,115 +189,338 @@ class SchedulerMetricsCollector: ...@@ -165,115 +189,338 @@ class SchedulerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_used_tokens = Gauge( self.num_used_tokens = Gauge(
name="sglang:num_used_tokens", name="sglang:num_used_tokens",
documentation="The number of used tokens.", documentation="The number of used tokens.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.token_usage = Gauge( self.token_usage = Gauge(
name="sglang:token_usage", name="sglang:token_usage",
documentation="The token usage.", documentation="The token usage.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.swa_token_usage = Gauge(
name="sglang:swa_token_usage",
documentation="The token usage for SWA layers.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.gen_throughput = Gauge( self.gen_throughput = Gauge(
name="sglang:gen_throughput", name="sglang:gen_throughput",
documentation="The generation throughput (token/s).", documentation="The generation throughput (token/s).",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_queue_reqs = Gauge( self.num_queue_reqs = Gauge(
name="sglang:num_queue_reqs", name="sglang:num_queue_reqs",
documentation="The number of requests in the waiting queue.", documentation="The number of requests in the waiting queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_grammar_queue_reqs = Gauge( self.num_grammar_queue_reqs = Gauge(
name="sglang:num_grammar_queue_reqs", name="sglang:num_grammar_queue_reqs",
documentation="The number of requests in the grammar waiting queue.", documentation="The number of requests in the grammar waiting queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_running_reqs_offline_batch = Gauge(
self.cache_hit_rate = Gauge( name="sglang:num_running_reqs_offline_batch",
name="sglang:cache_hit_rate", documentation="The number of running low-priority offline batch requests(label is 'batch').",
documentation="The prefix cache hit rate.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.spec_accept_length = Gauge(
name="sglang:spec_accept_length",
documentation="The average acceptance length of speculative decoding.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.avg_request_queue_latency = Gauge( self.avg_request_queue_latency = Gauge(
name="sglang:avg_request_queue_latency", name="sglang:avg_request_queue_latency",
documentation="The average request queue latency for the last batch of requests in seconds.", documentation="The average request queue latency for the last batch of requests in seconds.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate",
documentation="The prefix cache hit rate.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.total_retracted_reqs = Gauge( # Speculative decoding
name="sglang:total_retracted_reqs", self.spec_accept_length = Gauge(
documentation="The total number of retracted requests due to kvcache full.", name="sglang:spec_accept_length",
documentation="The average acceptance length of speculative decoding.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
# Disaggregation queue metrics # PD disaggregation
self.num_prefill_prealloc_queue_reqs = Gauge( self.num_prefill_prealloc_queue_reqs = Gauge(
name="sglang:num_prefill_prealloc_queue_reqs", name="sglang:num_prefill_prealloc_queue_reqs",
documentation="The number of requests in the prefill prealloc queue.", documentation="The number of requests in the prefill prealloc queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_prefill_inflight_queue_reqs = Gauge( self.num_prefill_inflight_queue_reqs = Gauge(
name="sglang:num_prefill_inflight_queue_reqs", name="sglang:num_prefill_inflight_queue_reqs",
documentation="The number of requests in the prefill inflight queue.", documentation="The number of requests in the prefill inflight queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_decode_prealloc_queue_reqs = Gauge( self.num_decode_prealloc_queue_reqs = Gauge(
name="sglang:num_decode_prealloc_queue_reqs", name="sglang:num_decode_prealloc_queue_reqs",
documentation="The number of requests in the decode prealloc queue.", documentation="The number of requests in the decode prealloc queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_decode_transfer_queue_reqs = Gauge( self.num_decode_transfer_queue_reqs = Gauge(
name="sglang:num_decode_transfer_queue_reqs", name="sglang:num_decode_transfer_queue_reqs",
documentation="The number of requests in the decode transfer queue.", documentation="The number of requests in the decode transfer queue.",
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_bootstrap_failed_reqs = Counter( self.num_bootstrap_failed_reqs = Counter(
name="sglang:num_bootstrap_failed_reqs", name="sglang:num_bootstrap_failed_reqs_total",
documentation="The number of bootstrap failed requests.", documentation="The number of bootstrap failed requests.",
labelnames=labels.keys(), labelnames=labels.keys(),
) )
self.num_transfer_failed_reqs = Counter( self.num_transfer_failed_reqs = Counter(
name="sglang:num_transfer_failed_reqs", name="sglang:num_transfer_failed_reqs_total",
documentation="The number of transfer failed requests.", documentation="The number of transfer failed requests.",
labelnames=labels.keys(), labelnames=labels.keys(),
) )
self.kv_transfer_speed_gb_s = Gauge(
name="sglang:kv_transfer_speed_gb_s",
documentation="The transfer speed of the KV cache in GB/s.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.kv_transfer_latency_ms = Gauge(
name="sglang:kv_transfer_latency_ms",
documentation="The transfer latency of the KV cache in ms.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
# Retract
self.total_retracted_reqs = Gauge(
name="sglang:total_retracted_reqs",
documentation="The total number of retracted requests due to kvcache full.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_retracted_reqs = Gauge(
name="sglang:num_retracted_reqs",
documentation="The number of retracted requests.",
labelnames=labels.keys(),
)
self.num_paused_reqs = Gauge(
name="sglang:num_paused_reqs",
documentation="The number of paused requests by async weight sync.",
labelnames=labels.keys(),
)
# Utilization
self.utilization = Gauge(
name="sglang:utilization",
documentation="The utilization.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.max_running_requests_under_SLO = Gauge(
name="sglang:max_running_requests_under_SLO",
documentation="The maximum number of running requests under SLO.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
# Engine startup
self.engine_startup_time = Gauge(
name="sglang:engine_startup_time",
documentation="The time taken for the engine to start up.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.engine_load_weights_time = Gauge(
name="sglang:engine_load_weights_time",
documentation="The time taken for the engine to load weights.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
# Additional queueing time histogram
self.queue_time = Histogram(
name="sglang:queue_time_s",
documentation="Histogram of queueing time in seconds.",
labelnames=labels.keys(),
buckets=[
0.0,
0.1,
0.2,
0.5,
1,
2,
3,
4,
5,
10,
15,
20,
30,
40,
50,
60,
70,
80,
90,
100,
200,
300,
400,
500,
600,
700,
800,
900,
1000,
1200,
1400,
1600,
1800,
2000,
2500,
3000,
],
)
# Grammar metrics
self.grammar_compilation_time = Histogram(
name="sglang:grammar_compilation_time_seconds",
documentation="Histogram of grammar compilation time in seconds.",
labelnames=labels.keys(),
buckets=[
0.0,
0.01,
0.02,
0.05,
0.1,
0.2,
0.5,
1,
2,
5,
10,
20,
30,
60,
90,
120,
240,
],
)
self.num_grammar_cache_hit = Counter(
name="sglang:num_grammar_cache_hit_total",
documentation="Number of grammar cache hits.",
labelnames=labels.keys(),
)
self.num_grammar_aborted = Counter(
name="sglang:num_grammar_aborted_total",
documentation="Number of grammar aborted requests.",
labelnames=labels.keys(),
)
self.num_grammar_total = Counter(
name="sglang:num_grammar_total",
documentation="Number of the total grammar requests.",
labelnames=labels.keys(),
)
self.grammar_schema_count = Histogram(
name="sglang:grammar_schema_count",
documentation="Histogram of grammar schema count.",
labelnames=labels.keys(),
buckets=[
0,
1,
2,
5,
10,
20,
30,
40,
60,
80,
100,
120,
140,
160,
180,
200,
300,
400,
500,
700,
1000,
],
)
self.grammar_ebnf_size = Histogram(
name="sglang:grammar_ebnf_size",
documentation="Histogram of grammar EBNF size.",
labelnames=labels.keys(),
buckets=[
0,
50,
100,
200,
300,
500,
1000,
2000,
3000,
5000,
10000,
20000,
30000,
50000,
100000,
],
)
tree_traversal_time_buckets = [
0.0,
0.01,
0.02,
0.05,
0.1,
0.2,
0.5,
1,
2,
5,
10,
15,
30,
60,
90,
120,
240,
]
self.grammar_tree_traversal_time_avg = Histogram(
name="sglang:grammar_tree_traversal_time_avg",
documentation="Histogram of average grammar tree traversal time in seconds.",
labelnames=labels.keys(),
buckets=tree_traversal_time_buckets,
)
self.grammar_tree_traversal_time_max = Histogram(
name="sglang:grammar_tree_traversal_time_max",
documentation="Histogram of max grammar tree traversal time in seconds.",
labelnames=labels.keys(),
buckets=tree_traversal_time_buckets,
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data) gauge.labels(**self.labels).set(data)
def log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def increment_bootstrap_failed_reqs(self) -> None: def increment_bootstrap_failed_reqs(self) -> None:
self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1) self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)
...@@ -284,14 +531,19 @@ class SchedulerMetricsCollector: ...@@ -284,14 +531,19 @@ class SchedulerMetricsCollector:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens) self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
self._log_gauge(self.token_usage, stats.token_usage) self._log_gauge(self.token_usage, stats.token_usage)
self._log_gauge(self.swa_token_usage, stats.swa_token_usage)
self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.gen_throughput, stats.gen_throughput)
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs) self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
self._log_gauge(
self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
# Speculative decoding
self._log_gauge(self.spec_accept_length, stats.spec_accept_length) self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
# Disaggregation metrics # PD disaggregation
self._log_gauge( self._log_gauge(
self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
) )
...@@ -304,15 +556,59 @@ class SchedulerMetricsCollector: ...@@ -304,15 +556,59 @@ class SchedulerMetricsCollector:
self._log_gauge( self._log_gauge(
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
) )
self._log_gauge(self.kv_transfer_speed_gb_s, stats.kv_transfer_speed_gb_s)
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
# Retract
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
# Utilization
self._log_gauge(self.utilization, stats.utilization)
if stats.max_running_requests_under_SLO is not None:
self._log_gauge(
self.max_running_requests_under_SLO,
stats.max_running_requests_under_SLO,
)
# Engine startup time
self._log_gauge(self.engine_startup_time, stats.engine_startup_time)
if stats.engine_load_weights_time is not None:
self._log_gauge(
self.engine_load_weights_time, stats.engine_load_weights_time
)
self.last_log_time = time.perf_counter() self.last_log_time = time.perf_counter()
def log_grammar_stats(self, grammar_stats) -> None:
# Duck-typed GrammarStats to avoid cross-package dependency
if getattr(grammar_stats, "compilation_time", None) is not None:
self.log_histogram(
self.grammar_compilation_time, grammar_stats.compilation_time
)
if getattr(grammar_stats, "schema_count", None) is not None:
self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
if getattr(grammar_stats, "ebnf_size", None) is not None:
self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
tree_times = getattr(grammar_stats, "tree_traversal_time", None)
if tree_times:
max_time = max(tree_times)
avg_time = sum(tree_times) / len(tree_times)
self.log_histogram(self.grammar_tree_traversal_time_max, max_time)
self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
if getattr(grammar_stats, "is_cache_hit", False):
self.num_grammar_cache_hit.labels(**self.labels).inc(1)
if getattr(grammar_stats, "is_grammar_aborted", False):
self.num_grammar_aborted.labels(**self.labels).inc(1)
self.num_grammar_total.labels(**self.labels).inc(1)
class TokenizerMetricsCollector: class TokenizerMetricsCollector:
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: Optional[ServerArgs] = None,
labels: Dict[str, str], labels: Dict[str, str] = None,
bucket_time_to_first_token: Optional[List[float]] = None, bucket_time_to_first_token: Optional[List[float]] = None,
bucket_inter_token_latency: Optional[List[float]] = None, bucket_inter_token_latency: Optional[List[float]] = None,
bucket_e2e_request_latency: Optional[List[float]] = None, bucket_e2e_request_latency: Optional[List[float]] = None,
...@@ -321,7 +617,7 @@ class TokenizerMetricsCollector: ...@@ -321,7 +617,7 @@ class TokenizerMetricsCollector:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
self.labels = labels self.labels = labels or {}
self.collect_tokens_histogram = collect_tokens_histogram self.collect_tokens_histogram = collect_tokens_histogram
self.prompt_tokens_total = Counter( self.prompt_tokens_total = Counter(
...@@ -361,6 +657,13 @@ class TokenizerMetricsCollector: ...@@ -361,6 +657,13 @@ class TokenizerMetricsCollector:
30000, 30000,
35000, 35000,
40000, 40000,
66000,
99000,
132000,
300000,
600000,
900000,
1100000,
] ]
self.prompt_tokens_histogram = Histogram( self.prompt_tokens_histogram = Histogram(
name="sglang:prompt_tokens_histogram", name="sglang:prompt_tokens_histogram",
...@@ -370,34 +673,13 @@ class TokenizerMetricsCollector: ...@@ -370,34 +673,13 @@ class TokenizerMetricsCollector:
server_args.prompt_tokens_buckets, default_bucket_prompt_tokens server_args.prompt_tokens_buckets, default_bucket_prompt_tokens
), ),
) )
default_bucket_generation_tokens = [
100,
300,
500,
1000,
1200,
1500,
1700,
2000,
2500,
3000,
3500,
4000,
4500,
5000,
6000,
7000,
8000,
9000,
10000,
]
self.generation_tokens_histogram = Histogram( self.generation_tokens_histogram = Histogram(
name="sglang:generation_tokens_histogram", name="sglang:generation_tokens_histogram",
documentation="Histogram of generation token length.", documentation="Histogram of generation token length.",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=generate_buckets( buckets=generate_buckets(
server_args.generation_tokens_buckets, server_args.generation_tokens_buckets,
default_bucket_generation_tokens, default_bucket_prompt_tokens,
), ),
) )
...@@ -467,7 +749,10 @@ class TokenizerMetricsCollector: ...@@ -467,7 +749,10 @@ class TokenizerMetricsCollector:
100, 100,
200, 200,
400, 400,
800, 600,
1200,
1800,
2400,
] ]
if bucket_inter_token_latency is None: if bucket_inter_token_latency is None:
...@@ -518,6 +803,14 @@ class TokenizerMetricsCollector: ...@@ -518,6 +803,14 @@ class TokenizerMetricsCollector:
buckets=bucket_e2e_request_latency, buckets=bucket_e2e_request_latency,
) )
# Offline batch specific TTFB histogram
self.histogram_time_to_first_token_offline_batch = Histogram(
name="sglang:time_to_first_token_seconds_offline_batch",
documentation="Histogram of time to first token in seconds for offline batch requests.",
labelnames=labels.keys(),
buckets=bucket_time_to_first_token,
)
def _log_histogram(self, histogram, data: Union[int, float]) -> None: def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data) histogram.labels(**self.labels).observe(data)
...@@ -541,8 +834,26 @@ class TokenizerMetricsCollector: ...@@ -541,8 +834,26 @@ class TokenizerMetricsCollector:
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens) self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
self._log_histogram(self.generation_tokens_histogram, generation_tokens) self._log_histogram(self.generation_tokens_histogram, generation_tokens)
def observe_time_to_first_token(self, value: float): def observe_time_to_first_token(self, value: float, label: str = ""):
self.histogram_time_to_first_token.labels(**self.labels).observe(value) if label == "batch":
self.histogram_time_to_first_token_offline_batch.labels(
**self.labels
).observe(value)
else:
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
def check_time_to_first_token_straggler(self, value: float) -> bool:
his = self.histogram_time_to_first_token.labels(**self.labels)
total_observations = sum(bucket._value for bucket in his._buckets)
if total_observations < 100:
return False
p99_threshold = total_observations * 0.99
cumulative_count = 0
for i, bucket in enumerate(his._buckets):
cumulative_count += bucket._value
if cumulative_count > p99_threshold:
return value >= his._upper_bounds[i]
return False
def observe_inter_token_latency(self, internval: float, num_new_tokens: int): def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
adjusted_interval = internval / num_new_tokens adjusted_interval = internval / num_new_tokens
......
"""
Records startup latency breakdown by context using gauge metrics in seconds
"""
import logging
import time
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Dict, Generator, Optional
logger = logging.getLogger(__name__)
enable_startup_metrics = False
STARTUP_LATENCY_SECONDS = None
# Track maximum durations for each context
_max_durations: Dict[str, float] = {}
def enable_startup_timer():
"""Initialize startup latency metrics when metrics are enabled"""
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Gauge
global enable_startup_metrics, STARTUP_LATENCY_SECONDS
enable_startup_metrics = True
STARTUP_LATENCY_SECONDS = Gauge(
"sglang:startup_latency_breakdown_seconds_max",
"Startup latency breakdown in seconds by context, only records the maximum duration if the context is called multiple times.",
labelnames=["context"],
multiprocess_mode="mostrecent",
)
def set_startup_metric(context: str, value: float, should_log: bool = True):
"""Set the startup metric for a given context"""
if should_log:
logger.info(f"Setting startup metric: {context} took {value:.3f}s")
if not enable_startup_metrics:
return
current_max = _max_durations.get(context, 0.0)
if value > current_max:
_max_durations[context] = value
STARTUP_LATENCY_SECONDS.labels(context=context).set(value)
def reset_startup_timers():
"""Reset all recorded maximum durations. Useful for testing or reinitialization."""
global _max_durations
_max_durations.clear()
def get_max_duration(context: str) -> Optional[float]:
"""Get the maximum recorded duration for a context name."""
return _max_durations.get(context)
@contextmanager
def startup_timer(name: str, log_only: bool = False) -> Generator[None, None, None]:
"""
Context manager to measure startup latency for arbitrary code blocks.
Only records the maximum duration if the context is called multiple times.
Usage:
with startup_timer("model_loading"):
# model loading code
model = load_model()
with startup_timer("memory_allocation"):
# memory setup code
allocate_memory()
"""
start_time = time.monotonic()
try:
yield
finally:
duration_seconds = time.monotonic() - start_time
# Track the maximum duration for this context name
current_max = _max_durations.get(name, 0.0)
is_new_max = duration_seconds > current_max
if is_new_max:
_max_durations[name] = duration_seconds
# Only update Prometheus gauge if this is a new maximum
if enable_startup_metrics and not log_only:
STARTUP_LATENCY_SECONDS.labels(context=name).set(duration_seconds)
# Log with indication if this was a new max
logger.info(f"Startup timing: {name} took {duration_seconds:.3f}s")
def time_startup_latency(
func: Callable = None, name: Optional[str] = None, log_only: bool = False
) -> Callable[..., Any]:
"""
A decorator to measure startup context latency and record it in seconds.
Only records the maximum duration if the context is called multiple times.
Usage:
@time_startup_latency
def load_model():
# model loading code
@time_startup_latency(name="custom_init")
def initialize_something():
# initialization code
@time_startup_latency(name="debug_only", log_only=True)
def debug_function():
# This will only log, not record to Prometheus
"""
def measure(func: Callable[..., Any]) -> Callable[..., Any]:
nonlocal name
name = name or func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.monotonic()
try:
result = func(*args, **kwargs)
return result
finally:
duration_seconds = time.monotonic() - start_time
# Track the maximum duration for this context name
current_max = _max_durations.get(name, 0.0)
is_new_max = duration_seconds > current_max
if is_new_max:
_max_durations[name] = duration_seconds
# Only update Prometheus gauge if this is a new maximum
if enable_startup_metrics and not log_only:
STARTUP_LATENCY_SECONDS.labels(context=name).set(
duration_seconds
)
# Log the timing
logger.info(f"Startup timing: {name} took {duration_seconds:.3f}s")
return wrapper
if func:
return measure(func)
else:
return measure
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment