Commit 09656f6c authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

feat: Add estimated kv cache hit metric events (#30)

parent a720fa12
...@@ -53,6 +53,8 @@ tensorrtllm_checkpoints/ ...@@ -53,6 +53,8 @@ tensorrtllm_checkpoints/
tensorrtllm_engines/ tensorrtllm_engines/
api_server_models/ api_server_models/
server/ server/
# Replay/Snapshot test artifacts
*.new
**/*backups* **/*backups*
......
...@@ -739,6 +739,7 @@ dependencies = [ ...@@ -739,6 +739,7 @@ dependencies = [
"clap", "clap",
"dynemo-llm", "dynemo-llm",
"dynemo-runtime", "dynemo-runtime",
"futures",
"opentelemetry", "opentelemetry",
"opentelemetry-prometheus", "opentelemetry-prometheus",
"prometheus", "prometheus",
......
...@@ -38,6 +38,7 @@ opentelemetry-prometheus = "0.13" ...@@ -38,6 +38,7 @@ opentelemetry-prometheus = "0.13"
prometheus = "0.13" prometheus = "0.13"
rand = "0.8" rand = "0.8"
axum = "0.6" axum = "0.6"
futures = "0.3"
[dev-dependencies] [dev-dependencies]
reqwest = { version = "0.11", features = ["blocking"] } reqwest = { version = "0.11", features = ["blocking"] }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
//! Library functions for the count application. //! Library functions for the count application.
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use prometheus::register_gauge_vec; use prometheus::{register_counter_vec, register_gauge_vec};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::SocketAddr; use std::net::SocketAddr;
...@@ -97,6 +97,18 @@ impl PrometheusMetricsServer { ...@@ -97,6 +97,18 @@ impl PrometheusMetricsServer {
pub fn update(&mut self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) { pub fn update(&mut self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) {
self.metrics.update(config, processed); self.metrics.update(config, processed);
} }
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&mut self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
self.metrics
.update_kv_hit_rate(config, worker_id, isl_blocks, overlap_blocks);
}
} }
/// Prometheus metrics collection /// Prometheus metrics collection
...@@ -107,6 +119,9 @@ pub struct PrometheusMetrics { ...@@ -107,6 +119,9 @@ pub struct PrometheusMetrics {
requests_total: prometheus::GaugeVec, requests_total: prometheus::GaugeVec,
load_avg: prometheus::GaugeVec, load_avg: prometheus::GaugeVec,
load_std: prometheus::GaugeVec, load_std: prometheus::GaugeVec,
// KV hit rate metrics
kv_hit_rate_isl_blocks: prometheus::CounterVec,
kv_hit_rate_overlap_blocks: prometheus::CounterVec,
} }
impl PrometheusMetrics { impl PrometheusMetrics {
...@@ -143,6 +158,19 @@ impl PrometheusMetrics { ...@@ -143,6 +158,19 @@ impl PrometheusMetrics {
"Load standard deviation across workers", "Load standard deviation across workers",
&["component", "endpoint"] &["component", "endpoint"]
)?, )?,
// TODO: The cumulative isl/overlap metrics are monotonically increasing
// and may overflow at some point, we may want to periodically reset them.
// KV hit rate metrics
kv_hit_rate_isl_blocks: register_counter_vec!(
"llm_kv_hit_rate_isl_blocks",
"Cumulative count of ISL blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
kv_hit_rate_overlap_blocks: register_counter_vec!(
"llm_kv_hit_rate_overlap_blocks",
"Cumulative count of overlapping blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
}) })
} }
...@@ -159,6 +187,19 @@ impl PrometheusMetrics { ...@@ -159,6 +187,19 @@ impl PrometheusMetrics {
.set(value); .set(value);
} }
/// Helper method to increment a counter with worker-specific labels (3 labels)
fn increment_worker_counter(
&self,
counter: &prometheus::CounterVec,
config: &LLMWorkerLoadCapacityConfig,
worker_id: &str,
value: f64,
) {
counter
.with_label_values(&[&config.component_name, &config.endpoint_name, worker_id])
.inc_by(value);
}
/// Helper method to set a gauge with component/endpoint labels only (2 labels) /// Helper method to set a gauge with component/endpoint labels only (2 labels)
fn set_endpoint_gauge( fn set_endpoint_gauge(
&self, &self,
...@@ -208,6 +249,61 @@ impl PrometheusMetrics { ...@@ -208,6 +249,61 @@ impl PrometheusMetrics {
self.set_endpoint_gauge(&self.load_avg, config, processed.load_avg); self.set_endpoint_gauge(&self.load_avg, config, processed.load_avg);
self.set_endpoint_gauge(&self.load_std, config, processed.load_std); self.set_endpoint_gauge(&self.load_std, config, processed.load_std);
} }
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
let worker_id_str = worker_id.to_string();
// Increment the ISL blocks and overlap blocks counters
self.increment_worker_counter(
&self.kv_hit_rate_isl_blocks,
config,
&worker_id_str,
isl_blocks as f64,
);
self.increment_worker_counter(
&self.kv_hit_rate_overlap_blocks,
config,
&worker_id_str,
overlap_blocks as f64,
);
// TODO: The cumulative hit rate percentage can probably be computed by consumers
// of Prometheus metrics like Grafana instead, but we'll compute it here for now
// for convenient debugging/logging.
// Calculate and set the cumulative hit rate percentage
let cumulative_isl = self
.kv_hit_rate_isl_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
let cumulative_overlap = self
.kv_hit_rate_overlap_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
if cumulative_isl > 0.0 {
let cumulative_hit_rate = (cumulative_overlap / cumulative_isl) * 100.0;
tracing::info!(
"Estimated Cumulative KV hit rate: {cumulative_hit_rate:.2}% (Overlap: {cumulative_overlap} / ISL: {cumulative_isl})"
);
}
}
} }
/// Collect endpoints from a component /// Collect endpoints from a component
......
...@@ -22,14 +22,20 @@ ...@@ -22,14 +22,20 @@
//! - These metrics will be scraped by the LLM NATS Service API's stats request //! - These metrics will be scraped by the LLM NATS Service API's stats request
//! - Request Slots: [Active, Total] //! - Request Slots: [Active, Total]
//! - KV Cache Blocks: [Active, Total] //! - KV Cache Blocks: [Active, Total]
//! - KV Hit Rate:
//! - These metrics will be collected from KV hit rate events published by the KV router
//! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events
//! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache
use clap::Parser; use clap::Parser;
use dynemo_llm::kv_router::scheduler::KVHitRateEvent;
use dynemo_runtime::{ use dynemo_runtime::{
error, logging, error, logging,
traits::events::EventPublisher, traits::events::{EventPublisher, EventSubscriber},
utils::{Duration, Instant}, utils::{Duration, Instant},
DistributedRuntime, ErrorContext, Result, Runtime, Worker, DistributedRuntime, ErrorContext, Result, Runtime, Worker,
}; };
use futures::stream::StreamExt;
use std::sync::Arc;
// Import from our library // Import from our library
use count::{ use count::{
...@@ -111,8 +117,65 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -111,8 +117,65 @@ async fn app(runtime: Runtime) -> Result<()> {
// TODO: Make metrics host/port configurable // TODO: Make metrics host/port configurable
// Initialize Prometheus metrics and start server // Initialize Prometheus metrics and start server
let mut metrics_server = PrometheusMetricsServer::new()?; let metrics_server = PrometheusMetricsServer::new()?;
metrics_server.start(9091); // Metrics will be updated concurrently, so protect it with a mutex:
// - Main loop: Collect and process ForwardPassMetrics at an interval from endpoint stats handlers
// - Subscription task: Collect and process KVHitRateEvent metrics from the KV router as they are published
let metrics_server = Arc::new(tokio::sync::Mutex::new(metrics_server));
metrics_server.lock().await.start(9091);
// Subscribe to KV hit rate events
let kv_hit_rate_subject = "kv-hit-rate";
tracing::info!("Subscribing to KV hit rate events on subject: {kv_hit_rate_subject}");
// Clone the metrics server and config for the subscription task
let metrics_server_clone = metrics_server.clone();
let config_clone = config.clone();
// Clone the namespace for the subscription task
let namespace_clone = namespace.clone();
// Spawn a task to handle KV hit rate events
tokio::spawn(async move {
match namespace_clone.subscribe(kv_hit_rate_subject).await {
Ok(mut subscriber) => {
tracing::info!("Successfully subscribed to KV hit rate events");
while let Some(msg) = subscriber.next().await {
match serde_json::from_slice::<KVHitRateEvent>(&msg.payload) {
Ok(event) => {
// TODO: Lower to debug
let cache_hit_pct =
(event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0;
tracing::info!(
"Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%",
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
cache_hit_pct
);
// Update metrics with the event data
let mut metrics = metrics_server_clone.lock().await;
metrics.update_kv_hit_rate(
&config_clone,
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
);
}
Err(e) => {
tracing::warn!("Failed to deserialize KV hit rate event: {:?}", e);
}
}
}
tracing::warn!("KV hit rate event subscription stream ended");
}
Err(e) => {
tracing::error!("Failed to subscribe to KV hit rate events: {:?}", e);
}
}
});
loop { loop {
let next = Instant::now() + Duration::from_secs(args.poll_interval); let next = Instant::now() + Duration::from_secs(args.poll_interval);
...@@ -123,12 +186,14 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -123,12 +186,14 @@ async fn app(runtime: Runtime) -> Result<()> {
collect_endpoints(&target_component, &service_subject, scrape_timeout).await?; collect_endpoints(&target_component, &service_subject, scrape_timeout).await?;
let metrics = extract_metrics(&endpoints); let metrics = extract_metrics(&endpoints);
let processed = postprocess_metrics(&metrics, &endpoints); let processed = postprocess_metrics(&metrics, &endpoints);
tracing::info!("Aggregated metrics: {processed:?}"); tracing::debug!("Aggregated metrics: {processed:?}");
// Update Prometheus metrics // Update Prometheus metrics
metrics_server.update(&config, &processed); metrics_server.lock().await.update(&config, &processed);
// TODO: Who needs to consume these events? // TODO: Enable KV Routers to subscribe to metrics events published here
// for a single view of the aggregated metrics, as opposed to the current
// approach where each KV Router computes and published its own metrics.
// Publish metrics event // Publish metrics event
namespace.publish(&event_name, &processed).await?; namespace.publish(&event_name, &processed).await?;
......
...@@ -261,7 +261,7 @@ ...@@ -261,7 +261,7 @@
}, },
"gridPos": { "gridPos": {
"h": 8, "h": 8,
"w": 6, "w": 4,
"x": 0, "x": 0,
"y": 8 "y": 8
}, },
...@@ -329,8 +329,8 @@ ...@@ -329,8 +329,8 @@
}, },
"gridPos": { "gridPos": {
"h": 8, "h": 8,
"w": 6, "w": 4,
"x": 6, "x": 4,
"y": 8 "y": 8
}, },
"id": 4, "id": 4,
...@@ -363,6 +363,74 @@ ...@@ -363,6 +363,74 @@
} }
] ]
}, },
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "yellow",
"value": 50
},
{
"color": "red",
"value": 80
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 4,
"x": 8,
"y": 8
},
"id": 7,
"options": {
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showThresholdLabels": false,
"showThresholdMarkers": true
},
"pluginVersion": "10.0.0",
"title": "Cumulative KV Cache Hit Rate",
"type": "gauge",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
]
},
{ {
"datasource": { "datasource": {
"type": "prometheus", "type": "prometheus",
...@@ -467,6 +535,190 @@ ...@@ -467,6 +535,190 @@
} }
] ]
}, },
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"id": 8,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "KV Cache Hit Rate by Worker",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"} / llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"}",
"legendFormat": "Worker {{worker_id}}",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"id": 9,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "Cumulative KV Cache Hit Rate",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "Overall Hit Rate",
"range": true,
"refId": "A"
}
]
},
{ {
"datasource": { "datasource": {
"type": "prometheus", "type": "prometheus",
...@@ -525,7 +777,7 @@ ...@@ -525,7 +777,7 @@
"h": 8, "h": 8,
"w": 24, "w": 24,
"x": 0, "x": 0,
"y": 16 "y": 24
}, },
"id": 6, "id": 6,
"options": { "options": {
......
...@@ -33,7 +33,7 @@ ENDPOINT_NAME=${4:-"dynemo.process.chat/completions"} ...@@ -33,7 +33,7 @@ ENDPOINT_NAME=${4:-"dynemo.process.chat/completions"}
VALID_STRATEGIES=("prefix") VALID_STRATEGIES=("prefix")
SESSION_NAME="v" SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm" WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/dynemo/venv/bin/activate && cd $WORKDIR" INIT_CMD="cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}" echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
use anyhow::Result; use anyhow::Result;
use dynemo_runtime::{component::Component, DistributedRuntime}; use dynemo_runtime::{component::Component, component::Namespace, DistributedRuntime};
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -57,15 +57,19 @@ impl KvRouter { ...@@ -57,15 +57,19 @@ impl KvRouter {
let nats_client = runtime.nats_client(); let nats_client = runtime.nats_client();
let service_name = backend.service_name(); let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
let namespace = runtime.namespace(backend.namespace())?;
tracing::info!("Component Namespace {}", backend.namespace());
tracing::info!("Component Service Name {}", service_name); tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject); tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await Self::new(nats_client, service_name, kv_subject, namespace).await
} }
pub async fn new( pub async fn new(
nats_client: dynemo_runtime::transports::nats::Client, nats_client: dynemo_runtime::transports::nats::Client,
service_name: String, service_name: String,
kv_subject: String, kv_subject: String,
namespace: Namespace,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
...@@ -78,7 +82,7 @@ impl KvRouter { ...@@ -78,7 +82,7 @@ impl KvRouter {
)); ));
let indexer = KvIndexer::new(cancellation_token.clone()); let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?; let scheduler = KvScheduler::start(ep_rx, namespace).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject); tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?; let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynemo_runtime::component::Namespace;
use dynemo_runtime::traits::events::EventPublisher;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::min; use std::cmp::min;
...@@ -21,7 +23,13 @@ use crate::kv_router::indexer::OverlapScores; ...@@ -21,7 +23,13 @@ use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE}; pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
#[allow(dead_code)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
pub worker_id: i64,
pub isl_blocks: usize,
pub overlap_blocks: usize,
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError { pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")] #[error("no endpoints aviailable to route work")]
...@@ -93,6 +101,7 @@ pub struct KvScheduler { ...@@ -93,6 +101,7 @@ pub struct KvScheduler {
impl KvScheduler { impl KvScheduler {
pub async fn start( pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>, endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx; let mut endpoints_rx = endpoints_rx;
...@@ -104,6 +113,20 @@ impl KvScheduler { ...@@ -104,6 +113,20 @@ impl KvScheduler {
} }
}; };
// Channel to asynchronously publish metric events on
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
// Publisher task
tokio::spawn(async move {
let mut event_rx = event_rx;
let subject = "kv-hit-rate";
while let Some(event) = event_rx.recv().await {
if let Err(e) = ns.publish(subject, &event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
}
}
});
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
tracing::debug!("scheduler starting"); tracing::debug!("scheduler starting");
...@@ -146,7 +169,7 @@ impl KvScheduler { ...@@ -146,7 +169,7 @@ impl KvScheduler {
}; };
tracing::debug!("selected"); tracing::debug!("selected");
loop { loop {
match select_worker(endpoints.borrow_mut(), &request) { match select_worker(endpoints.borrow_mut(), &request, &event_tx) {
Ok(worker_id) => { Ok(worker_id) => {
request.respond(worker_id); request.respond(worker_id);
continue 'outer; continue 'outer;
...@@ -175,7 +198,6 @@ impl KvScheduler { ...@@ -175,7 +198,6 @@ impl KvScheduler {
Ok(KvScheduler { request_tx }) Ok(KvScheduler { request_tx })
} }
#[allow(dead_code)]
pub async fn schedule( pub async fn schedule(
&self, &self,
overlap: OverlapScores, overlap: OverlapScores,
...@@ -205,6 +227,7 @@ impl KvScheduler { ...@@ -205,6 +227,7 @@ impl KvScheduler {
pub fn select_worker( pub fn select_worker(
workers: &mut ProcessedEndpoints, workers: &mut ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers // balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1; let balance_threshold: f64 = 0.1;
...@@ -268,6 +291,23 @@ pub fn select_worker( ...@@ -268,6 +291,23 @@ pub fn select_worker(
workers.endpoints[best_index].data.request_active_slots += 1; workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64; workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / KV_BLOCK_SIZE;
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: best_worker_id,
isl_blocks,
overlap_blocks: overlap_blocks as usize,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
}
} }
match best_index { match best_index {
......
...@@ -154,6 +154,11 @@ impl Component { ...@@ -154,6 +154,11 @@ impl Component {
format!("{}/{}", self.namespace, self.name) format!("{}/{}", self.namespace, self.name)
} }
/// Returns a reference to the namespace string of this component
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn drt(&self) -> &DistributedRuntime { pub fn drt(&self) -> &DistributedRuntime {
&self.drt &self.drt
} }
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
// limitations under the License. // limitations under the License.
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use super::*; use super::*;
use crate::traits::events::EventPublisher; use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait] #[async_trait]
impl EventPublisher for Namespace { impl EventPublisher for Namespace {
...@@ -49,6 +51,32 @@ impl EventPublisher for Namespace { ...@@ -49,6 +51,32 @@ impl EventPublisher for Namespace {
} }
} }
#[async_trait]
impl EventSubscriber for Namespace {
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl Stream<Item = Result<T>> + Send> {
let subscriber = self.subscribe(event_name).await?;
// Transform the subscriber into a stream of deserialized events
let stream = subscriber.map(move |msg| {
serde_json::from_slice::<T>(&msg.payload)
.map_err(|e| anyhow::anyhow!("Failed to deserialize event: {}", e))
});
Ok(stream)
}
}
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
...@@ -64,4 +92,27 @@ mod tests { ...@@ -64,4 +92,27 @@ mod tests {
ns.publish("test", &"test".to_string()).await.unwrap(); ns.publish("test", &"test".to_string()).await.unwrap();
rt.shutdown(); rt.shutdown();
} }
#[tokio::test]
async fn test_subscribe() {
let rt = Runtime::from_current().unwrap();
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
// Create a subscriber
let subscriber = ns.subscribe("test").await.unwrap();
// Publish a message
ns.publish("test", &"test_message".to_string())
.await
.unwrap();
// Receive the message
if let Some(msg) = subscriber.next().await {
let received = String::from_utf8(msg.payload.to_vec()).unwrap();
assert_eq!(received, "\"test_message\"");
}
rt.shutdown();
}
} }
...@@ -56,3 +56,24 @@ pub trait EventPublisher { ...@@ -56,3 +56,24 @@ pub trait EventPublisher {
// fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>; // fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>;
// fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes; // fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes;
} }
/// A trait for subscribing to events in the event plane.
///
/// This trait provides methods to subscribe to events published on specific subjects.
#[async_trait]
pub trait EventSubscriber {
/// Subscribe to events with the given event name.
/// The `event_name` will be `.` concatenated with the base subject provided by the implementation.
/// Returns a subscriber that can be used to receive events.
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber>;
/// Subscribe to events with the given event name and deserialize them to the specified type.
/// This is a convenience method that combines subscribe and deserialization.
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl futures::Stream<Item = Result<T>> + Send>;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment