Unverified Commit fc2c3a3d authored by Yingchun Lai's avatar Yingchun Lai Committed by GitHub
Browse files

metrics: support customer labels specified in request header (#10143)

parent 8f6a1758
......@@ -229,6 +229,9 @@ class CompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens")
@classmethod
def validate_max_tokens_positive(cls, v):
......
......@@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -24,6 +25,14 @@ class OpenAIServingBase(ABC):
def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = (
set(
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
)
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None
)
async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request
......@@ -37,7 +46,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request(
request
request, raw_request
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
......@@ -81,6 +90,7 @@ class OpenAIServingBase(ABC):
def _convert_to_internal_request(
self,
request: OpenAIServingRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
"""Convert OpenAI request to internal format"""
pass
......@@ -154,3 +164,32 @@ class OpenAIServingBase(ABC):
code=status_code,
)
return json.dumps({"error": error.model_dump()})
def extract_customer_labels(self, raw_request):
if (
not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
):
return None
customer_labels = None
header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
)
try:
raw_labels = (
json.loads(raw_request.headers.get(header))
if raw_request and raw_request.headers.get(header)
else None
)
except json.JSONDecodeError as e:
logger.exception(f"Error in request: {e}")
raw_labels = None
if isinstance(raw_labels, dict):
customer_labels = {
label: value
for label, value in raw_labels.items()
if label in self.allowed_custom_labels
}
return customer_labels
......@@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _convert_to_internal_request(
self,
request: ChatCompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, ChatCompletionRequest]:
reasoning_effort = (
request.chat_template_kwargs.pop("reasoning_effort", None)
......@@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
......@@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
customer_labels=customer_labels,
)
return adapted_request, request
......
......@@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
def _convert_to_internal_request(
self,
request: CompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, CompletionRequest]:
"""Convert OpenAI completion request to internal format"""
# NOTE: with openai API, the prompt's logprobs are always not computed
......@@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": prompt}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params,
......@@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
customer_labels=customer_labels,
)
return adapted_request, request
......
......@@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def _convert_to_internal_request(
self,
request: EmbeddingRequest,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
"""Convert OpenAI embedding request to internal format"""
prompt = request.input
......
......@@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase):
return None
def _convert_to_internal_request(
self, request: V1RerankReqInput
self,
request: V1RerankReqInput,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
"""Convert OpenAI rerank request to internal embedding format"""
# Create pairs of [query, document] for each document
......
......@@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase):
def _convert_to_internal_request(
self,
request: ScoringRequest,
raw_request: Request = None,
) -> tuple[ScoringRequest, ScoringRequest]:
"""Convert OpenAI scoring request to internal format"""
# For scoring, we pass the request directly as the tokenizer_manager
......
......@@ -141,6 +141,9 @@ class GenerateReqInput:
# Image gen grpc migration
return_bytes: bool = False
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
......
......@@ -306,12 +306,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# Metrics
if self.enable_metrics:
labels = {
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
}
if server_args.tokenizer_metrics_allowed_customer_labels:
for label in server_args.tokenizer_metrics_allowed_customer_labels:
labels[label] = ""
self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args,
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
labels=labels,
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
......@@ -1036,7 +1040,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return
req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request()
......@@ -1616,6 +1619,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else 0
)
customer_labels = getattr(state.obj, "customer_labels", None)
labels = (
{**self.metrics_collector.labels, **customer_labels}
if customer_labels
else self.metrics_collector.labels
)
if (
state.first_token_time == 0.0
and self.disaggregation_mode != DisaggregationMode.PREFILL
......@@ -1623,7 +1632,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
labels, state.first_token_time - state.created_time
)
else:
num_new_tokens = completion_tokens - state.last_completion_tokens
......@@ -1631,6 +1640,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
new_time = time.time()
interval = new_time - state.last_time
self.metrics_collector.observe_inter_token_latency(
labels,
interval,
num_new_tokens,
)
......@@ -1645,6 +1655,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
or state.obj.sampling_params.get("structural_tag", None)
)
self.metrics_collector.observe_one_finished_request(
labels,
recv_obj.prompt_tokens[i],
completion_tokens,
recv_obj.cached_tokens[i],
......
......@@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================
"""Utilities for Prometheus Metrics Collection."""
import time
from dataclasses import dataclass, field
from enum import Enum
......@@ -812,36 +811,38 @@ class TokenizerMetricsCollector:
buckets=bucket_time_to_first_token,
)
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def observe_one_finished_request(
self,
labels: Dict[str, str],
prompt_tokens: int,
generation_tokens: int,
cached_tokens: int,
e2e_latency: float,
has_grammar: bool,
):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**labels).inc(generation_tokens)
if cached_tokens > 0:
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1)
self.cached_tokens_total.labels(**labels).inc(cached_tokens)
self.num_requests_total.labels(**labels).inc(1)
if has_grammar:
self.num_so_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
self.num_so_requests_total.labels(**labels).inc(1)
self.histogram_e2e_request_latency.labels(**labels).observe(float(e2e_latency))
if self.collect_tokens_histogram:
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
self._log_histogram(self.generation_tokens_histogram, generation_tokens)
def observe_time_to_first_token(self, value: float, label: str = ""):
if label == "batch":
self.histogram_time_to_first_token_offline_batch.labels(
**self.labels
).observe(value)
self.prompt_tokens_histogram.labels(**labels).observe(float(prompt_tokens))
self.generation_tokens_histogram.labels(**labels).observe(
float(generation_tokens)
)
def observe_time_to_first_token(
self, labels: Dict[str, str], value: float, type: str = ""
):
if type == "batch":
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
value
)
else:
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
self.histogram_time_to_first_token.labels(**labels).observe(value)
def check_time_to_first_token_straggler(self, value: float) -> bool:
his = self.histogram_time_to_first_token.labels(**self.labels)
......@@ -856,12 +857,14 @@ class TokenizerMetricsCollector:
return value >= his._upper_bounds[i]
return False
def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
def observe_inter_token_latency(
self, labels: Dict[str, str], internval: float, num_new_tokens: int
):
adjusted_interval = internval / num_new_tokens
# A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
his = self.histogram_inter_token_latency_seconds.labels(**labels)
his._sum.inc(internval)
for i, bound in enumerate(his._upper_bounds):
......
......@@ -205,6 +205,8 @@ class ServerArgs:
show_time_cost: bool = False
enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False
tokenizer_metrics_custom_labels_header: str = "x-customer-labels"
tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None
bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
......@@ -911,6 +913,14 @@ class ServerArgs:
"and cannot be used at the same time. Please use only one of them."
)
if (
not self.tokenizer_metrics_custom_labels_header
and self.tokenizer_metrics_allowed_customer_labels
):
raise ValueError(
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
)
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer
......@@ -1324,6 +1334,21 @@ class ServerArgs:
"to record request metrics separately. This is especially useful when dp_attention is enabled, as "
"otherwise all metrics appear to come from TP 0.",
)
parser.add_argument(
"--tokenizer-metrics-custom-labels-header",
type=str,
default=ServerArgs.tokenizer_metrics_custom_labels_header,
help="Specify the HTTP header for passing customer labels for tokenizer metrics.",
)
parser.add_argument(
"--tokenizer-metrics-allowed-customer-labels",
type=str,
nargs="+",
default=ServerArgs.tokenizer_metrics_allowed_customer_labels,
help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in "
"'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': "
"'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.",
)
parser.add_argument(
"--bucket-time-to-first-token",
type=float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment