Unverified Commit a146d999 authored by Lzhang-hub's avatar Lzhang-hub Committed by GitHub
Browse files

support prometheus metrics (#1853)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarByron Hsu <byronhsu1230@gmail.com>
parent f5113e50
......@@ -31,6 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses
import logging
import time
from typing import List, Optional, Tuple, Union
import torch
......@@ -254,6 +255,16 @@ class Req:
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
# Lifetime traces
# time when request is created and added to waitlist
self.created_time = None
# time when request is added to prefill batch
self.queued_time = None
# time when request is being processed
self.started_time = None
# time when request is finished
self.finished_time = None
# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
......@@ -1028,6 +1039,9 @@ class ScheduleBatch:
f"#req={(len(self.reqs))})"
)
def mark_reqs_started(self):
for req in self.reqs:
req.started_time = time.time()
@dataclasses.dataclass
class ModelWorkerBatch:
......
......@@ -17,6 +17,7 @@ limitations under the License.
import os
import random
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
......@@ -306,6 +307,7 @@ class PrefillAdder:
):
# Non-chunked prefill
self.can_run_list.append(req)
req.queued_time = time.time()
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len,
......@@ -324,6 +326,7 @@ class PrefillAdder:
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
req.queued_time = time.time()
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
......
......@@ -62,6 +62,8 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
......@@ -222,7 +224,8 @@ class Scheduler:
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log
self.stream_interval = server_args.stream_interval
# Init chunked prefill
......@@ -291,6 +294,15 @@ class Scheduler:
],
with_stack=True,
)
# Init metrics stats
self.stats = Stats()
self.metrics_collector = PrometheusMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
max_model_len=self.max_total_num_tokens,
)
def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
......@@ -338,6 +350,11 @@ class Scheduler:
else:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()
self.last_batch = batch
......@@ -476,6 +493,7 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1,
)
req.created_time = time.time()
self.waiting_queue.append(req)
def handle_embedding_request(
......@@ -504,9 +522,11 @@ class Scheduler:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.last_log_tic = time.time()
# set system stats
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
......@@ -676,6 +696,9 @@ class Scheduler:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
if num_mixed_running > 0:
logger.info(
......@@ -770,6 +793,7 @@ class Scheduler:
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
......@@ -789,6 +813,88 @@ class Scheduler:
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid
return ret
def get_stats(self,batch: ScheduleBatch):
# TODO: get stats for chunked prefill
now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0
# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate=self.stats.cache_hit_rate
token_usage=self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []
# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2)
for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now-self.last_stats_tic)
if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None)
return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)
def log_stats(self,stats:Stats):
self.metrics_collector.log_stats(stats)
def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Utilities for Prometheus Metrics Collection."""
import logging
from abc import ABC, abstractmethod
from typing import Counter as CollectionsCounter
from typing import Dict, List, Union
import numpy as np
from prometheus_client import Counter, Gauge, Histogram
from sglang.srt.metrics.metrics_types import Stats
class Metrics:
"""
SGLang Metrics
"""
def __init__(self, labelnames: List[str], max_model_len):
# Configuration Stats
self.max_total_num_tokens = Gauge(
name="sglang:max_total_num_tokens",
documentation="Maximum total number of tokens",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
self.max_prefill_tokens = Gauge(
name="sglang:max_prefill_tokens",
documentation="Maximum prefill tokens",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
self.max_running_requests = Gauge(
name="sglang:max_running_requests",
documentation="Maximum running requests",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
self.context_len = Gauge(
name="sglang:context_len",
documentation="Context length",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
# Decode Stats
self.num_running_sys = Gauge(
name="sglang:num_requests_running",
documentation="Number of requests currently running on GPU",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.num_waiting_sys = Gauge(
name="sglang:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.gen_throughput = Gauge(
name="sglang:gen_throughput",
documentation="Gen token throughput (token/s)",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.token_usage = Gauge(
name="sglang:token_usage",
documentation="Total token usage",
labelnames=labelnames,
multiprocess_mode="sum",
)
# System Stats
# KV Cache Usage in %
# self.gpu_cache_usage_sys = Gauge(
# "gpu_cache_usage_perc",
# "GPU KV-cache usage. 1 means 100 percent usage.",
# labelnames=labelnames,
# multiprocess_mode="sum")
self.new_seq = Gauge(
name="sglang:new_seq",
documentation="Number of new sequences",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.new_token = Gauge(
name="sglang:new_token",
documentation="Number of new token",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Prefix caching block hit rate
self.cached_token = Gauge(
name="sglang:cached_token",
documentation="Number of cached token",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate",
documentation="Cache hit rate",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.queue_req = Gauge(
name="sglang:queue_req",
documentation="Number of queued requests",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Iteration stats
self.counter_prompt_tokens = Counter(
name="sglang:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = Counter(
name="sglang:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
name="sglang:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 25.0, 30.0
])
self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
# Request Stats
# Metadata
self.num_prompt_tokens_requests = Histogram(
name="sglang:request_prompt_tokens",
documentation="Number of prefill tokens processed",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.num_generation_tokens_requests = Histogram(
name="sglang:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.finished_reason_requests = Counter(
name="sglang:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"],
)
self.histogram_time_e2e_requests = Histogram(
name="sglang:e2e_request_latency_seconds",
documentation="Histogram of End-to-end request latency in seconds",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_time_waiting_requests = Histogram(
name="sglang:waiting_request_latency_seconds",
documentation="Histogram of request waiting time in seconds",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_time_decode_requests = Histogram(
name="sglang:decode_request_latency_seconds",
documentation="Histogram of request decoding time in seconds",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
class MetricsCollector(ABC):
"""
SGLang Metrics Collector
"""
@abstractmethod
def log_stats(self, stats: Stats) -> None:
pass
class PrometheusMetricsCollector(MetricsCollector):
"""
SGLang Metrics Collector
"""
def __init__(self, labels: Dict[str, str], max_model_len: int) -> None:
self.labels = labels
self.metrics = Metrics(
labelnames=list(labels.keys()), max_model_len=max_model_len
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(
self, counter, data: CollectionsCounter, label_key: str
) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def log_stats(self, stats: Stats) -> None:
self._log_gauge(self.metrics.max_total_num_tokens, stats.max_total_num_tokens)
self._log_gauge(self.metrics.max_prefill_tokens, stats.max_prefill_tokens)
self._log_gauge(self.metrics.max_running_requests, stats.max_running_requests)
self._log_gauge(self.metrics.context_len, stats.context_len)
self._log_histogram(
self.metrics.num_prompt_tokens_requests, stats.num_prompt_tokens_requests
)
self._log_histogram(
self.metrics.num_generation_tokens_requests,
stats.num_generation_tokens_requests,
)
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
stats.num_generation_tokens_iter)
self._log_histogram(self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter)
self._log_histogram(self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter)
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)
self._log_gauge(self.metrics.num_waiting_sys, stats.num_waiting_req)
self._log_gauge(self.metrics.gen_throughput, stats.gen_throughput)
self._log_gauge(self.metrics.token_usage, stats.token_usage)
self._log_histogram(
self.metrics.histogram_time_e2e_requests, stats.time_e2e_requests
)
self._log_histogram(
self.metrics.histogram_time_waiting_requests, stats.time_waiting_requests
)
self._log_histogram(
self.metrics.histogram_time_decode_requests, stats.time_decode_requests
)
self._log_gauge(self.metrics.new_seq, stats.new_seq)
self._log_gauge(self.metrics.new_token, stats.new_token)
self._log_gauge(self.metrics.cached_token, stats.cached_token)
self._log_gauge(self.metrics.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.metrics.queue_req, stats.queue_req)
def build_1_2_5_buckets(max_value: int) -> List[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst = [1, 2, 5]
exponent = 0
buckets: List[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
\ No newline at end of file
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Metrics Types"""
from dataclasses import dataclass, field
from typing import List
@dataclass
class Stats:
# config
max_total_num_tokens: int = 0
max_prefill_tokens: int = 0
max_running_requests: int = 0
context_len: int = 0
# request stats
num_prompt_tokens_requests: List[int] = field(default_factory=list)
num_generation_tokens_requests: List[int] = field(default_factory=list)
finished_reason_requests: List[str] = field(default_factory=list)
# decode stats
num_running_req: int = 0
num_waiting_req: int = 0
gen_throughput: float = 0.0
num_token: int = 0
token_usage: float = 0.0
waiting_queue: int = 0
time_e2e_requests: List[float] = field(default_factory=list)
time_waiting_requests: List[float] = field(default_factory=list)
time_decode_requests: List[float] = field(default_factory=list)
# system stats
token_usage: float = 0.0
is_mixed_chunk: bool = False
new_seq: int = 0
new_token: int = 0
cached_token: int = 0
cache_hit_rate: float = 0.0
running_req: int = 0
queue_req: int = 0
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int = 0
num_generation_tokens_iter: int = 0
time_to_first_tokens_iter: List[float] = field(default_factory=list)
time_per_output_tokens_iter: List[float] = field(default_factory=list)
\ No newline at end of file
......@@ -25,12 +25,15 @@ import json
import logging
import multiprocessing as mp
import os
import re
import tempfile
import threading
import time
from http import HTTPStatus
from typing import AsyncIterator, Dict, List, Optional, Union
import orjson
from starlette.routing import Mount
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -86,6 +89,10 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -412,6 +419,18 @@ def launch_engine(
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()
def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def launch_server(
server_args: ServerArgs,
......@@ -439,6 +458,11 @@ def launch_server(
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# add prometheus middleware
if server_args.enable_metrics:
_set_prometheus_env()
add_prometheus_middleware(app)
# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
......@@ -466,6 +490,21 @@ def launch_server(
finally:
t.join()
def _set_prometheus_env():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
# we need to set this before importing prometheus_client
# https://prometheus.github.io/client_python/multiprocess/
global prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.")
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
)
else:
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
......
......@@ -70,6 +70,7 @@ class ServerArgs:
log_level_http: Optional[str] = None
log_requests: bool = False
show_time_cost: bool = False
enable_metrics: bool = False
# Other
api_key: Optional[str] = None
......@@ -414,6 +415,12 @@ class ServerArgs:
action="store_true",
help="Show time cost of custom marks.",
)
parser.add_argument(
"--enable-metrics",
action="store_true",
help="Enable log prometheus metrics.",
)
parser.add_argument(
"--api-key",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment