Unverified Commit b1721edb authored by Yingchun Lai's avatar Yingchun Lai Committed by GitHub
Browse files

[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)

parent 57234d0c
...@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort, prepare_abort,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
...@@ -253,6 +253,7 @@ class DecodePreallocQueue: ...@@ -253,6 +253,7 @@ class DecodePreallocQueue:
prefill_dp_rank=req.data_parallel_rank, prefill_dp_rank=req.data_parallel_rank,
) )
req.add_latency(RequestStage.DECODE_PREPARE)
self.queue.append( self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
) )
...@@ -421,6 +422,7 @@ class DecodePreallocQueue: ...@@ -421,6 +422,7 @@ class DecodePreallocQueue:
kv_indices, self.token_to_kv_pool_allocator.page_size kv_indices, self.token_to_kv_pool_allocator.page_size
) )
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
...@@ -662,6 +664,7 @@ class DecodeTransferQueue: ...@@ -662,6 +664,7 @@ class DecodeTransferQueue:
for i in indices_to_remove: for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index idx = self.queue[i].metadata_buffer_index
assert idx != -1 assert idx != -1
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
self.req_to_metadata_buffer_idx_allocator.free(idx) self.req_to_metadata_buffer_idx_allocator.free(idx)
self.queue = [ self.queue = [
...@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
# we can only add at least `num_not_used_batch` new batch to the running queue # we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch: if i < num_not_used_batch:
can_run_list.append(req) can_run_list.append(req)
req.add_latency(RequestStage.DECODE_WAITING)
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
else: else:
waiting_queue.append(req) waiting_queue.append(req)
......
...@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import ( ...@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import (
FINISH_LENGTH,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import ( from sglang.srt.utils import (
DynamicGradMode, DynamicGradMode,
...@@ -170,6 +175,7 @@ class PrefillBootstrapQueue: ...@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
pp_rank=self.pp_rank, pp_rank=self.pp_rank,
) )
self._process_req(req) self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req) self.queue.append(req)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None: def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
...@@ -256,6 +262,8 @@ class PrefillBootstrapQueue: ...@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req) bootstrapped_reqs.append(req)
indices_to_remove.add(i) indices_to_remove.add(i)
...@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
# There is no output_ids for prefill # There is no output_ids for prefill
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if ( if (
logits_output is not None logits_output is not None
...@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
) )
for req in done_reqs: for req in done_reqs:
req: Req req: Req
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1 req.metadata_buffer_index = -1
......
from __future__ import annotations from __future__ import annotations
import enum
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -35,6 +37,7 @@ import copy ...@@ -35,6 +37,7 @@ import copy
import dataclasses import dataclasses
import logging import logging
import threading import threading
import time
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
from itertools import chain from itertools import chain
...@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache ...@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -407,6 +410,23 @@ class MultimodalInputs: ...@@ -407,6 +410,23 @@ class MultimodalInputs:
# other args would be kept intact # other args would be kept intact
class RequestStage(str, enum.Enum):
# prefill
PREFILL_WAITING = "prefill_waiting"
# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
# disaggregation decode
DECODE_PREPARE = "decode_prepare"
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"
class Req: class Req:
"""The input and output status of a request.""" """The input and output status of a request."""
...@@ -433,6 +453,7 @@ class Req: ...@@ -433,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None, vocab_size: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -590,10 +611,12 @@ class Req: ...@@ -590,10 +611,12 @@ class Req:
self.spec_verify_ct = 0 self.spec_verify_ct = 0
# For metrics # For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats() self.time_stats: TimeStats = TimeStats()
self.has_log_time_stats: bool = False self.has_log_time_stats: bool = False
self.queue_time_start = None self.queue_time_start = None
self.queue_time_end = None self.queue_time_end = None
self.last_tic = time.monotonic()
# For disaggregation # For disaggregation
self.bootstrap_host: str = bootstrap_host self.bootstrap_host: str = bootstrap_host
...@@ -626,6 +649,16 @@ class Req: ...@@ -626,6 +649,16 @@ class Req:
"""Check if this request is prefill-only (no token generation needed).""" """Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0 return self.sampling_params.max_new_tokens == 0
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
now = time.monotonic()
self.metrics_collector.observe_request_latency_seconds(
stage.value, now - self.last_tic
)
self.last_tic = now
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None: if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs self.multimodal_inputs = image_inputs
......
...@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
MultimodalInputs, MultimodalInputs,
Req, Req,
RequestStage,
ScheduleBatch, ScheduleBatch,
global_server_args_dict, global_server_args_dict,
) )
...@@ -1232,6 +1233,9 @@ class Scheduler( ...@@ -1232,6 +1233,9 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room, bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank, data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size, vocab_size=self.model_config.vocab_size,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
...@@ -1768,6 +1772,7 @@ class Scheduler( ...@@ -1768,6 +1772,7 @@ class Scheduler(
# only record queue time when enable_metrics is True to avoid overhead # only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list: for req in can_run_list:
req.queue_time_end = time.perf_counter() req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING)
self.waiting_queue = [ self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
......
...@@ -17,7 +17,7 @@ from dataclasses import dataclass, field ...@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.metrics.utils import generate_buckets from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
...@@ -513,6 +513,14 @@ class SchedulerMetricsCollector: ...@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
buckets=tree_traversal_time_buckets, buckets=tree_traversal_time_buckets,
) )
self.request_latency_seconds = Histogram(
name="sglang:request_latency_seconds",
documentation="The latency of each stage of requests.",
# captures latency in range [1ms - ~1191s]
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
labelnames=list(labels.keys()) + ["stage"],
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data) gauge.labels(**self.labels).set(data)
...@@ -526,6 +534,10 @@ class SchedulerMetricsCollector: ...@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
def increment_transfer_failed_reqs(self) -> None: def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1) self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
labels_with_stage = {**self.labels, "stage": stage}
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)
def log_stats(self, stats: SchedulerStats) -> None: def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens) self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
......
...@@ -20,6 +20,8 @@ import time ...@@ -20,6 +20,8 @@ import time
from functools import wraps from functools import wraps
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
from sglang.srt.metrics.utils import exponential_buckets
enable_metrics = False enable_metrics = False
...@@ -42,13 +44,6 @@ def enable_func_timer(): ...@@ -42,13 +44,6 @@ def enable_func_timer():
FUNC_LATENCY = None FUNC_LATENCY = None
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
def time_func_latency( def time_func_latency(
func: Callable = None, name: Optional[str] = None func: Callable = None, name: Optional[str] = None
) -> Callable[..., Any]: ) -> Callable[..., Any]:
......
...@@ -46,3 +46,10 @@ def generate_buckets( ...@@ -46,3 +46,10 @@ def generate_buckets(
return sorted(set(default_buckets)) return sorted(set(default_buckets))
assert rule == "customer" assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]])) return sorted(set([float(x) for x in buckets_rule[1:]]))
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment