".github/vscode:/vscode.git/clone" did not exist on "35fa7129db3e5c00cd8aad5000f30eb6841fca83"
Unverified Commit 4c207e0c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: remove kv metrics scraping and aggregation (#3701)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 2c2f7c7d
......@@ -290,8 +290,7 @@ graph TB
NATS[NATS Server<br/>KV_METRICS_SUBJECT]
end
subgraph "Aggregator Process Components"
AGG[KvMetricsAggregator<br/>llm/src/kv_router/metrics_aggregator.rs]
subgraph "Other Consumers (e.g., KvWorkerMonitor)"
SUB[NATS Subscriber<br/>component/namespace.rs]
end
......@@ -303,7 +302,6 @@ graph TB
style WATCH fill:#ce422b,color:#fff
style PROM1 fill:#ce422b,color:#fff
style NATS fill:#27aae1,color:#fff
style AGG fill:#ce422b,color:#fff
style SUB fill:#ce422b,color:#fff
style SS fill:#6c757d,color:#fff
end
......@@ -316,9 +314,7 @@ graph TB
WMP -->|"tx.send(metrics)"| WATCH
WATCH -->|"publish(KV_METRICS_SUBJECT, LoadEvent)"| NATS
NATS -->|"subscribe_with_type LoadEvent"| SUB
SUB -->|"discover endpoints"| AGG
SS -->|"Worker: gather() from PROM1"| PROM1
SS -->|"Aggregator: scrape_stats()"| AGG
```
##### Method 2: Dynamic Registration - Component View
......
......@@ -146,9 +146,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?;
m.add_class::<llm::kv::ApproxKvIndexer>()?;
m.add_class::<llm::kv::EndpointKvMetrics>()?;
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?;
......
......@@ -49,39 +49,35 @@ impl WorkerMetricsPublisher {
}
#[pyo3(signature = (component, metrics_labels = None))]
#[allow(unused_variables)]
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
metrics_labels: Option<Vec<(String, String)>>,
metrics_labels: Option<Vec<(String, String)>>, // TODO: fully remove this
) -> PyResult<Bound<'p, PyAny>> {
// Emit deprecation warning if metrics_labels is provided
if metrics_labels.is_some() {
let warnings = py.import("warnings")?;
warnings.call_method1(
"warn",
(
"The 'metrics_labels' parameter is deprecated and no longer used. It will be removed in a future version.",
py.get_type::<pyo3::exceptions::PyDeprecationWarning>(),
),
)?;
}
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Convert Python labels to Option<&[(&str, &str)]> expected by Rust API
let metrics_labels_ref: Option<Vec<(&str, &str)>> =
if let Some(metrics_labels) = metrics_labels.as_ref() {
if metrics_labels.is_empty() {
None
} else {
Some(
metrics_labels
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect(),
)
}
} else {
None
};
// Register Prometheus metrics first
rs_publisher
.register_prometheus_metrics(&rs_component)
.map_err(to_pyerr)?;
rs_publisher
.create_endpoint(rs_component, metrics_labels_ref.as_deref())
.create_endpoint(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
......@@ -552,96 +548,6 @@ impl ApproxKvIndexer {
}
}
#[pyclass]
#[derive(Clone)]
pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)]
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub request_active_slots: u64,
#[pyo3(get, set)]
pub request_total_slots: u64,
#[pyo3(get, set)]
pub kv_active_blocks: u64,
#[pyo3(get, set)]
pub kv_total_blocks: u64,
#[pyo3(get, set)]
pub num_requests_waiting: u64,
#[pyo3(get, set)]
pub gpu_cache_usage_perc: f32,
#[pyo3(get, set)]
pub gpu_prefix_cache_hit_rate: f32,
}
#[pyclass]
#[derive(Clone)]
pub(crate) struct AggregatedMetrics {
#[pyo3(get, set)]
pub endpoints: Vec<EndpointKvMetrics>,
#[pyo3(get, set)]
pub load_avg: f64,
#[pyo3(get, set)]
pub load_std: f64,
}
#[pyclass]
pub(crate) struct KvMetricsAggregator {
inner: Arc<llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator>,
}
#[pymethods]
impl KvMetricsAggregator {
#[new]
fn new(component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator::new(
component.inner.clone(),
component.inner.drt().runtime().child_token(),
)
.await;
Ok(Self {
inner: inner.into(),
})
})
}
fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
// TODO: update EndpointKvMetrics to match the new ForwardPassMetrics struct
let endpoints = self.inner.get_endpoints();
let load_avg = endpoints.load_avg;
let load_std = endpoints.load_std;
let endpoint_kv_metrics = endpoints
.endpoints
.into_iter()
.map(|(worker_id, endpoint)| {
let metrics = endpoint.data;
let LoadMetrics::EngineLoadMetrics(fwd_pass_metrics) = metrics else {
panic!("Endpoints do not contain forward pass metrics.");
};
EndpointKvMetrics {
worker_id,
request_active_slots: fwd_pass_metrics.worker_stats.request_active_slots,
request_total_slots: fwd_pass_metrics.worker_stats.request_total_slots,
kv_active_blocks: fwd_pass_metrics.kv_stats.kv_active_blocks,
kv_total_blocks: fwd_pass_metrics.kv_stats.kv_total_blocks,
num_requests_waiting: fwd_pass_metrics.worker_stats.num_requests_waiting,
gpu_cache_usage_perc: fwd_pass_metrics.kv_stats.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: fwd_pass_metrics.kv_stats.gpu_prefix_cache_hit_rate,
}
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics,
load_avg,
load_std,
})
})
}
}
#[pyclass]
pub(crate) struct KvRecorder {
inner: Arc<llm_rs::kv_router::recorder::KvRecorder>,
......
......@@ -464,6 +464,13 @@ class WorkerMetricsPublisher:
"""
Similar to Component.create_service, but only service created through
this method will interact with KV router of the same component.
Args:
component: The component to create the endpoint for
metrics_labels: [DEPRECATED] This parameter is no longer used and will be removed in a future version
.. deprecated::
The metrics_labels parameter is deprecated and has no effect.
"""
def publish(
......@@ -745,31 +752,6 @@ class KvRecorder:
"""
...
class AggregatedMetrics:
"""
A collection of metrics of the endpoints
"""
...
class KvMetricsAggregator:
"""
A metrics aggregator will collect KV metrics of the endpoints.
"""
...
def __init__(self, component: Component) -> None:
"""
Create a `KvMetricsAggregator` object
"""
def get_metrics(self) -> AggregatedMetrics:
"""
Return the aggregated metrics of the endpoints.
"""
...
class KvEventPublisher:
"""
A KV event publisher will publish KV events corresponding to the component.
......
......@@ -5,8 +5,6 @@
import logging
from dynamo._core import AggregatedMetrics as AggregatedMetrics
try:
from dynamo._core import BlockManager as BlockManager
from dynamo._core import KvbmLeader as KvbmLeader
......@@ -23,7 +21,6 @@ from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpService as HttpService
from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouterConfig as KvRouterConfig
......
......@@ -19,17 +19,7 @@ from typing import List
import pytest
from dynamo.llm import (
ApproxKvIndexer,
ForwardPassMetrics,
KvEventPublisher,
KvIndexer,
KvMetricsAggregator,
KvStats,
RadixTree,
WorkerMetricsPublisher,
WorkerStats,
)
from dynamo.llm import ApproxKvIndexer, KvEventPublisher, KvIndexer, RadixTree
from dynamo.runtime import Component, DistributedRuntime
pytestmark = pytest.mark.pre_merge
......@@ -226,80 +216,3 @@ class EventPublisher:
], # block_hashes
)
self.event_id_counter += 1
@pytest.mark.asyncio
@pytest.mark.forked
async def test_metrics_aggregator(distributed_runtime):
namespace = "kv_test"
component = "metrics"
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
# aggregator
metrics_aggregator = KvMetricsAggregator(kv_listener)
# has nothing to aggregate as worker has not started
metrics = await metrics_aggregator.get_metrics()
assert not metrics.endpoints
expected_metrics = {
"request_active_slots": 0,
"request_total_slots": 1024,
"kv_active_blocks": 523,
"kv_total_blocks": 777,
"num_requests_waiting": 10,
"gpu_cache_usage_perc": 0.5,
"gpu_prefix_cache_hit_rate": 0.75,
}
# need 'create_task' to put publisher task in the background
asyncio.create_task(metrics_publisher_task(kv_listener, expected_metrics))
# needs time for publisher to spawn up
# Using shorter intervals for faster detection in normal cases
for i in range(20): # Try up to 20 times (10 seconds total)
await asyncio.sleep(0.5) # Wait 500ms between retries
metrics = await metrics_aggregator.get_metrics()
if metrics.endpoints:
break
assert metrics.endpoints, f"No metrics endpoints found after {(i+1)*0.5}s"
for endpoint in metrics.endpoints:
# [TODO] not really checking id for now, can't get it as create_endpoint()
# create and serve the endpoint internally
assert endpoint.worker_id != 0
assert endpoint.request_active_slots == expected_metrics["request_active_slots"]
assert endpoint.request_total_slots == expected_metrics["request_total_slots"]
assert endpoint.kv_active_blocks == expected_metrics["kv_active_blocks"]
assert endpoint.kv_total_blocks == expected_metrics["kv_total_blocks"]
async def metrics_publisher_task(kv_listener, expected_metrics):
# Construct the structured ForwardPassMetrics payload expected by the
# current Rust bindings instead of passing the individual scalar values
# directly. The API for `WorkerMetricsPublisher.publish`
# changed from a list of positional scalars to a single
# `ForwardPassMetrics` object.
metrics_publisher = WorkerMetricsPublisher()
worker_stats = WorkerStats(
expected_metrics["request_active_slots"],
expected_metrics["request_total_slots"],
expected_metrics["num_requests_waiting"],
0, # data_parallel_rank (0 = DP not enabled)
)
kv_stats = KvStats(
expected_metrics["kv_active_blocks"],
expected_metrics["kv_total_blocks"],
expected_metrics["gpu_cache_usage_perc"],
expected_metrics["gpu_prefix_cache_hit_rate"],
)
metrics = ForwardPassMetrics(worker_stats, kv_stats, None)
# Publish and expose the metrics via the endpoint so that the aggregator
# test can discover them.
metrics_publisher.publish(metrics)
await metrics_publisher.create_endpoint(kv_listener)
......@@ -22,7 +22,6 @@ use serde::{Deserialize, Serialize};
pub mod approx;
pub mod indexer;
pub mod metrics_aggregator;
pub mod protocols;
pub mod publisher;
pub mod recorder;
......@@ -42,7 +41,6 @@ use crate::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scoring::ProcessedEndpoints,
subscriber::start_kv_router_background,
},
local_model::runtime_config::ModelRuntimeConfig,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Once;
use crate::kv_router::KV_METRICS_ENDPOINT;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::ProcessedEndpoints;
use crate::kv_router::scoring::Endpoint;
use dynamo_runtime::component::Component;
use dynamo_runtime::{Result, service::EndpointInfo, utils::Duration};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
static METRICS_WAITING_MESSAGE: Once = Once::new();
static METRICS_FOUND_MESSAGE: Once = Once::new();
pub struct EndpointCollector {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
}
impl EndpointCollector {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task(
component.clone(),
watch_tx,
cancellation_token.clone(),
"generate".to_string(),
));
Self {
service_name: component.service_name(),
endpoints_rx: watch_rx,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
self.endpoints_rx.borrow().clone()
}
pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
self.endpoints_rx.clone()
}
}
pub struct KvMetricsAggregator {
pub service_name: String,
pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
}
impl KvMetricsAggregator {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
tokio::spawn(collect_endpoints_task(
component.clone(),
watch_tx,
cancellation_token.clone(),
KV_METRICS_ENDPOINT.to_string(),
));
Self {
service_name: component.service_name(),
endpoints_rx: watch_rx,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
self.endpoints_rx.borrow().clone()
}
pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
self.endpoints_rx.clone()
}
}
/// [gluo TODO] 'collect_endpoints' is from component/metrics,
/// should consolidate these functions into generic metrics aggregator
/// functions and shared by KvMetricsAggregator and component/metrics.
/// Collect endpoints from a component
pub async fn collect_endpoints(
component: &Component,
subject: &str,
timeout: Duration,
) -> Result<Vec<EndpointInfo>> {
// Collect stats from each backend
let stream = component.scrape_stats(timeout).await?;
// Filter the stats by the service subject
let endpoints = stream
.into_endpoints()
.filter(|e| e.subject.starts_with(subject))
.collect::<Vec<_>>();
if endpoints.is_empty() {
// Only print it once, we poll while the worker starts
METRICS_WAITING_MESSAGE.call_once(|| {
tracing::debug!("Waiting for metrics endpoint..");
});
} else {
METRICS_FOUND_MESSAGE.call_once(|| {
tracing::debug!("Found metrics endpoint");
});
}
Ok(endpoints)
}
pub async fn collect_endpoints_task(
component: Component,
watch_tx: watch::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
subject: String,
) {
let backoff_delay = Duration::from_millis(100);
let scrape_timeout = Duration::from_millis(300);
let endpoint = component.endpoint(&subject);
let service_subject = endpoint.subject();
// Keep track of the last sent value to avoid unnecessary updates
let mut last_sent: Option<ProcessedEndpoints> = None;
loop {
tokio::select! {
_ = cancel.cancelled() => {
break;
}
_ = tokio::time::sleep(backoff_delay) => {
tracing::trace!("collecting endpoints for service: {}", service_subject);
let unfiltered_endpoints =
match collect_endpoints(&component, &service_subject, scrape_timeout).await
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_subject, e);
continue;
}
};
let endpoints: Vec<Endpoint> = if subject == KV_METRICS_ENDPOINT {
// Original filtering behavior
unfiltered_endpoints
.into_iter()
.filter_map(|s| {
s.data?
.decode::<ForwardPassMetrics>()
.map(|data| Endpoint {
name: s.name,
subject: s.subject,
data: LoadMetrics::EngineLoadMetrics(data),
})
.inspect_err(|e| {
tracing::warn!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
})
.ok()
})
.collect()
} else {
// No filtering - just use default LoadMetrics
unfiltered_endpoints
.into_iter()
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: LoadMetrics::default(),
})
.collect()
};
tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
let processed = ProcessedEndpoints::new(endpoints);
// Only send if different from last sent value
// This is necessary because the watch channel does not track changes
// https://docs.rs/tokio/latest/tokio/sync/watch/struct.Receiver.html#method.has_changed
let should_send = match &last_sent {
Some(last) => last != &processed,
None => true,
};
if should_send {
tracing::trace!("Endpoints changed, sending update for service: {service_subject}");
if watch_tx.send(processed.clone()).is_err() {
tracing::error!("failed to send processed endpoints; shutting down");
break;
}
last_sent = Some(processed);
} else {
tracing::trace!("Endpoints unchanged, skipping update for service: {service_subject}");
}
}
}
}
}
......@@ -2,25 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT,
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*,
scoring::LoadEvent,
};
use async_trait::async_trait;
use dynamo_runtime::metrics::{MetricsRegistry, prometheus_names::kvstats};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::{
Error, Result,
Result,
component::{Component, Namespace},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
network::Ingress,
},
protocols::annotated::Annotated,
transports::nats::{NatsQueue, QUEUE_NAME, Slug},
};
use futures::stream;
use std::sync::{Arc, OnceLock};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
......@@ -790,15 +783,7 @@ impl WorkerMetricsPublisher {
Ok(())
}
pub async fn create_endpoint(
&self,
component: Component,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let worker_id = component
.drt()
.primary_lease()
......@@ -809,31 +794,13 @@ impl WorkerMetricsPublisher {
});
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
let metrics_labels = metrics_labels.map(|v| {
v.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<Vec<_>>()
});
component
.endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder()
.stats_handler(move |_| {
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.metrics_labels(metrics_labels)
.handler(handler)
.start()
.await
Ok(())
}
/// Starts a background task to publish metrics over NATS
///
/// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting)
/// and publishes stable metrics to NATS after they've been unchanged for 1ms.
#[allow(dead_code)]
fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: i64) {
let nats_rx = self.rx.clone();
......@@ -907,32 +874,6 @@ impl WorkerMetricsPublisher {
}
}
struct KvLoadEndpointHandler {
metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvLoadEndpointHandler {
pub fn new(metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>) -> Self {
Self { metrics_rx }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<()>, ManyOut<Annotated<ForwardPassMetrics>>, Error>
for KvLoadEndpointHandler
{
async fn generate(
&self,
request: SingleIn<()>,
) -> Result<ManyOut<Annotated<ForwardPassMetrics>>> {
let context = request.context();
let metrics = self.metrics_rx.borrow().clone();
let metrics = (*metrics).clone();
let stream = stream::iter(vec![Annotated::from_data(metrics)]);
Ok(ResponseStream::new(Box::pin(stream), context))
}
}
// -------------------------------------------------------------------------
// Testing -----------------------------------------------------------------
// -------------------------------------------------------------------------
......
......@@ -184,7 +184,7 @@ impl MockVllmEngine {
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone(), None).await {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
......@@ -715,36 +715,6 @@ mod integration_tests {
}
}
// Use KvMetricsAggregator to get metrics more easily
let cancel_token = test_component.drt().runtime().child_token();
let metrics_aggregator = crate::kv_router::metrics_aggregator::KvMetricsAggregator::new(
test_component.clone(),
cancel_token,
)
.await;
tokio::time::sleep(Duration::from_millis(500)).await;
let processed_endpoints = metrics_aggregator.get_endpoints();
tracing::info!(
"Found {} metrics endpoints",
processed_endpoints.endpoints.len()
);
// Verify we found at least one metrics endpoint
assert!(
!processed_endpoints.endpoints.is_empty(),
"Should find at least one metrics endpoint"
);
tracing::info!(
"✓ Successfully found {} metrics endpoints",
processed_endpoints.endpoints.len()
);
// Verify the metrics endpoints contain valid data
for (worker_id, endpoint) in &processed_endpoints.endpoints {
tracing::info!("✓ Worker {} metrics: {:?}", worker_id, endpoint.data);
}
tracing::info!("🎉 Event verification completed!");
// Cleanup
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment