Unverified Commit 1929c067 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify prometheus metrics (#1981)


Co-authored-by: default avatarMohit Reddy <mohitreddy1996@users.noreply.github.com>
parent ed53ac84
...@@ -31,7 +31,6 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -31,7 +31,6 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses import dataclasses
import logging import logging
import time
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -255,16 +254,6 @@ class Req: ...@@ -255,16 +254,6 @@ class Req:
# For Qwen2-VL # For Qwen2-VL
self.mrope_position_delta = [] # use mutable object self.mrope_position_delta = [] # use mutable object
# Lifetime traces
# time when request is created and added to waitlist
self.created_time = None
# time when request is added to prefill batch
self.queued_time = None
# time when request is being processed
self.started_time = None
# time when request is finished
self.finished_time = None
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
...@@ -1038,10 +1027,6 @@ class ScheduleBatch: ...@@ -1038,10 +1027,6 @@ class ScheduleBatch:
f"#req={(len(self.reqs))})" f"#req={(len(self.reqs))})"
) )
def mark_reqs_started(self):
for req in self.reqs:
req.started_time = time.time()
@dataclasses.dataclass @dataclasses.dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
......
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
import os import os
import random import random
import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
...@@ -307,7 +306,6 @@ class PrefillAdder: ...@@ -307,7 +306,6 @@ class PrefillAdder:
): ):
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
req.queued_time = time.time()
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req( self._prefill_one_req(
prefix_len, prefix_len,
...@@ -326,7 +324,6 @@ class PrefillAdder: ...@@ -326,7 +324,6 @@ class PrefillAdder:
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
req.queued_time = time.time()
self.new_inflight_req = req self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
......
...@@ -62,8 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker ...@@ -62,8 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
...@@ -106,6 +105,7 @@ class Scheduler: ...@@ -106,6 +105,7 @@ class Scheduler:
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule self.enable_overlap = server_args.enable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
...@@ -224,8 +224,7 @@ class Scheduler: ...@@ -224,8 +224,7 @@ class Scheduler:
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() # time of last stats for every iter self.last_decode_stats_tic = time.time()
self.last_log_tic = time.time() # time of last log for print decode log
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
# Init chunked prefill # Init chunked prefill
...@@ -294,14 +293,15 @@ class Scheduler: ...@@ -294,14 +293,15 @@ class Scheduler:
], ],
with_stack=True, with_stack=True,
) )
# Init metrics stats # Init metrics stats
self.stats = Stats() self.stats = SchedulerStats()
self.metrics_collector = PrometheusMetricsCollector( if self.enable_metrics:
self.metrics_collector = SchedulerMetricsCollector(
labels={ labels={
"model_name": self.server_args.served_model_name, "model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future, # TODO: Add lora name/path in the future,
}, },
max_model_len=self.max_total_num_tokens,
) )
def watchdog_thread(self): def watchdog_thread(self):
...@@ -350,11 +350,6 @@ class Scheduler: ...@@ -350,11 +350,6 @@ class Scheduler:
else: else:
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()
self.last_batch = batch self.last_batch = batch
...@@ -493,7 +488,6 @@ class Scheduler: ...@@ -493,7 +488,6 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1, self.max_req_len - len(req.origin_input_ids) - 1,
) )
req.created_time = time.time()
self.waiting_queue.append(req) self.waiting_queue.append(req)
def handle_embedding_request( def handle_embedding_request(
...@@ -518,25 +512,68 @@ class Scheduler: ...@@ -518,25 +512,68 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def print_decode_stats(self): def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
self.num_generated_tokens = 0 logger.info(
self.last_log_tic = time.time() f"Prefill batch. "
# set system stats f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
if self.enable_metrics:
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
self.stats.cache_hit_rate = tree_cache_hit_rate
self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
gen_throughput = self.num_generated_tokens / (
time.time() - self.last_decode_stats_tic
)
self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info( logger.info(
f"Decode batch. " f"Decode batch. "
f"#running-req: {num_running_reqs}, " f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, " f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}" f"#queue-req: {len(self.waiting_queue)}"
) )
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.metrics_collector.log_stats(self.stats)
def check_memory(self): def check_memory(self):
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
...@@ -612,7 +649,6 @@ class Scheduler: ...@@ -612,7 +649,6 @@ class Scheduler:
prefix_computed = self.policy.calc_priority(self.waiting_queue) prefix_computed = self.policy.calc_priority(self.waiting_queue)
# Prefill policy # Prefill policy
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
self.running_batch, self.running_batch,
...@@ -620,7 +656,7 @@ class Scheduler: ...@@ -620,7 +656,7 @@ class Scheduler:
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
num_mixed_running, running_bs if self.is_mixed_chunk else 0,
) )
has_inflight = self.being_chunked_req is not None has_inflight = self.being_chunked_req is not None
...@@ -677,47 +713,7 @@ class Scheduler: ...@@ -677,47 +713,7 @@ class Scheduler:
# Print stats # Print stats
if self.tp_rank == 0: if self.tp_rank == 0:
if isinstance(self.tree_cache, RadixCache): self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
if num_mixed_running > 0:
logger.info(
f"Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
else:
logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
# Create a new batch # Create a new batch
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
...@@ -789,7 +785,6 @@ class Scheduler: ...@@ -789,7 +785,6 @@ class Scheduler:
if self.is_generation: if self.is_generation:
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
...@@ -810,94 +805,6 @@ class Scheduler: ...@@ -810,94 +805,6 @@ class Scheduler:
ret = embeddings, model_worker_batch.bid ret = embeddings, model_worker_batch.bid
return ret return ret
def get_stats(self, batch: ScheduleBatch):
# TODO: get stats for chunked prefill
now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0
# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate = self.stats.cache_hit_rate
token_usage = self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []
# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(
num_generation_tokens_iter / (now - self.last_stats_tic), 2
)
for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter = len(batch.input_ids) + sum(
batch.prefix_lens
)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now - self.last_stats_tic)
if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None
)
return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)
def log_stats(self, stats: Stats):
self.metrics_collector.log_stats(stats)
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
...@@ -1035,7 +942,7 @@ class Scheduler: ...@@ -1035,7 +942,7 @@ class Scheduler:
self.tp_rank == 0 self.tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0
): ):
self.print_decode_stats() self.log_decode_stats()
def add_logprob_return_values( def add_logprob_return_values(
self, self,
......
...@@ -22,6 +22,7 @@ import logging ...@@ -22,6 +22,7 @@ import logging
import os import os
import signal import signal
import sys import sys
import time
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi import fastapi
...@@ -52,6 +53,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -52,6 +53,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightReqInput, UpdateWeightReqInput,
UpdateWeightReqOutput, UpdateWeightReqOutput,
) )
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_child_process from sglang.srt.utils import get_zmq_socket, kill_child_process
...@@ -69,6 +71,10 @@ class ReqState: ...@@ -69,6 +71,10 @@ class ReqState:
finished: bool finished: bool
event: asyncio.Event event: asyncio.Event
# For metrics
created_time: float
first_token_time: Optional[float] = None
class TokenizerManager: class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
...@@ -80,6 +86,7 @@ class TokenizerManager: ...@@ -80,6 +86,7 @@ class TokenizerManager:
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
...@@ -142,11 +149,22 @@ class TokenizerManager: ...@@ -142,11 +149,22 @@ class TokenizerManager:
# Others # Others
self.gracefully_exit = False self.gracefully_exit = False
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
created_time = time.time()
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
...@@ -164,10 +182,12 @@ class TokenizerManager: ...@@ -164,10 +182,12 @@ class TokenizerManager:
if is_single: if is_single:
tokenized_obj = await self._tokenize_one_request(obj) tokenized_obj = await self._tokenize_one_request(obj)
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
async for response in self._wait_one_response(obj, request): async for response in self._wait_one_response(obj, request, created_time):
yield response yield response
else: else:
async for response in self._handle_batch_request(obj, request): async for response in self._handle_batch_request(
obj, request, created_time
):
yield response yield response
async def _tokenize_one_request( async def _tokenize_one_request(
...@@ -231,10 +251,11 @@ class TokenizerManager: ...@@ -231,10 +251,11 @@ class TokenizerManager:
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
): ):
"""Wait for the response of one request.""" """Wait for the response of one request."""
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
while True: while True:
...@@ -272,6 +293,7 @@ class TokenizerManager: ...@@ -272,6 +293,7 @@ class TokenizerManager:
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
): ):
batch_size = obj.batch_size batch_size = obj.batch_size
...@@ -283,7 +305,9 @@ class TokenizerManager: ...@@ -283,7 +305,9 @@ class TokenizerManager:
tmp_obj = obj[i] tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj) tokenized_obj = await self._tokenize_one_request(tmp_obj)
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
generators.append(self._wait_one_response(tmp_obj, request)) generators.append(
self._wait_one_response(tmp_obj, request, created_time)
)
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
else: else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal. # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
...@@ -303,7 +327,9 @@ class TokenizerManager: ...@@ -303,7 +327,9 @@ class TokenizerManager:
tokenized_obj.sampling_params.max_new_tokens = 0 tokenized_obj.sampling_params.max_new_tokens = 0
tokenized_obj.stream = False tokenized_obj.stream = False
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
await self._wait_one_response(tmp_obj, request).__anext__() await self._wait_one_response(
tmp_obj, request, created_time
).__anext__()
# Expand requests, assign new rids for them, and send them # Expand requests, assign new rids for them, and send them
for i in range(batch_size): for i in range(batch_size):
...@@ -312,7 +338,9 @@ class TokenizerManager: ...@@ -312,7 +338,9 @@ class TokenizerManager:
tokenized_obj = copy.copy(tokenized_objs[i]) tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid() tokenized_obj.rid = tmp_obj.regenerate_rid()
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
generators.append(self._wait_one_response(tmp_obj, request)) generators.append(
self._wait_one_response(tmp_obj, request, created_time)
)
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
# Wait for all requests # Wait for all requests
...@@ -524,6 +552,34 @@ class TokenizerManager: ...@@ -524,6 +552,34 @@ class TokenizerManager:
state.finished = recv_obj.finished_reason[i] is not None state.finished = recv_obj.finished_reason[i] is not None
state.event.set() state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(completion_tokens)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
def convert_logprob_style( def convert_logprob_style(
self, self,
ret: dict, ret: dict,
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Utilities for Prometheus Metrics Collection."""
from dataclasses import dataclass
from typing import Dict, Union
@dataclass
class SchedulerStats:
num_running_reqs: int = 0
num_used_tokens: int = 0
token_usage: float = 0.0
gen_throughput: float = 0.0
num_queue_reqs: int = 0
cache_hit_rate: float = 0.0
class SchedulerMetricsCollector:
def __init__(self, labels: Dict[str, str]) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Gauge
self.labels = labels
self.num_running_reqs = Gauge(
name="sglang:num_running_reqs",
documentation="The number of running requests",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.num_used_tokens = Gauge(
name="sglang:num_used_tokens",
documentation="The number of used tokens",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.token_usage = Gauge(
name="sglang:token_usage",
documentation="The token usage",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.gen_throughput = Gauge(
name="sglang:gen_throughput",
documentation="The generate throughput (token/s)",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.num_queue_reqs = Gauge(
name="sglang:num_queue_reqs",
documentation="The number of requests in the waiting queue",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate",
documentation="The cache hit rate",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
self._log_gauge(self.token_usage, stats.token_usage)
self._log_gauge(self.gen_throughput, stats.gen_throughput)
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
class TokenizerMetricsCollector:
def __init__(self, labels: Dict[str, str]) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Counter, Histogram
self.labels = labels
self.prompt_tokens_total = Counter(
name="sglang:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labels.keys(),
)
self.generation_tokens_total = Counter(
name="sglang:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labels.keys(),
)
self.histogram_time_to_first_token = Histogram(
name="sglang:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labels.keys(),
buckets=[
0.001,
0.005,
0.01,
0.02,
0.04,
0.06,
0.08,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
7.5,
10.0,
15.0,
20.0,
25.0,
30.0,
],
)
self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labels.keys(),
buckets=[
0.005,
0.01,
0.015,
0.02,
0.025,
0.03,
0.04,
0.05,
0.075,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
2.5,
],
)
self.histogram_e2e_request_latency = Histogram(
name="sglang:e2e_request_latency_seconds",
documentation="Histogram of End-to-end request latency in seconds",
labelnames=labels.keys(),
buckets=[
0.3,
0.5,
0.8,
1.0,
1.5,
2.0,
2.5,
5.0,
10.0,
15.0,
20.0,
30.0,
40.0,
50.0,
60.0,
],
)
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def _log_counter(self, counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def inc_prompt_tokens(self, value: int):
self._log_counter(self.prompt_tokens_total, value)
def inc_generation_tokens(self, value: int):
self._log_counter(self.generation_tokens_total, value)
def observe_time_to_first_token(self, value: Union[float, int]):
self._log_histogram(self.histogram_time_to_first_token, value)
def observe_time_per_output_token(self, value: Union[float, int]):
self._log_histogram(self.histogram_time_per_output_token, value)
def observe_e2e_request_latency(self, value: Union[float, int]):
self._log_histogram(self.histogram_e2e_request_latency, value)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Records the latency of some functions
"""
import asyncio
import time
from functools import wraps
from typing import Any, Callable, List, Optional
enable_metrics = False
def enable_func_timer():
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Histogram
global enable_metrics, FUNC_LATENCY
enable_metrics = True
FUNC_LATENCY = Histogram(
"sglang:func_latency_seconds",
"Function latency in seconds",
# captures latency in range [50ms - ~50s]
buckets=exponential_buckets(start=0.05, width=1.5, length=18),
labelnames=["name"],
)
FUNC_LATENCY = None
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
def time_func_latency(
func: Callable = None, name: Optional[str] = None
) -> Callable[..., Any]:
"""
A decorator to observe the latency of a function's execution. Supports both sync and async functions.
NOTE: We use our own implementation of a timer decorator since prometheus_client does not support async
context manager yet.
Overhead: The overhead introduced here in case of an async function could likely be because of `await` introduced
which will return in another coroutine object creation and under heavy load could see longer wall time
(scheduling delays due to introduction of another awaitable).
"""
def measure(func: Callable[..., Any]) -> Callable[..., Any]:
nonlocal name
name = name or func.__name__
@wraps(func)
async def async_wrapper(*args, **kwargs):
if not enable_metrics:
return await func(*args, **kwargs)
metric = FUNC_LATENCY
start = time.monotonic()
ret = func(*args, **kwargs)
if isinstance(ret, asyncio.Future) or asyncio.iscoroutine(ret):
try:
ret = await ret
finally:
metric.labels(name=name).observe(time.monotonic() - start)
return ret
@wraps(func)
def sync_wrapper(*args, **kwargs):
if not enable_metrics:
return func(*args, **kwargs)
metric = FUNC_LATENCY
start = time.monotonic()
try:
ret = func(*args, **kwargs)
finally:
metric.labels(name=name).observe(time.monotonic() - start)
return ret
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
if func:
return measure(func)
else:
return measure
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Utilities for Prometheus Metrics Collection."""
import logging
from abc import ABC, abstractmethod
from typing import Counter as CollectionsCounter
from typing import Dict, List, Union
import numpy as np
from prometheus_client import Counter, Gauge, Histogram
from sglang.srt.metrics.metrics_types import Stats
class Metrics:
"""
SGLang Metrics
"""
def __init__(self, labelnames: List[str], max_model_len):
# Configuration Stats
self.max_total_num_tokens = Gauge(
name="sglang:max_total_num_tokens",
documentation="Maximum total number of tokens",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
self.max_prefill_tokens = Gauge(
name="sglang:max_prefill_tokens",
documentation="Maximum prefill tokens",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
self.max_running_requests = Gauge(
name="sglang:max_running_requests",
documentation="Maximum running requests",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
self.context_len = Gauge(
name="sglang:context_len",
documentation="Context length",
labelnames=labelnames,
multiprocess_mode="min",
) # static across processes
# Decode Stats
self.num_running_sys = Gauge(
name="sglang:num_requests_running",
documentation="Number of requests currently running on GPU",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.num_waiting_sys = Gauge(
name="sglang:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.gen_throughput = Gauge(
name="sglang:gen_throughput",
documentation="Gen token throughput (token/s)",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.token_usage = Gauge(
name="sglang:token_usage",
documentation="Total token usage",
labelnames=labelnames,
multiprocess_mode="sum",
)
# System Stats
# KV Cache Usage in %
# self.gpu_cache_usage_sys = Gauge(
# "gpu_cache_usage_perc",
# "GPU KV-cache usage. 1 means 100 percent usage.",
# labelnames=labelnames,
# multiprocess_mode="sum")
self.new_seq = Gauge(
name="sglang:new_seq",
documentation="Number of new sequences",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.new_token = Gauge(
name="sglang:new_token",
documentation="Number of new token",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Prefix caching block hit rate
self.cached_token = Gauge(
name="sglang:cached_token",
documentation="Number of cached token",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate",
documentation="Cache hit rate",
labelnames=labelnames,
multiprocess_mode="sum",
)
self.queue_req = Gauge(
name="sglang:queue_req",
documentation="Number of queued requests",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Iteration stats
self.counter_prompt_tokens = Counter(
name="sglang:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
)
self.counter_generation_tokens = Counter(
name="sglang:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
)
self.histogram_time_to_first_token = Histogram(
name="sglang:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001,
0.005,
0.01,
0.02,
0.04,
0.06,
0.08,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
7.5,
10.0,
15.0,
20.0,
25.0,
30.0,
],
)
self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.005,
0.01,
0.015,
0.02,
0.025,
0.03,
0.04,
0.05,
0.075,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
2.5,
],
)
# Request Stats
# Metadata
self.num_prompt_tokens_requests = Histogram(
name="sglang:request_prompt_tokens",
documentation="Number of prefill tokens processed",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.num_generation_tokens_requests = Histogram(
name="sglang:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.finished_reason_requests = Counter(
name="sglang:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"],
)
self.histogram_time_e2e_requests = Histogram(
name="sglang:e2e_request_latency_seconds",
documentation="Histogram of End-to-end request latency in seconds",
labelnames=labelnames,
buckets=[
0.3,
0.5,
0.8,
1.0,
1.5,
2.0,
2.5,
5.0,
10.0,
15.0,
20.0,
30.0,
40.0,
50.0,
60.0,
],
)
self.histogram_time_waiting_requests = Histogram(
name="sglang:waiting_request_latency_seconds",
documentation="Histogram of request waiting time in seconds",
labelnames=labelnames,
buckets=[
0.3,
0.5,
0.8,
1.0,
1.5,
2.0,
2.5,
5.0,
10.0,
15.0,
20.0,
30.0,
40.0,
50.0,
60.0,
],
)
self.histogram_time_decode_requests = Histogram(
name="sglang:decode_request_latency_seconds",
documentation="Histogram of request decoding time in seconds",
labelnames=labelnames,
buckets=[
0.3,
0.5,
0.8,
1.0,
1.5,
2.0,
2.5,
5.0,
10.0,
15.0,
20.0,
30.0,
40.0,
50.0,
60.0,
],
)
class MetricsCollector(ABC):
"""
SGLang Metrics Collector
"""
@abstractmethod
def log_stats(self, stats: Stats) -> None:
pass
class PrometheusMetricsCollector(MetricsCollector):
"""
SGLang Metrics Collector
"""
def __init__(self, labels: Dict[str, str], max_model_len: int) -> None:
self.labels = labels
self.metrics = Metrics(
labelnames=list(labels.keys()), max_model_len=max_model_len
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(
self, counter, data: CollectionsCounter, label_key: str
) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def log_stats(self, stats: Stats) -> None:
self._log_gauge(self.metrics.max_total_num_tokens, stats.max_total_num_tokens)
self._log_gauge(self.metrics.max_prefill_tokens, stats.max_prefill_tokens)
self._log_gauge(self.metrics.max_running_requests, stats.max_running_requests)
self._log_gauge(self.metrics.context_len, stats.context_len)
self._log_histogram(
self.metrics.num_prompt_tokens_requests, stats.num_prompt_tokens_requests
)
self._log_histogram(
self.metrics.num_generation_tokens_requests,
stats.num_generation_tokens_requests,
)
self._log_counter(
self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter
)
self._log_counter(
self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
)
self._log_histogram(
self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
)
self._log_histogram(
self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter,
)
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)
self._log_gauge(self.metrics.num_waiting_sys, stats.num_waiting_req)
self._log_gauge(self.metrics.gen_throughput, stats.gen_throughput)
self._log_gauge(self.metrics.token_usage, stats.token_usage)
self._log_histogram(
self.metrics.histogram_time_e2e_requests, stats.time_e2e_requests
)
self._log_histogram(
self.metrics.histogram_time_waiting_requests, stats.time_waiting_requests
)
self._log_histogram(
self.metrics.histogram_time_decode_requests, stats.time_decode_requests
)
self._log_gauge(self.metrics.new_seq, stats.new_seq)
self._log_gauge(self.metrics.new_token, stats.new_token)
self._log_gauge(self.metrics.cached_token, stats.cached_token)
self._log_gauge(self.metrics.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.metrics.queue_req, stats.queue_req)
def build_1_2_5_buckets(max_value: int) -> List[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst = [1, 2, 5]
exponent = 0
buckets: List[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Metrics Types"""
from dataclasses import dataclass, field
from typing import List
@dataclass
class Stats:
# config
max_total_num_tokens: int = 0
max_prefill_tokens: int = 0
max_running_requests: int = 0
context_len: int = 0
# request stats
num_prompt_tokens_requests: List[int] = field(default_factory=list)
num_generation_tokens_requests: List[int] = field(default_factory=list)
finished_reason_requests: List[str] = field(default_factory=list)
# decode stats
num_running_req: int = 0
num_waiting_req: int = 0
gen_throughput: float = 0.0
waiting_queue: int = 0
time_e2e_requests: List[float] = field(default_factory=list)
time_waiting_requests: List[float] = field(default_factory=list)
time_decode_requests: List[float] = field(default_factory=list)
# system stats
token_usage: float = 0.0
new_seq: int = 0
new_token: int = 0
cached_token: int = 0
cache_hit_rate: float = 0.0
running_req: int = 0
queue_req: int = 0
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int = 0
num_generation_tokens_iter: int = 0
time_to_first_tokens_iter: List[float] = field(default_factory=list)
time_per_output_tokens_iter: List[float] = field(default_factory=list)
...@@ -56,6 +56,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -56,6 +56,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
v1_batches, v1_batches,
...@@ -196,6 +197,7 @@ async def get_memory_pool_size(): ...@@ -196,6 +197,7 @@ async def get_memory_pool_size():
@app.post("/update_weights") @app.post("/update_weights")
@time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server.""" """Update the weights inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights(obj, request) success, message = await tokenizer_manager.update_weights(obj, request)
...@@ -212,7 +214,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ...@@ -212,7 +214,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
) )
# fastapi implicitly converts json in the request to obj (dataclass) @time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request.""" """Handle a generate request."""
if obj.stream: if obj.stream:
...@@ -245,10 +247,12 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -245,10 +247,12 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) )
# fastapi implicitly converts json in the request to obj (dataclass)
app.post("/generate")(generate_request) app.post("/generate")(generate_request)
app.put("/generate")(generate_request) app.put("/generate")(generate_request)
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request.""" """Handle an embedding request."""
try: try:
...@@ -264,6 +268,7 @@ app.post("/encode")(encode_request) ...@@ -264,6 +268,7 @@ app.post("/encode")(encode_request)
app.put("/encode")(encode_request) app.put("/encode")(encode_request)
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request): async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models.""" """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try: try:
...@@ -283,16 +288,19 @@ app.put("/classify")(classify_request) ...@@ -283,16 +288,19 @@ app.put("/classify")(classify_request)
@app.post("/v1/completions") @app.post("/v1/completions")
@time_func_latency
async def openai_v1_completions(raw_request: Request): async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request) return await v1_completions(tokenizer_manager, raw_request)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
@time_func_latency
async def openai_v1_chat_completions(raw_request: Request): async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request) return await v1_chat_completions(tokenizer_manager, raw_request)
@app.post("/v1/embeddings", response_class=ORJSONResponse) @app.post("/v1/embeddings", response_class=ORJSONResponse)
@time_func_latency
async def openai_v1_embeddings(raw_request: Request): async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(tokenizer_manager, raw_request) response = await v1_embeddings(tokenizer_manager, raw_request)
return response return response
...@@ -455,6 +463,7 @@ def launch_server( ...@@ -455,6 +463,7 @@ def launch_server(
# add prometheus middleware # add prometheus middleware
if server_args.enable_metrics: if server_args.enable_metrics:
add_prometheus_middleware(app) add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request # Send a warmup request
t = threading.Thread( t = threading.Thread(
...@@ -492,6 +501,10 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -492,6 +501,10 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
# Set prometheus env vars
if server_args.enable_metrics:
set_prometheus_multiproc_dir()
# Set ulimit # Set ulimit
set_ulimit() set_ulimit()
...@@ -510,10 +523,6 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -510,10 +523,6 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
# Set prometheus env vars
if server_args.enable_metrics:
set_prometheus_multiproc_dir()
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
......
...@@ -781,6 +781,7 @@ def set_prometheus_multiproc_dir(): ...@@ -781,6 +781,7 @@ def set_prometheus_multiproc_dir():
def add_prometheus_middleware(app): def add_prometheus_middleware(app):
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry() registry = CollectorRegistry()
......
...@@ -22,23 +22,41 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -22,23 +22,41 @@ class TestEnableMetrics(unittest.TestCase):
) )
try: try:
# Make a request to generate some metrics # Make some requests to generate some metrics
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"stream": True,
},
stream=True,
)
for _ in response.iter_lines(decode_unicode=False):
pass
# Get metrics # Get metrics
metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics")
self.assertEqual(metrics_response.status_code, 200) self.assertEqual(metrics_response.status_code, 200)
metrics_content = metrics_response.text metrics_content = metrics_response.text
print(f"{metrics_content=}") print(f"metrics_content=\n{metrics_content}")
# Verify essential metrics are present # Verify essential metrics are present
essential_metrics = [ essential_metrics = [
"sglang:num_running_reqs",
"sglang:token_usage",
"sglang:gen_throughput",
"sglang:cache_hit_rate",
"sglang:func_latency_seconds",
"sglang:prompt_tokens_total", "sglang:prompt_tokens_total",
"sglang:generation_tokens_total", "sglang:generation_tokens_total",
"sglang:max_total_num_tokens",
"sglang:context_len",
"sglang:time_to_first_token_seconds", "sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds", "sglang:time_per_output_token_seconds",
"sglang:e2e_request_latency_seconds", "sglang:e2e_request_latency_seconds",
...@@ -50,6 +68,7 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -50,6 +68,7 @@ class TestEnableMetrics(unittest.TestCase):
# Verify model name label is present and correct # Verify model name label is present and correct
expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
self.assertIn(f'model_name="{expected_model_name}"', metrics_content) self.assertIn(f'model_name="{expected_model_name}"', metrics_content)
# Verify metrics have values (not empty) # Verify metrics have values (not empty)
self.assertIn("_sum{", metrics_content) self.assertIn("_sum{", metrics_content)
self.assertIn("_count{", metrics_content) self.assertIn("_count{", metrics_content)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment