"src/vscode:/vscode.git/clone" did not exist on "d74561da2c6f531c6a6061e3582f9fda4fc70500"
Unverified Commit fc2c3a3d authored by Yingchun Lai's avatar Yingchun Lai Committed by GitHub
Browse files

metrics: support customer labels specified in request header (#10143)

parent 8f6a1758
...@@ -229,6 +229,9 @@ class CompletionRequest(BaseModel): ...@@ -229,6 +229,9 @@ class CompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens") @field_validator("max_tokens")
@classmethod @classmethod
def validate_max_tokens_positive(cls, v): def validate_max_tokens_positive(cls, v):
......
...@@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse ...@@ -11,6 +11,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -24,6 +25,14 @@ class OpenAIServingBase(ABC): ...@@ -24,6 +25,14 @@ class OpenAIServingBase(ABC):
def __init__(self, tokenizer_manager: TokenizerManager): def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = (
set(
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
)
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None
)
async def handle_request( async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request self, request: OpenAIServingRequest, raw_request: Request
...@@ -37,7 +46,7 @@ class OpenAIServingBase(ABC): ...@@ -37,7 +46,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format # Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request( adapted_request, processed_request = self._convert_to_internal_request(
request request, raw_request
) )
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
...@@ -81,6 +90,7 @@ class OpenAIServingBase(ABC): ...@@ -81,6 +90,7 @@ class OpenAIServingBase(ABC):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: OpenAIServingRequest, request: OpenAIServingRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, OpenAIServingRequest]: ) -> tuple[GenerateReqInput, OpenAIServingRequest]:
"""Convert OpenAI request to internal format""" """Convert OpenAI request to internal format"""
pass pass
...@@ -154,3 +164,32 @@ class OpenAIServingBase(ABC): ...@@ -154,3 +164,32 @@ class OpenAIServingBase(ABC):
code=status_code, code=status_code,
) )
return json.dumps({"error": error.model_dump()}) return json.dumps({"error": error.model_dump()})
def extract_customer_labels(self, raw_request):
if (
not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
):
return None
customer_labels = None
header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
)
try:
raw_labels = (
json.loads(raw_request.headers.get(header))
if raw_request and raw_request.headers.get(header)
else None
)
except json.JSONDecodeError as e:
logger.exception(f"Error in request: {e}")
raw_labels = None
if isinstance(raw_labels, dict):
customer_labels = {
label: value
for label, value in raw_labels.items()
if label in self.allowed_custom_labels
}
return customer_labels
...@@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -96,6 +96,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, ChatCompletionRequest]: ) -> tuple[GenerateReqInput, ChatCompletionRequest]:
reasoning_effort = ( reasoning_effort = (
request.chat_template_kwargs.pop("reasoning_effort", None) request.chat_template_kwargs.pop("reasoning_effort", None)
...@@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -127,6 +128,9 @@ class OpenAIServingChat(OpenAIServingBase):
else: else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids} prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
image_data=processed_messages.image_data, image_data=processed_messages.image_data,
...@@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -145,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
customer_labels=customer_labels,
) )
return adapted_request, request return adapted_request, request
......
...@@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -59,6 +59,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
raw_request: Request = None,
) -> tuple[GenerateReqInput, CompletionRequest]: ) -> tuple[GenerateReqInput, CompletionRequest]:
"""Convert OpenAI completion request to internal format""" """Convert OpenAI completion request to internal format"""
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
...@@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -89,6 +90,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
else: else:
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -103,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
customer_labels=customer_labels,
) )
return adapted_request, request return adapted_request, request
......
...@@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -74,6 +74,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, EmbeddingRequest]: ) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
"""Convert OpenAI embedding request to internal format""" """Convert OpenAI embedding request to internal format"""
prompt = request.input prompt = request.input
......
...@@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase): ...@@ -45,7 +45,9 @@ class OpenAIServingRerank(OpenAIServingBase):
return None return None
def _convert_to_internal_request( def _convert_to_internal_request(
self, request: V1RerankReqInput self,
request: V1RerankReqInput,
raw_request: Request = None,
) -> tuple[EmbeddingReqInput, V1RerankReqInput]: ) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
"""Convert OpenAI rerank request to internal embedding format""" """Convert OpenAI rerank request to internal embedding format"""
# Create pairs of [query, document] for each document # Create pairs of [query, document] for each document
......
...@@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase): ...@@ -25,6 +25,7 @@ class OpenAIServingScore(OpenAIServingBase):
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
request: ScoringRequest, request: ScoringRequest,
raw_request: Request = None,
) -> tuple[ScoringRequest, ScoringRequest]: ) -> tuple[ScoringRequest, ScoringRequest]:
"""Convert OpenAI scoring request to internal format""" """Convert OpenAI scoring request to internal format"""
# For scoring, we pass the request directly as the tokenizer_manager # For scoring, we pass the request directly as the tokenizer_manager
......
...@@ -141,6 +141,9 @@ class GenerateReqInput: ...@@ -141,6 +141,9 @@ class GenerateReqInput:
# Image gen grpc migration # Image gen grpc migration
return_bytes: bool = False return_bytes: bool = False
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
......
...@@ -306,12 +306,16 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -306,12 +306,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# Metrics # Metrics
if self.enable_metrics: if self.enable_metrics:
labels = {
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
}
if server_args.tokenizer_metrics_allowed_customer_labels:
for label in server_args.tokenizer_metrics_allowed_customer_labels:
labels[label] = ""
self.metrics_collector = TokenizerMetricsCollector( self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args, server_args=server_args,
labels={ labels=labels,
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token, bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency, bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency, bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
...@@ -1036,7 +1040,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1036,7 +1040,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return return
req = AbortReq(rid, abort_all) req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics: if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request() self.metrics_collector.observe_one_aborted_request()
...@@ -1616,6 +1619,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1616,6 +1619,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else 0 else 0
) )
customer_labels = getattr(state.obj, "customer_labels", None)
labels = (
{**self.metrics_collector.labels, **customer_labels}
if customer_labels
else self.metrics_collector.labels
)
if ( if (
state.first_token_time == 0.0 state.first_token_time == 0.0
and self.disaggregation_mode != DisaggregationMode.PREFILL and self.disaggregation_mode != DisaggregationMode.PREFILL
...@@ -1623,7 +1632,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1623,7 +1632,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.first_token_time = state.last_time = time.time() state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token( self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time labels, state.first_token_time - state.created_time
) )
else: else:
num_new_tokens = completion_tokens - state.last_completion_tokens num_new_tokens = completion_tokens - state.last_completion_tokens
...@@ -1631,6 +1640,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1631,6 +1640,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
new_time = time.time() new_time = time.time()
interval = new_time - state.last_time interval = new_time - state.last_time
self.metrics_collector.observe_inter_token_latency( self.metrics_collector.observe_inter_token_latency(
labels,
interval, interval,
num_new_tokens, num_new_tokens,
) )
...@@ -1645,6 +1655,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1645,6 +1655,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
or state.obj.sampling_params.get("structural_tag", None) or state.obj.sampling_params.get("structural_tag", None)
) )
self.metrics_collector.observe_one_finished_request( self.metrics_collector.observe_one_finished_request(
labels,
recv_obj.prompt_tokens[i], recv_obj.prompt_tokens[i],
completion_tokens, completion_tokens,
recv_obj.cached_tokens[i], recv_obj.cached_tokens[i],
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utilities for Prometheus Metrics Collection.""" """Utilities for Prometheus Metrics Collection."""
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
...@@ -812,36 +811,38 @@ class TokenizerMetricsCollector: ...@@ -812,36 +811,38 @@ class TokenizerMetricsCollector:
buckets=bucket_time_to_first_token, buckets=bucket_time_to_first_token,
) )
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def observe_one_finished_request( def observe_one_finished_request(
self, self,
labels: Dict[str, str],
prompt_tokens: int, prompt_tokens: int,
generation_tokens: int, generation_tokens: int,
cached_tokens: int, cached_tokens: int,
e2e_latency: float, e2e_latency: float,
has_grammar: bool, has_grammar: bool,
): ):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.generation_tokens_total.labels(**labels).inc(generation_tokens)
if cached_tokens > 0: if cached_tokens > 0:
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) self.cached_tokens_total.labels(**labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1) self.num_requests_total.labels(**labels).inc(1)
if has_grammar: if has_grammar:
self.num_so_requests_total.labels(**self.labels).inc(1) self.num_so_requests_total.labels(**labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) self.histogram_e2e_request_latency.labels(**labels).observe(float(e2e_latency))
if self.collect_tokens_histogram: if self.collect_tokens_histogram:
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens) self.prompt_tokens_histogram.labels(**labels).observe(float(prompt_tokens))
self._log_histogram(self.generation_tokens_histogram, generation_tokens) self.generation_tokens_histogram.labels(**labels).observe(
float(generation_tokens)
def observe_time_to_first_token(self, value: float, label: str = ""): )
if label == "batch":
self.histogram_time_to_first_token_offline_batch.labels( def observe_time_to_first_token(
**self.labels self, labels: Dict[str, str], value: float, type: str = ""
).observe(value) ):
if type == "batch":
self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
value
)
else: else:
self.histogram_time_to_first_token.labels(**self.labels).observe(value) self.histogram_time_to_first_token.labels(**labels).observe(value)
def check_time_to_first_token_straggler(self, value: float) -> bool: def check_time_to_first_token_straggler(self, value: float) -> bool:
his = self.histogram_time_to_first_token.labels(**self.labels) his = self.histogram_time_to_first_token.labels(**self.labels)
...@@ -856,12 +857,14 @@ class TokenizerMetricsCollector: ...@@ -856,12 +857,14 @@ class TokenizerMetricsCollector:
return value >= his._upper_bounds[i] return value >= his._upper_bounds[i]
return False return False
def observe_inter_token_latency(self, internval: float, num_new_tokens: int): def observe_inter_token_latency(
self, labels: Dict[str, str], internval: float, num_new_tokens: int
):
adjusted_interval = internval / num_new_tokens adjusted_interval = internval / num_new_tokens
# A faster version of the Histogram::observe which observes multiple values at the same time. # A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639 # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**self.labels) his = self.histogram_inter_token_latency_seconds.labels(**labels)
his._sum.inc(internval) his._sum.inc(internval)
for i, bound in enumerate(his._upper_bounds): for i, bound in enumerate(his._upper_bounds):
......
...@@ -205,6 +205,8 @@ class ServerArgs: ...@@ -205,6 +205,8 @@ class ServerArgs:
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False enable_metrics_for_all_schedulers: bool = False
tokenizer_metrics_custom_labels_header: str = "x-customer-labels"
tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None
bucket_time_to_first_token: Optional[List[float]] = None bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None bucket_e2e_request_latency: Optional[List[float]] = None
...@@ -911,6 +913,14 @@ class ServerArgs: ...@@ -911,6 +913,14 @@ class ServerArgs:
"and cannot be used at the same time. Please use only one of them." "and cannot be used at the same time. Please use only one of them."
) )
if (
not self.tokenizer_metrics_custom_labels_header
and self.tokenizer_metrics_allowed_customer_labels
):
raise ValueError(
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
)
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer # Model and tokenizer
...@@ -1324,6 +1334,21 @@ class ServerArgs: ...@@ -1324,6 +1334,21 @@ class ServerArgs:
"to record request metrics separately. This is especially useful when dp_attention is enabled, as " "to record request metrics separately. This is especially useful when dp_attention is enabled, as "
"otherwise all metrics appear to come from TP 0.", "otherwise all metrics appear to come from TP 0.",
) )
parser.add_argument(
"--tokenizer-metrics-custom-labels-header",
type=str,
default=ServerArgs.tokenizer_metrics_custom_labels_header,
help="Specify the HTTP header for passing customer labels for tokenizer metrics.",
)
parser.add_argument(
"--tokenizer-metrics-allowed-customer-labels",
type=str,
nargs="+",
default=ServerArgs.tokenizer_metrics_allowed_customer_labels,
help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in "
"'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': "
"'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.",
)
parser.add_argument( parser.add_argument(
"--bucket-time-to-first-token", "--bucket-time-to-first-token",
type=float, type=float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment