Commit 09656f6c authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

feat: Add estimated kv cache hit metric events (#30)

parent a720fa12
...@@ -53,6 +53,8 @@ tensorrtllm_checkpoints/ ...@@ -53,6 +53,8 @@ tensorrtllm_checkpoints/
tensorrtllm_engines/ tensorrtllm_engines/
api_server_models/ api_server_models/
server/ server/
# Replay/Snapshot test artifacts
*.new
**/*backups* **/*backups*
...@@ -75,4 +77,4 @@ __pycache__/ ...@@ -75,4 +77,4 @@ __pycache__/
*$py.class *$py.class
*.so *.so
**/.devcontainer **/.devcontainer
\ No newline at end of file
...@@ -739,6 +739,7 @@ dependencies = [ ...@@ -739,6 +739,7 @@ dependencies = [
"clap", "clap",
"dynemo-llm", "dynemo-llm",
"dynemo-runtime", "dynemo-runtime",
"futures",
"opentelemetry", "opentelemetry",
"opentelemetry-prometheus", "opentelemetry-prometheus",
"prometheus", "prometheus",
......
...@@ -38,6 +38,7 @@ opentelemetry-prometheus = "0.13" ...@@ -38,6 +38,7 @@ opentelemetry-prometheus = "0.13"
prometheus = "0.13" prometheus = "0.13"
rand = "0.8" rand = "0.8"
axum = "0.6" axum = "0.6"
futures = "0.3"
[dev-dependencies] [dev-dependencies]
reqwest = { version = "0.11", features = ["blocking"] } reqwest = { version = "0.11", features = ["blocking"] }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
//! Library functions for the count application. //! Library functions for the count application.
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use prometheus::register_gauge_vec; use prometheus::{register_counter_vec, register_gauge_vec};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::SocketAddr; use std::net::SocketAddr;
...@@ -97,6 +97,18 @@ impl PrometheusMetricsServer { ...@@ -97,6 +97,18 @@ impl PrometheusMetricsServer {
pub fn update(&mut self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) { pub fn update(&mut self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) {
self.metrics.update(config, processed); self.metrics.update(config, processed);
} }
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&mut self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
self.metrics
.update_kv_hit_rate(config, worker_id, isl_blocks, overlap_blocks);
}
} }
/// Prometheus metrics collection /// Prometheus metrics collection
...@@ -107,6 +119,9 @@ pub struct PrometheusMetrics { ...@@ -107,6 +119,9 @@ pub struct PrometheusMetrics {
requests_total: prometheus::GaugeVec, requests_total: prometheus::GaugeVec,
load_avg: prometheus::GaugeVec, load_avg: prometheus::GaugeVec,
load_std: prometheus::GaugeVec, load_std: prometheus::GaugeVec,
// KV hit rate metrics
kv_hit_rate_isl_blocks: prometheus::CounterVec,
kv_hit_rate_overlap_blocks: prometheus::CounterVec,
} }
impl PrometheusMetrics { impl PrometheusMetrics {
...@@ -143,6 +158,19 @@ impl PrometheusMetrics { ...@@ -143,6 +158,19 @@ impl PrometheusMetrics {
"Load standard deviation across workers", "Load standard deviation across workers",
&["component", "endpoint"] &["component", "endpoint"]
)?, )?,
// TODO: The cumulative isl/overlap metrics are monotonically increasing
// and may overflow at some point, we may want to periodically reset them.
// KV hit rate metrics
kv_hit_rate_isl_blocks: register_counter_vec!(
"llm_kv_hit_rate_isl_blocks",
"Cumulative count of ISL blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
kv_hit_rate_overlap_blocks: register_counter_vec!(
"llm_kv_hit_rate_overlap_blocks",
"Cumulative count of overlapping blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
}) })
} }
...@@ -159,6 +187,19 @@ impl PrometheusMetrics { ...@@ -159,6 +187,19 @@ impl PrometheusMetrics {
.set(value); .set(value);
} }
/// Helper method to increment a counter with worker-specific labels (3 labels)
fn increment_worker_counter(
&self,
counter: &prometheus::CounterVec,
config: &LLMWorkerLoadCapacityConfig,
worker_id: &str,
value: f64,
) {
counter
.with_label_values(&[&config.component_name, &config.endpoint_name, worker_id])
.inc_by(value);
}
/// Helper method to set a gauge with component/endpoint labels only (2 labels) /// Helper method to set a gauge with component/endpoint labels only (2 labels)
fn set_endpoint_gauge( fn set_endpoint_gauge(
&self, &self,
...@@ -208,6 +249,61 @@ impl PrometheusMetrics { ...@@ -208,6 +249,61 @@ impl PrometheusMetrics {
self.set_endpoint_gauge(&self.load_avg, config, processed.load_avg); self.set_endpoint_gauge(&self.load_avg, config, processed.load_avg);
self.set_endpoint_gauge(&self.load_std, config, processed.load_std); self.set_endpoint_gauge(&self.load_std, config, processed.load_std);
} }
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
let worker_id_str = worker_id.to_string();
// Increment the ISL blocks and overlap blocks counters
self.increment_worker_counter(
&self.kv_hit_rate_isl_blocks,
config,
&worker_id_str,
isl_blocks as f64,
);
self.increment_worker_counter(
&self.kv_hit_rate_overlap_blocks,
config,
&worker_id_str,
overlap_blocks as f64,
);
// TODO: The cumulative hit rate percentage can probably be computed by consumers
// of Prometheus metrics like Grafana instead, but we'll compute it here for now
// for convenient debugging/logging.
// Calculate and set the cumulative hit rate percentage
let cumulative_isl = self
.kv_hit_rate_isl_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
let cumulative_overlap = self
.kv_hit_rate_overlap_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
if cumulative_isl > 0.0 {
let cumulative_hit_rate = (cumulative_overlap / cumulative_isl) * 100.0;
tracing::info!(
"Estimated Cumulative KV hit rate: {cumulative_hit_rate:.2}% (Overlap: {cumulative_overlap} / ISL: {cumulative_isl})"
);
}
}
} }
/// Collect endpoints from a component /// Collect endpoints from a component
......
...@@ -22,14 +22,20 @@ ...@@ -22,14 +22,20 @@
//! - These metrics will be scraped by the LLM NATS Service API's stats request //! - These metrics will be scraped by the LLM NATS Service API's stats request
//! - Request Slots: [Active, Total] //! - Request Slots: [Active, Total]
//! - KV Cache Blocks: [Active, Total] //! - KV Cache Blocks: [Active, Total]
//! - KV Hit Rate:
//! - These metrics will be collected from KV hit rate events published by the KV router
//! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events
//! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache
use clap::Parser; use clap::Parser;
use dynemo_llm::kv_router::scheduler::KVHitRateEvent;
use dynemo_runtime::{ use dynemo_runtime::{
error, logging, error, logging,
traits::events::EventPublisher, traits::events::{EventPublisher, EventSubscriber},
utils::{Duration, Instant}, utils::{Duration, Instant},
DistributedRuntime, ErrorContext, Result, Runtime, Worker, DistributedRuntime, ErrorContext, Result, Runtime, Worker,
}; };
use futures::stream::StreamExt;
use std::sync::Arc;
// Import from our library // Import from our library
use count::{ use count::{
...@@ -111,8 +117,65 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -111,8 +117,65 @@ async fn app(runtime: Runtime) -> Result<()> {
// TODO: Make metrics host/port configurable // TODO: Make metrics host/port configurable
// Initialize Prometheus metrics and start server // Initialize Prometheus metrics and start server
let mut metrics_server = PrometheusMetricsServer::new()?; let metrics_server = PrometheusMetricsServer::new()?;
metrics_server.start(9091); // Metrics will be updated concurrently, so protect it with a mutex:
// - Main loop: Collect and process ForwardPassMetrics at an interval from endpoint stats handlers
// - Subscription task: Collect and process KVHitRateEvent metrics from the KV router as they are published
let metrics_server = Arc::new(tokio::sync::Mutex::new(metrics_server));
metrics_server.lock().await.start(9091);
// Subscribe to KV hit rate events
let kv_hit_rate_subject = "kv-hit-rate";
tracing::info!("Subscribing to KV hit rate events on subject: {kv_hit_rate_subject}");
// Clone the metrics server and config for the subscription task
let metrics_server_clone = metrics_server.clone();
let config_clone = config.clone();
// Clone the namespace for the subscription task
let namespace_clone = namespace.clone();
// Spawn a task to handle KV hit rate events
tokio::spawn(async move {
match namespace_clone.subscribe(kv_hit_rate_subject).await {
Ok(mut subscriber) => {
tracing::info!("Successfully subscribed to KV hit rate events");
while let Some(msg) = subscriber.next().await {
match serde_json::from_slice::<KVHitRateEvent>(&msg.payload) {
Ok(event) => {
// TODO: Lower to debug
let cache_hit_pct =
(event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0;
tracing::info!(
"Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%",
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
cache_hit_pct
);
// Update metrics with the event data
let mut metrics = metrics_server_clone.lock().await;
metrics.update_kv_hit_rate(
&config_clone,
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
);
}
Err(e) => {
tracing::warn!("Failed to deserialize KV hit rate event: {:?}", e);
}
}
}
tracing::warn!("KV hit rate event subscription stream ended");
}
Err(e) => {
tracing::error!("Failed to subscribe to KV hit rate events: {:?}", e);
}
}
});
loop { loop {
let next = Instant::now() + Duration::from_secs(args.poll_interval); let next = Instant::now() + Duration::from_secs(args.poll_interval);
...@@ -123,12 +186,14 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -123,12 +186,14 @@ async fn app(runtime: Runtime) -> Result<()> {
collect_endpoints(&target_component, &service_subject, scrape_timeout).await?; collect_endpoints(&target_component, &service_subject, scrape_timeout).await?;
let metrics = extract_metrics(&endpoints); let metrics = extract_metrics(&endpoints);
let processed = postprocess_metrics(&metrics, &endpoints); let processed = postprocess_metrics(&metrics, &endpoints);
tracing::info!("Aggregated metrics: {processed:?}"); tracing::debug!("Aggregated metrics: {processed:?}");
// Update Prometheus metrics // Update Prometheus metrics
metrics_server.update(&config, &processed); metrics_server.lock().await.update(&config, &processed);
// TODO: Who needs to consume these events? // TODO: Enable KV Routers to subscribe to metrics events published here
// for a single view of the aggregated metrics, as opposed to the current
// approach where each KV Router computes and published its own metrics.
// Publish metrics event // Publish metrics event
namespace.publish(&event_name, &processed).await?; namespace.publish(&event_name, &processed).await?;
......
...@@ -16,17 +16,17 @@ ...@@ -16,17 +16,17 @@
] ]
}, },
"copyright": [ "copyright": [
"SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.",
"SPDX-License-Identifier: Apache-2.0", "SPDX-License-Identifier: Apache-2.0",
"Licensed under the Apache License, Version 2.0 (the \"License\");", "Licensed under the Apache License, Version 2.0 (the \"License\");",
"you may not use this file except in compliance with the License.", "you may not use this file except in compliance with the License.",
"You may obtain a copy of the License at", "You may obtain a copy of the License at",
"http://www.apache.org/licenses/LICENSE-2.0", "http://www.apache.org/licenses/LICENSE-2.0",
"Unless required by applicable law or agreed to in writing, software", "Unless required by applicable law or agreed to in writing, software",
"distributed under the License is distributed on an \"AS IS\" BASIS,", "distributed under the License is distributed on an \"AS IS\" BASIS,",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"See the License for the specific language governing permissions and", "See the License for the specific language governing permissions and",
"limitations under the License." "limitations under the License."
], ],
"editable": true, "editable": true,
"fiscalYearStartMonth": 0, "fiscalYearStartMonth": 0,
...@@ -261,7 +261,7 @@ ...@@ -261,7 +261,7 @@
}, },
"gridPos": { "gridPos": {
"h": 8, "h": 8,
"w": 6, "w": 4,
"x": 0, "x": 0,
"y": 8 "y": 8
}, },
...@@ -329,8 +329,8 @@ ...@@ -329,8 +329,8 @@
}, },
"gridPos": { "gridPos": {
"h": 8, "h": 8,
"w": 6, "w": 4,
"x": 6, "x": 4,
"y": 8 "y": 8
}, },
"id": 4, "id": 4,
...@@ -363,6 +363,74 @@ ...@@ -363,6 +363,74 @@
} }
] ]
}, },
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "yellow",
"value": 50
},
{
"color": "red",
"value": 80
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 4,
"x": 8,
"y": 8
},
"id": 7,
"options": {
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showThresholdLabels": false,
"showThresholdMarkers": true
},
"pluginVersion": "10.0.0",
"title": "Cumulative KV Cache Hit Rate",
"type": "gauge",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
]
},
{ {
"datasource": { "datasource": {
"type": "prometheus", "type": "prometheus",
...@@ -467,6 +535,190 @@ ...@@ -467,6 +535,190 @@
} }
] ]
}, },
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"id": 8,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "KV Cache Hit Rate by Worker",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"} / llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"}",
"legendFormat": "Worker {{worker_id}}",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"id": 9,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "Cumulative KV Cache Hit Rate",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "Overall Hit Rate",
"range": true,
"refId": "A"
}
]
},
{ {
"datasource": { "datasource": {
"type": "prometheus", "type": "prometheus",
...@@ -525,7 +777,7 @@ ...@@ -525,7 +777,7 @@
"h": 8, "h": 8,
"w": 24, "w": 24,
"x": 0, "x": 0,
"y": 16 "y": 24
}, },
"id": 6, "id": 6,
"options": { "options": {
...@@ -647,4 +899,4 @@ ...@@ -647,4 +899,4 @@
"uid": "llm-worker-metrics", "uid": "llm-worker-metrics",
"version": 1, "version": 1,
"weekStart": "" "weekStart": ""
} }
\ No newline at end of file
...@@ -33,7 +33,7 @@ ENDPOINT_NAME=${4:-"dynemo.process.chat/completions"} ...@@ -33,7 +33,7 @@ ENDPOINT_NAME=${4:-"dynemo.process.chat/completions"}
VALID_STRATEGIES=("prefix") VALID_STRATEGIES=("prefix")
SESSION_NAME="v" SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm" WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/dynemo/venv/bin/activate && cd $WORKDIR" INIT_CMD="cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}" echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
use anyhow::Result; use anyhow::Result;
use dynemo_runtime::{component::Component, DistributedRuntime}; use dynemo_runtime::{component::Component, component::Namespace, DistributedRuntime};
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -57,15 +57,19 @@ impl KvRouter { ...@@ -57,15 +57,19 @@ impl KvRouter {
let nats_client = runtime.nats_client(); let nats_client = runtime.nats_client();
let service_name = backend.service_name(); let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
let namespace = runtime.namespace(backend.namespace())?;
tracing::info!("Component Namespace {}", backend.namespace());
tracing::info!("Component Service Name {}", service_name); tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject); tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await Self::new(nats_client, service_name, kv_subject, namespace).await
} }
pub async fn new( pub async fn new(
nats_client: dynemo_runtime::transports::nats::Client, nats_client: dynemo_runtime::transports::nats::Client,
service_name: String, service_name: String,
kv_subject: String, kv_subject: String,
namespace: Namespace,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
...@@ -78,7 +82,7 @@ impl KvRouter { ...@@ -78,7 +82,7 @@ impl KvRouter {
)); ));
let indexer = KvIndexer::new(cancellation_token.clone()); let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?; let scheduler = KvScheduler::start(ep_rx, namespace).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject); tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?; let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynemo_runtime::component::Namespace;
use dynemo_runtime::traits::events::EventPublisher;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::min; use std::cmp::min;
...@@ -21,7 +23,13 @@ use crate::kv_router::indexer::OverlapScores; ...@@ -21,7 +23,13 @@ use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE}; pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
#[allow(dead_code)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
pub worker_id: i64,
pub isl_blocks: usize,
pub overlap_blocks: usize,
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError { pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")] #[error("no endpoints aviailable to route work")]
...@@ -93,6 +101,7 @@ pub struct KvScheduler { ...@@ -93,6 +101,7 @@ pub struct KvScheduler {
impl KvScheduler { impl KvScheduler {
pub async fn start( pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>, endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx; let mut endpoints_rx = endpoints_rx;
...@@ -104,6 +113,20 @@ impl KvScheduler { ...@@ -104,6 +113,20 @@ impl KvScheduler {
} }
}; };
// Channel to asynchronously publish metric events on
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
// Publisher task
tokio::spawn(async move {
let mut event_rx = event_rx;
let subject = "kv-hit-rate";
while let Some(event) = event_rx.recv().await {
if let Err(e) = ns.publish(subject, &event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
}
}
});
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
tracing::debug!("scheduler starting"); tracing::debug!("scheduler starting");
...@@ -146,7 +169,7 @@ impl KvScheduler { ...@@ -146,7 +169,7 @@ impl KvScheduler {
}; };
tracing::debug!("selected"); tracing::debug!("selected");
loop { loop {
match select_worker(endpoints.borrow_mut(), &request) { match select_worker(endpoints.borrow_mut(), &request, &event_tx) {
Ok(worker_id) => { Ok(worker_id) => {
request.respond(worker_id); request.respond(worker_id);
continue 'outer; continue 'outer;
...@@ -175,7 +198,6 @@ impl KvScheduler { ...@@ -175,7 +198,6 @@ impl KvScheduler {
Ok(KvScheduler { request_tx }) Ok(KvScheduler { request_tx })
} }
#[allow(dead_code)]
pub async fn schedule( pub async fn schedule(
&self, &self,
overlap: OverlapScores, overlap: OverlapScores,
...@@ -205,6 +227,7 @@ impl KvScheduler { ...@@ -205,6 +227,7 @@ impl KvScheduler {
pub fn select_worker( pub fn select_worker(
workers: &mut ProcessedEndpoints, workers: &mut ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers // balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1; let balance_threshold: f64 = 0.1;
...@@ -268,6 +291,23 @@ pub fn select_worker( ...@@ -268,6 +291,23 @@ pub fn select_worker(
workers.endpoints[best_index].data.request_active_slots += 1; workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64; workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / KV_BLOCK_SIZE;
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: best_worker_id,
isl_blocks,
overlap_blocks: overlap_blocks as usize,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
}
} }
match best_index { match best_index {
......
...@@ -154,6 +154,11 @@ impl Component { ...@@ -154,6 +154,11 @@ impl Component {
format!("{}/{}", self.namespace, self.name) format!("{}/{}", self.namespace, self.name)
} }
/// Returns a reference to the namespace string of this component
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn drt(&self) -> &DistributedRuntime { pub fn drt(&self) -> &DistributedRuntime {
&self.drt &self.drt
} }
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
// limitations under the License. // limitations under the License.
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use super::*; use super::*;
use crate::traits::events::EventPublisher; use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait] #[async_trait]
impl EventPublisher for Namespace { impl EventPublisher for Namespace {
...@@ -49,6 +51,32 @@ impl EventPublisher for Namespace { ...@@ -49,6 +51,32 @@ impl EventPublisher for Namespace {
} }
} }
#[async_trait]
impl EventSubscriber for Namespace {
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl Stream<Item = Result<T>> + Send> {
let subscriber = self.subscribe(event_name).await?;
// Transform the subscriber into a stream of deserialized events
let stream = subscriber.map(move |msg| {
serde_json::from_slice::<T>(&msg.payload)
.map_err(|e| anyhow::anyhow!("Failed to deserialize event: {}", e))
});
Ok(stream)
}
}
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
...@@ -64,4 +92,27 @@ mod tests { ...@@ -64,4 +92,27 @@ mod tests {
ns.publish("test", &"test".to_string()).await.unwrap(); ns.publish("test", &"test".to_string()).await.unwrap();
rt.shutdown(); rt.shutdown();
} }
#[tokio::test]
async fn test_subscribe() {
let rt = Runtime::from_current().unwrap();
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
// Create a subscriber
let subscriber = ns.subscribe("test").await.unwrap();
// Publish a message
ns.publish("test", &"test_message".to_string())
.await
.unwrap();
// Receive the message
if let Some(msg) = subscriber.next().await {
let received = String::from_utf8(msg.payload.to_vec()).unwrap();
assert_eq!(received, "\"test_message\"");
}
rt.shutdown();
}
} }
...@@ -56,3 +56,24 @@ pub trait EventPublisher { ...@@ -56,3 +56,24 @@ pub trait EventPublisher {
// fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>; // fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>;
// fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes; // fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes;
} }
/// A trait for subscribing to events in the event plane.
///
/// This trait provides methods to subscribe to events published on specific subjects.
#[async_trait]
pub trait EventSubscriber {
/// Subscribe to events with the given event name.
/// The `event_name` will be `.` concatenated with the base subject provided by the implementation.
/// Returns a subscriber that can be used to receive events.
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber>;
/// Subscribe to events with the given event name and deserialize them to the specified type.
/// This is a convenience method that combines subscribe and deserialization.
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl futures::Stream<Item = Result<T>> + Send>;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment