Unverified Commit 57728909 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add model label for vllm backend metrics (#2474)


Co-authored-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent 41a617f8
...@@ -150,8 +150,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -150,8 +150,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# (temp reason): we don't support re-routing prefill requests # (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is # (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished # only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True), generate_endpoint.serve_endpoint(
clear_endpoint.serve_endpoint(handler.clear_kv_blocks), handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", config.model)],
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
...@@ -178,7 +184,11 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -178,7 +184,11 @@ async def init(runtime: DistributedRuntime, config: Config):
.client() .client()
) )
factory = StatLoggerFactory(component, config.engine_args.data_parallel_rank or 0) factory = StatLoggerFactory(
component,
config.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", config.model)],
)
engine_client, vllm_config, default_sampling_params = setup_vllm_engine( engine_client, vllm_config, default_sampling_params = setup_vllm_engine(
config, factory config, factory
) )
...@@ -239,8 +249,14 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -239,8 +249,14 @@ async def init(runtime: DistributedRuntime, config: Config):
await asyncio.gather( await asyncio.gather(
# for decode, we want to transfer the in-flight requests to other decode engines, # for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs # because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False), generate_endpoint.serve_endpoint(
clear_endpoint.serve_endpoint(handler.clear_kv_blocks), handler.generate,
graceful_shutdown=False,
metrics_labels=[("model", config.model)],
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import List, Optional, Tuple
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase from vllm.v1.metrics.loggers import StatLoggerBase
...@@ -36,9 +36,16 @@ class NullStatLogger(StatLoggerBase): ...@@ -36,9 +36,16 @@ class NullStatLogger(StatLoggerBase):
class DynamoStatLoggerPublisher(StatLoggerBase): class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface.""" """Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None: def __init__(
self,
component: Component,
dp_rank: int,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.inner = WorkerMetricsPublisher() self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component) # Use labels directly for the new create_endpoint signature
metrics_labels = metrics_labels or []
self.inner.create_endpoint(component, metrics_labels)
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.num_gpu_block = 1 self.num_gpu_block = 1
self.request_total_slots = 1 self.request_total_slots = 1
...@@ -129,15 +136,23 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -129,15 +136,23 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
class StatLoggerFactory: class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM.""" """Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component, dp_rank: int = 0) -> None: def __init__(
self,
component: Component,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase: def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank: if self.dp_rank != dp_rank:
return NullStatLogger() return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank) logger = DynamoStatLoggerPublisher(
self.component, dp_rank, metrics_labels=self.metrics_labels
)
self.created_logger = logger self.created_logger = logger
return logger return logger
......
...@@ -140,7 +140,9 @@ class VllmBaseWorker: ...@@ -140,7 +140,9 @@ class VllmBaseWorker:
# Create vLLM engine with metrics logger and KV event publisher attached # Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory( self.stats_logger = StatLoggerFactory(
component, self.engine_args.data_parallel_rank or 0 component,
self.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", self.engine_args.model)],
) )
self.engine_client = AsyncLLM.from_vllm_config( self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -353,7 +355,9 @@ class VllmPDWorker(VllmBaseWorker): ...@@ -353,7 +355,9 @@ class VllmPDWorker(VllmBaseWorker):
extra_args.pop("serialized_request", None) extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request) logger.debug("Decode request: %s", decode_request)
async for decode_response in await self.decode_worker_client.round_robin( async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json() decode_request.model_dump_json()
): ):
output = MyRequestOutput.model_validate_json(decode_response.data()) output = MyRequestOutput.model_validate_json(decode_response.data())
......
...@@ -513,19 +513,24 @@ impl Component { ...@@ -513,19 +513,24 @@ impl Component {
#[pymethods] #[pymethods]
impl Endpoint { impl Endpoint {
#[pyo3(signature = (generator, graceful_shutdown = true))] #[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None))]
fn serve_endpoint<'p>( fn serve_endpoint<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
generator: PyObject, generator: PyObject,
graceful_shutdown: Option<bool>, graceful_shutdown: Option<bool>,
metrics_labels: Option<Vec<(String, String)>>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new( let engine = Arc::new(engine::PythonAsyncEngine::new(
generator, generator,
self.event_loop.clone(), self.event_loop.clone(),
)?); )?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?; let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let builder = self.inner.endpoint_builder().handler(ingress); let builder = self
.inner
.endpoint_builder()
.metrics_labels(metrics_labels)
.handler(ingress);
let graceful_shutdown = graceful_shutdown.unwrap_or(true); let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder builder
......
...@@ -55,17 +55,35 @@ impl WorkerMetricsPublisher { ...@@ -55,17 +55,35 @@ impl WorkerMetricsPublisher {
}) })
} }
#[pyo3(signature = (component))] #[pyo3(signature = (component, metrics_labels = None))]
fn create_endpoint<'p>( fn create_endpoint<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
component: Component, component: Component,
metrics_labels: Option<Vec<(String, String)>>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone(); let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone(); let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Convert Python labels to Option<&[(&str, &str)]> expected by Rust API
let metrics_labels_ref: Option<Vec<(&str, &str)>> =
if let Some(metrics_labels) = metrics_labels.as_ref() {
if metrics_labels.is_empty() {
None
} else {
Some(
metrics_labels
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect(),
)
}
} else {
None
};
rs_publisher rs_publisher
.create_endpoint(rs_component) .create_endpoint(rs_component, metrics_labels_ref.as_deref())
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
......
...@@ -498,7 +498,11 @@ impl WorkerMetricsPublisher { ...@@ -498,7 +498,11 @@ impl WorkerMetricsPublisher {
self.tx.send(metrics) self.tx.send(metrics)
} }
pub async fn create_endpoint(&self, component: Component) -> Result<()> { pub async fn create_endpoint(
&self,
component: Component,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let mut metrics_rx = self.rx.clone(); let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?; let handler = Ingress::for_engine(handler)?;
...@@ -514,6 +518,12 @@ impl WorkerMetricsPublisher { ...@@ -514,6 +518,12 @@ impl WorkerMetricsPublisher {
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id); self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
let metrics_labels = metrics_labels.map(|v| {
v.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<Vec<_>>()
});
component component
.endpoint(KV_METRICS_ENDPOINT) .endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder() .endpoint_builder()
...@@ -521,6 +531,7 @@ impl WorkerMetricsPublisher { ...@@ -521,6 +531,7 @@ impl WorkerMetricsPublisher {
let metrics = metrics_rx.borrow_and_update().clone(); let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap() serde_json::to_value(&*metrics).unwrap()
}) })
.metrics_labels(metrics_labels)
.handler(handler) .handler(handler)
.start() .start()
.await .await
......
...@@ -189,7 +189,7 @@ impl MockVllmEngine { ...@@ -189,7 +189,7 @@ impl MockVllmEngine {
tokio::spawn({ tokio::spawn({
let publisher = metrics_publisher.clone(); let publisher = metrics_publisher.clone();
async move { async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await { if let Err(e) = publisher.create_endpoint(comp.clone(), None).await {
tracing::error!("Metrics endpoint failed: {e}"); tracing::error!("Metrics endpoint failed: {e}");
} }
} }
......
...@@ -16,7 +16,6 @@ use std::sync::Arc; ...@@ -16,7 +16,6 @@ use std::sync::Arc;
pub const DEFAULT_NAMESPACE: &str = "dyn_example_namespace"; pub const DEFAULT_NAMESPACE: &str = "dyn_example_namespace";
pub const DEFAULT_COMPONENT: &str = "dyn_example_component"; pub const DEFAULT_COMPONENT: &str = "dyn_example_component";
pub const DEFAULT_ENDPOINT: &str = "dyn_example_endpoint"; pub const DEFAULT_ENDPOINT: &str = "dyn_example_endpoint";
pub const DEFAULT_MODEL_NAME: &str = "dyn_example_model";
/// Stats structure returned by the endpoint's stats handler /// Stats structure returned by the endpoint's stats handler
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
......
...@@ -41,6 +41,10 @@ pub struct EndpointConfig { ...@@ -41,6 +41,10 @@ pub struct EndpointConfig {
#[builder(default, private)] #[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>, _stats_handler: Option<EndpointStatsHandler>,
/// Additional labels for metrics
#[builder(default, setter(into))]
metrics_labels: Option<Vec<(String, String)>>,
/// Whether to wait for inflight requests to complete during shutdown /// Whether to wait for inflight requests to complete during shutdown
#[builder(default = "true")] #[builder(default = "true")]
graceful_shutdown: bool, graceful_shutdown: bool,
...@@ -59,7 +63,7 @@ impl EndpointConfigBuilder { ...@@ -59,7 +63,7 @@ impl EndpointConfigBuilder {
} }
pub async fn start(self) -> Result<()> { pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler, graceful_shutdown) = let (endpoint, lease, handler, stats_handler, metrics_labels, graceful_shutdown) =
self.build_internal()?.dissolve(); self.build_internal()?.dissolve();
let lease = lease.or(endpoint.drt().primary_lease()); let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0); let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
...@@ -74,8 +78,11 @@ impl EndpointConfigBuilder { ...@@ -74,8 +78,11 @@ impl EndpointConfigBuilder {
// acquire the registry lock // acquire the registry lock
let registry = endpoint.drt().component_registry.inner.lock().await; let registry = endpoint.drt().component_registry.inner.lock().await;
let metrics_labels: Option<Vec<(&str, &str)>> = metrics_labels
.as_ref()
.map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
// Add metrics to the handler. The endpoint provides additional information to the handler. // Add metrics to the handler. The endpoint provides additional information to the handler.
handler.add_metrics(&endpoint)?; handler.add_metrics(&endpoint, metrics_labels.as_deref())?;
// get the group // get the group
let group = registry let group = registry
......
...@@ -300,8 +300,12 @@ impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> { ...@@ -300,8 +300,12 @@ impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
.map_err(|_| anyhow::anyhow!("Segment already set")) .map_err(|_| anyhow::anyhow!("Segment already set"))
} }
pub fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> { pub fn add_metrics(
let metrics = WorkHandlerMetrics::from_endpoint(endpoint) &self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
.map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
self.metrics self.metrics
...@@ -345,7 +349,11 @@ pub trait PushWorkHandler: Send + Sync { ...@@ -345,7 +349,11 @@ pub trait PushWorkHandler: Send + Sync {
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>; async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
/// Add metrics to the handler /// Add metrics to the handler
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()>; fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()>;
} }
/* /*
......
...@@ -54,43 +54,45 @@ impl WorkHandlerMetrics { ...@@ -54,43 +54,45 @@ impl WorkHandlerMetrics {
/// Create WorkHandlerMetrics from an endpoint using its built-in labeling /// Create WorkHandlerMetrics from an endpoint using its built-in labeling
pub fn from_endpoint( pub fn from_endpoint(
endpoint: &crate::component::Endpoint, endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let metrics_labels = metrics_labels.unwrap_or(&[]);
let request_counter = endpoint.create_intcounter( let request_counter = endpoint.create_intcounter(
"requests_total", "requests_total",
"Total number of requests processed by work handler", "Total number of requests processed by work handler",
&[], metrics_labels,
)?; )?;
let request_duration = endpoint.create_histogram( let request_duration = endpoint.create_histogram(
"request_duration_seconds", "request_duration_seconds",
"Time spent processing requests by work handler", "Time spent processing requests by work handler",
&[], metrics_labels,
None, None,
)?; )?;
let inflight_requests = endpoint.create_intgauge( let inflight_requests = endpoint.create_intgauge(
"inflight_requests", "inflight_requests",
"Number of requests currently being processed by work handler", "Number of requests currently being processed by work handler",
&[], metrics_labels,
)?; )?;
let request_bytes = endpoint.create_intcounter( let request_bytes = endpoint.create_intcounter(
"request_bytes_total", "request_bytes_total",
"Total number of bytes received in requests by work handler", "Total number of bytes received in requests by work handler",
&[], metrics_labels,
)?; )?;
let response_bytes = endpoint.create_intcounter( let response_bytes = endpoint.create_intcounter(
"response_bytes_total", "response_bytes_total",
"Total number of bytes sent in responses by work handler", "Total number of bytes sent in responses by work handler",
&[], metrics_labels,
)?; )?;
let error_counter = endpoint.create_intcountervec( let error_counter = endpoint.create_intcountervec(
"errors_total", "errors_total",
"Total number of errors in work handler processing", "Total number of errors in work handler processing",
&["error_type"], &["error_type"],
&[], metrics_labels,
)?; )?;
Ok(Self::new( Ok(Self::new(
...@@ -110,10 +112,14 @@ where ...@@ -110,10 +112,14 @@ where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug, T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug, U: Data + Serialize + MaybeError + std::fmt::Debug,
{ {
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> { fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
// Call the Ingress-specific add_metrics implementation // Call the Ingress-specific add_metrics implementation
use crate::pipeline::network::Ingress; use crate::pipeline::network::Ingress;
Ingress::add_metrics(self, endpoint) Ingress::add_metrics(self, endpoint, metrics_labels)
} }
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> { async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment