Unverified Commit b1721edb authored by Yingchun Lai's avatar Yingchun Lai Committed by GitHub
Browse files

[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)

parent 57234d0c
......@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
......@@ -253,6 +253,7 @@ class DecodePreallocQueue:
prefill_dp_rank=req.data_parallel_rank,
)
req.add_latency(RequestStage.DECODE_PREPARE)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
......@@ -421,6 +422,7 @@ class DecodePreallocQueue:
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
......@@ -662,6 +664,7 @@ class DecodeTransferQueue:
for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index
assert idx != -1
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
self.req_to_metadata_buffer_idx_allocator.free(idx)
self.queue = [
......@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
# we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch:
can_run_list.append(req)
req.add_latency(RequestStage.DECODE_WAITING)
req.init_next_round_input(self.tree_cache)
else:
waiting_queue.append(req)
......
......@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
FINISH_LENGTH,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
......@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
pp_rank=self.pp_rank,
)
self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
......@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
......@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
self.disagg_prefill_inflight_queue.append(req)
if (
logits_output is not None
......@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
)
for req in done_reqs:
req: Req
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1
......
from __future__ import annotations
import enum
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -35,6 +37,7 @@ import copy
import dataclasses
import logging
import threading
import time
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
......@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -407,6 +410,23 @@ class MultimodalInputs:
# other args would be kept intact
class RequestStage(str, enum.Enum):
# prefill
PREFILL_WAITING = "prefill_waiting"
# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
# disaggregation decode
DECODE_PREPARE = "decode_prepare"
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"
class Req:
"""The input and output status of a request."""
......@@ -433,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
# Input and output info
self.rid = rid
......@@ -590,10 +611,12 @@ class Req:
self.spec_verify_ct = 0
# For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats()
self.has_log_time_stats: bool = False
self.queue_time_start = None
self.queue_time_end = None
self.last_tic = time.monotonic()
# For disaggregation
self.bootstrap_host: str = bootstrap_host
......@@ -626,6 +649,16 @@ class Req:
"""Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
now = time.monotonic()
self.metrics_collector.observe_request_latency_seconds(
stage.value, now - self.last_tic
)
self.last_tic = now
def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
......
......@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
MultimodalInputs,
Req,
RequestStage,
ScheduleBatch,
global_server_args_dict,
)
......@@ -1232,6 +1233,9 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
)
req.tokenizer = self.tokenizer
......@@ -1768,6 +1772,7 @@ class Scheduler(
# only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list:
req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING)
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
......
......@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
from sglang.srt.metrics.utils import generate_buckets
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var
......@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
buckets=tree_traversal_time_buckets,
)
self.request_latency_seconds = Histogram(
name="sglang:request_latency_seconds",
documentation="The latency of each stage of requests.",
# captures latency in range [1ms - ~1191s]
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
labelnames=list(labels.keys()) + ["stage"],
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
......@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
labels_with_stage = {**self.labels, "stage": stage}
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
......
......@@ -20,6 +20,8 @@ import time
from functools import wraps
from typing import Any, Callable, List, Optional
from sglang.srt.metrics.utils import exponential_buckets
enable_metrics = False
......@@ -42,13 +44,6 @@ def enable_func_timer():
FUNC_LATENCY = None
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
def time_func_latency(
func: Callable = None, name: Optional[str] = None
) -> Callable[..., Any]:
......
......@@ -46,3 +46,10 @@ def generate_buckets(
return sorted(set(default_buckets))
assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]]))
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment