Unverified Commit 57728909 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add model label for vllm backend metrics (#2474)


Co-authored-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent 41a617f8
......@@ -150,8 +150,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", config.model)],
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......@@ -178,7 +184,11 @@ async def init(runtime: DistributedRuntime, config: Config):
.client()
)
factory = StatLoggerFactory(component, config.engine_args.data_parallel_rank or 0)
factory = StatLoggerFactory(
component,
config.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", config.model)],
)
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(
config, factory
)
......@@ -239,8 +249,14 @@ async def init(runtime: DistributedRuntime, config: Config):
await asyncio.gather(
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=False,
metrics_labels=[("model", config.model)],
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from typing import List, Optional, Tuple
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase
......@@ -36,9 +36,16 @@ class NullStatLogger(StatLoggerBase):
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None:
def __init__(
self,
component: Component,
dp_rank: int,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
# Use labels directly for the new create_endpoint signature
metrics_labels = metrics_labels or []
self.inner.create_endpoint(component, metrics_labels)
self.dp_rank = dp_rank
self.num_gpu_block = 1
self.request_total_slots = 1
......@@ -129,15 +136,23 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component, dp_rank: int = 0) -> None:
def __init__(
self,
component: Component,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
logger = DynamoStatLoggerPublisher(
self.component, dp_rank, metrics_labels=self.metrics_labels
)
self.created_logger = logger
return logger
......
......@@ -140,7 +140,9 @@ class VllmBaseWorker:
# Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory(
component, self.engine_args.data_parallel_rank or 0
component,
self.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", self.engine_args.model)],
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
......@@ -353,7 +355,9 @@ class VllmPDWorker(VllmBaseWorker):
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for decode_response in await self.decode_worker_client.round_robin(
async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
......
......@@ -513,19 +513,24 @@ impl Component {
#[pymethods]
impl Endpoint {
#[pyo3(signature = (generator, graceful_shutdown = true))]
#[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None))]
fn serve_endpoint<'p>(
&self,
py: Python<'p>,
generator: PyObject,
graceful_shutdown: Option<bool>,
metrics_labels: Option<Vec<(String, String)>>,
) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new(
generator,
self.event_loop.clone(),
)?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let builder = self.inner.endpoint_builder().handler(ingress);
let builder = self
.inner
.endpoint_builder()
.metrics_labels(metrics_labels)
.handler(ingress);
let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder
......
......@@ -55,17 +55,35 @@ impl WorkerMetricsPublisher {
})
}
#[pyo3(signature = (component))]
#[pyo3(signature = (component, metrics_labels = None))]
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
metrics_labels: Option<Vec<(String, String)>>,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Convert Python labels to Option<&[(&str, &str)]> expected by Rust API
let metrics_labels_ref: Option<Vec<(&str, &str)>> =
if let Some(metrics_labels) = metrics_labels.as_ref() {
if metrics_labels.is_empty() {
None
} else {
Some(
metrics_labels
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect(),
)
}
} else {
None
};
rs_publisher
.create_endpoint(rs_component)
.create_endpoint(rs_component, metrics_labels_ref.as_deref())
.await
.map_err(to_pyerr)?;
Ok(())
......
......@@ -498,7 +498,11 @@ impl WorkerMetricsPublisher {
self.tx.send(metrics)
}
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
pub async fn create_endpoint(
&self,
component: Component,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
......@@ -514,6 +518,12 @@ impl WorkerMetricsPublisher {
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
let metrics_labels = metrics_labels.map(|v| {
v.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<Vec<_>>()
});
component
.endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder()
......@@ -521,6 +531,7 @@ impl WorkerMetricsPublisher {
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.metrics_labels(metrics_labels)
.handler(handler)
.start()
.await
......
......@@ -189,7 +189,7 @@ impl MockVllmEngine {
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
if let Err(e) = publisher.create_endpoint(comp.clone(), None).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
......
......@@ -16,7 +16,6 @@ use std::sync::Arc;
pub const DEFAULT_NAMESPACE: &str = "dyn_example_namespace";
pub const DEFAULT_COMPONENT: &str = "dyn_example_component";
pub const DEFAULT_ENDPOINT: &str = "dyn_example_endpoint";
pub const DEFAULT_MODEL_NAME: &str = "dyn_example_model";
/// Stats structure returned by the endpoint's stats handler
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
......
......@@ -41,6 +41,10 @@ pub struct EndpointConfig {
#[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>,
/// Additional labels for metrics
#[builder(default, setter(into))]
metrics_labels: Option<Vec<(String, String)>>,
/// Whether to wait for inflight requests to complete during shutdown
#[builder(default = "true")]
graceful_shutdown: bool,
......@@ -59,7 +63,7 @@ impl EndpointConfigBuilder {
}
pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler, graceful_shutdown) =
let (endpoint, lease, handler, stats_handler, metrics_labels, graceful_shutdown) =
self.build_internal()?.dissolve();
let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
......@@ -74,8 +78,11 @@ impl EndpointConfigBuilder {
// acquire the registry lock
let registry = endpoint.drt().component_registry.inner.lock().await;
let metrics_labels: Option<Vec<(&str, &str)>> = metrics_labels
.as_ref()
.map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
// Add metrics to the handler. The endpoint provides additional information to the handler.
handler.add_metrics(&endpoint)?;
handler.add_metrics(&endpoint, metrics_labels.as_deref())?;
// get the group
let group = registry
......
......@@ -300,8 +300,12 @@ impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
.map_err(|_| anyhow::anyhow!("Segment already set"))
}
pub fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> {
let metrics = WorkHandlerMetrics::from_endpoint(endpoint)
pub fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
.map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
self.metrics
......@@ -345,7 +349,11 @@ pub trait PushWorkHandler: Send + Sync {
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
/// Add metrics to the handler
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()>;
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()>;
}
/*
......
......@@ -54,43 +54,45 @@ impl WorkHandlerMetrics {
/// Create WorkHandlerMetrics from an endpoint using its built-in labeling
pub fn from_endpoint(
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let metrics_labels = metrics_labels.unwrap_or(&[]);
let request_counter = endpoint.create_intcounter(
"requests_total",
"Total number of requests processed by work handler",
&[],
metrics_labels,
)?;
let request_duration = endpoint.create_histogram(
"request_duration_seconds",
"Time spent processing requests by work handler",
&[],
metrics_labels,
None,
)?;
let inflight_requests = endpoint.create_intgauge(
"inflight_requests",
"Number of requests currently being processed by work handler",
&[],
metrics_labels,
)?;
let request_bytes = endpoint.create_intcounter(
"request_bytes_total",
"Total number of bytes received in requests by work handler",
&[],
metrics_labels,
)?;
let response_bytes = endpoint.create_intcounter(
"response_bytes_total",
"Total number of bytes sent in responses by work handler",
&[],
metrics_labels,
)?;
let error_counter = endpoint.create_intcountervec(
"errors_total",
"Total number of errors in work handler processing",
&["error_type"],
&[],
metrics_labels,
)?;
Ok(Self::new(
......@@ -110,10 +112,14 @@ where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> {
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
// Call the Ingress-specific add_metrics implementation
use crate::pipeline::network::Ingress;
Ingress::add_metrics(self, endpoint)
Ingress::add_metrics(self, endpoint, metrics_labels)
}
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment