Commit 09656f6c authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

feat: Add estimated kv cache hit metric events (#30)

parent a720fa12
......@@ -53,6 +53,8 @@ tensorrtllm_checkpoints/
tensorrtllm_engines/
api_server_models/
server/
# Replay/Snapshot test artifacts
*.new
**/*backups*
......
......@@ -739,6 +739,7 @@ dependencies = [
"clap",
"dynemo-llm",
"dynemo-runtime",
"futures",
"opentelemetry",
"opentelemetry-prometheus",
"prometheus",
......
......@@ -38,6 +38,7 @@ opentelemetry-prometheus = "0.13"
prometheus = "0.13"
rand = "0.8"
axum = "0.6"
futures = "0.3"
[dev-dependencies]
reqwest = { version = "0.11", features = ["blocking"] }
......@@ -16,7 +16,7 @@
//! Library functions for the count application.
use axum::{routing::get, Router};
use prometheus::register_gauge_vec;
use prometheus::{register_counter_vec, register_gauge_vec};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
......@@ -97,6 +97,18 @@ impl PrometheusMetricsServer {
pub fn update(&mut self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) {
self.metrics.update(config, processed);
}
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&mut self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
self.metrics
.update_kv_hit_rate(config, worker_id, isl_blocks, overlap_blocks);
}
}
/// Prometheus metrics collection
......@@ -107,6 +119,9 @@ pub struct PrometheusMetrics {
requests_total: prometheus::GaugeVec,
load_avg: prometheus::GaugeVec,
load_std: prometheus::GaugeVec,
// KV hit rate metrics
kv_hit_rate_isl_blocks: prometheus::CounterVec,
kv_hit_rate_overlap_blocks: prometheus::CounterVec,
}
impl PrometheusMetrics {
......@@ -143,6 +158,19 @@ impl PrometheusMetrics {
"Load standard deviation across workers",
&["component", "endpoint"]
)?,
// TODO: The cumulative isl/overlap metrics are monotonically increasing
// and may overflow at some point, we may want to periodically reset them.
// KV hit rate metrics
kv_hit_rate_isl_blocks: register_counter_vec!(
"llm_kv_hit_rate_isl_blocks",
"Cumulative count of ISL blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
kv_hit_rate_overlap_blocks: register_counter_vec!(
"llm_kv_hit_rate_overlap_blocks",
"Cumulative count of overlapping blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
})
}
......@@ -159,6 +187,19 @@ impl PrometheusMetrics {
.set(value);
}
/// Helper method to increment a counter with worker-specific labels (3 labels)
fn increment_worker_counter(
&self,
counter: &prometheus::CounterVec,
config: &LLMWorkerLoadCapacityConfig,
worker_id: &str,
value: f64,
) {
counter
.with_label_values(&[&config.component_name, &config.endpoint_name, worker_id])
.inc_by(value);
}
/// Helper method to set a gauge with component/endpoint labels only (2 labels)
fn set_endpoint_gauge(
&self,
......@@ -208,6 +249,61 @@ impl PrometheusMetrics {
self.set_endpoint_gauge(&self.load_avg, config, processed.load_avg);
self.set_endpoint_gauge(&self.load_std, config, processed.load_std);
}
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
let worker_id_str = worker_id.to_string();
// Increment the ISL blocks and overlap blocks counters
self.increment_worker_counter(
&self.kv_hit_rate_isl_blocks,
config,
&worker_id_str,
isl_blocks as f64,
);
self.increment_worker_counter(
&self.kv_hit_rate_overlap_blocks,
config,
&worker_id_str,
overlap_blocks as f64,
);
// TODO: The cumulative hit rate percentage can probably be computed by consumers
// of Prometheus metrics like Grafana instead, but we'll compute it here for now
// for convenient debugging/logging.
// Calculate and set the cumulative hit rate percentage
let cumulative_isl = self
.kv_hit_rate_isl_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
let cumulative_overlap = self
.kv_hit_rate_overlap_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
if cumulative_isl > 0.0 {
let cumulative_hit_rate = (cumulative_overlap / cumulative_isl) * 100.0;
tracing::info!(
"Estimated Cumulative KV hit rate: {cumulative_hit_rate:.2}% (Overlap: {cumulative_overlap} / ISL: {cumulative_isl})"
);
}
}
}
/// Collect endpoints from a component
......
......@@ -22,14 +22,20 @@
//! - These metrics will be scraped by the LLM NATS Service API's stats request
//! - Request Slots: [Active, Total]
//! - KV Cache Blocks: [Active, Total]
//! - KV Hit Rate:
//! - These metrics will be collected from KV hit rate events published by the KV router
//! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events
//! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache
use clap::Parser;
use dynemo_llm::kv_router::scheduler::KVHitRateEvent;
use dynemo_runtime::{
error, logging,
traits::events::EventPublisher,
traits::events::{EventPublisher, EventSubscriber},
utils::{Duration, Instant},
DistributedRuntime, ErrorContext, Result, Runtime, Worker,
};
use futures::stream::StreamExt;
use std::sync::Arc;
// Import from our library
use count::{
......@@ -111,8 +117,65 @@ async fn app(runtime: Runtime) -> Result<()> {
// TODO: Make metrics host/port configurable
// Initialize Prometheus metrics and start server
let mut metrics_server = PrometheusMetricsServer::new()?;
metrics_server.start(9091);
let metrics_server = PrometheusMetricsServer::new()?;
// Metrics will be updated concurrently, so protect it with a mutex:
// - Main loop: Collect and process ForwardPassMetrics at an interval from endpoint stats handlers
// - Subscription task: Collect and process KVHitRateEvent metrics from the KV router as they are published
let metrics_server = Arc::new(tokio::sync::Mutex::new(metrics_server));
metrics_server.lock().await.start(9091);
// Subscribe to KV hit rate events
let kv_hit_rate_subject = "kv-hit-rate";
tracing::info!("Subscribing to KV hit rate events on subject: {kv_hit_rate_subject}");
// Clone the metrics server and config for the subscription task
let metrics_server_clone = metrics_server.clone();
let config_clone = config.clone();
// Clone the namespace for the subscription task
let namespace_clone = namespace.clone();
// Spawn a task to handle KV hit rate events
tokio::spawn(async move {
match namespace_clone.subscribe(kv_hit_rate_subject).await {
Ok(mut subscriber) => {
tracing::info!("Successfully subscribed to KV hit rate events");
while let Some(msg) = subscriber.next().await {
match serde_json::from_slice::<KVHitRateEvent>(&msg.payload) {
Ok(event) => {
// TODO: Lower to debug
let cache_hit_pct =
(event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0;
tracing::info!(
"Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%",
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
cache_hit_pct
);
// Update metrics with the event data
let mut metrics = metrics_server_clone.lock().await;
metrics.update_kv_hit_rate(
&config_clone,
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
);
}
Err(e) => {
tracing::warn!("Failed to deserialize KV hit rate event: {:?}", e);
}
}
}
tracing::warn!("KV hit rate event subscription stream ended");
}
Err(e) => {
tracing::error!("Failed to subscribe to KV hit rate events: {:?}", e);
}
}
});
loop {
let next = Instant::now() + Duration::from_secs(args.poll_interval);
......@@ -123,12 +186,14 @@ async fn app(runtime: Runtime) -> Result<()> {
collect_endpoints(&target_component, &service_subject, scrape_timeout).await?;
let metrics = extract_metrics(&endpoints);
let processed = postprocess_metrics(&metrics, &endpoints);
tracing::info!("Aggregated metrics: {processed:?}");
tracing::debug!("Aggregated metrics: {processed:?}");
// Update Prometheus metrics
metrics_server.update(&config, &processed);
metrics_server.lock().await.update(&config, &processed);
// TODO: Who needs to consume these events?
// TODO: Enable KV Routers to subscribe to metrics events published here
// for a single view of the aggregated metrics, as opposed to the current
// approach where each KV Router computes and published its own metrics.
// Publish metrics event
namespace.publish(&event_name, &processed).await?;
......
......@@ -261,7 +261,7 @@
},
"gridPos": {
"h": 8,
"w": 6,
"w": 4,
"x": 0,
"y": 8
},
......@@ -329,8 +329,8 @@
},
"gridPos": {
"h": 8,
"w": 6,
"x": 6,
"w": 4,
"x": 4,
"y": 8
},
"id": 4,
......@@ -363,6 +363,74 @@
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "yellow",
"value": 50
},
{
"color": "red",
"value": 80
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 4,
"x": 8,
"y": 8
},
"id": 7,
"options": {
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showThresholdLabels": false,
"showThresholdMarkers": true
},
"pluginVersion": "10.0.0",
"title": "Cumulative KV Cache Hit Rate",
"type": "gauge",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
......@@ -467,6 +535,190 @@
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"id": 8,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "KV Cache Hit Rate by Worker",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"} / llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"}",
"legendFormat": "Worker {{worker_id}}",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"id": 9,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "Cumulative KV Cache Hit Rate",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "Overall Hit Rate",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
......@@ -525,7 +777,7 @@
"h": 8,
"w": 24,
"x": 0,
"y": 16
"y": 24
},
"id": 6,
"options": {
......
......@@ -33,7 +33,7 @@ ENDPOINT_NAME=${4:-"dynemo.process.chat/completions"}
VALID_STRATEGIES=("prefix")
SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/dynemo/venv/bin/activate && cd $WORKDIR"
INIT_CMD="cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
......
......@@ -14,7 +14,7 @@
// limitations under the License.
use anyhow::Result;
use dynemo_runtime::{component::Component, DistributedRuntime};
use dynemo_runtime::{component::Component, component::Namespace, DistributedRuntime};
use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken;
......@@ -57,15 +57,19 @@ impl KvRouter {
let nats_client = runtime.nats_client();
let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
let namespace = runtime.namespace(backend.namespace())?;
tracing::info!("Component Namespace {}", backend.namespace());
tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await
Self::new(nats_client, service_name, kv_subject, namespace).await
}
pub async fn new(
nats_client: dynemo_runtime::transports::nats::Client,
service_name: String,
kv_subject: String,
namespace: Namespace,
) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
......@@ -78,7 +82,7 @@ impl KvRouter {
));
let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?;
let scheduler = KvScheduler::start(ep_rx, namespace).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
......
......@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use dynemo_runtime::component::Namespace;
use dynemo_runtime::traits::events::EventPublisher;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;
......@@ -21,7 +23,13 @@ use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
pub worker_id: i64,
pub isl_blocks: usize,
pub overlap_blocks: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")]
......@@ -93,6 +101,7 @@ pub struct KvScheduler {
impl KvScheduler {
pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace,
) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx;
......@@ -104,6 +113,20 @@ impl KvScheduler {
}
};
// Channel to asynchronously publish metric events on
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
// Publisher task
tokio::spawn(async move {
let mut event_rx = event_rx;
let subject = "kv-hit-rate";
while let Some(event) = event_rx.recv().await {
if let Err(e) = ns.publish(subject, &event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
}
}
});
// Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
tracing::debug!("scheduler starting");
......@@ -146,7 +169,7 @@ impl KvScheduler {
};
tracing::debug!("selected");
loop {
match select_worker(endpoints.borrow_mut(), &request) {
match select_worker(endpoints.borrow_mut(), &request, &event_tx) {
Ok(worker_id) => {
request.respond(worker_id);
continue 'outer;
......@@ -175,7 +198,6 @@ impl KvScheduler {
Ok(KvScheduler { request_tx })
}
#[allow(dead_code)]
pub async fn schedule(
&self,
overlap: OverlapScores,
......@@ -205,6 +227,7 @@ impl KvScheduler {
pub fn select_worker(
workers: &mut ProcessedEndpoints,
request: &SchedulingRequest,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1;
......@@ -268,6 +291,23 @@ pub fn select_worker(
workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / KV_BLOCK_SIZE;
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: best_worker_id,
isl_blocks,
overlap_blocks: overlap_blocks as usize,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
}
}
match best_index {
......
......@@ -154,6 +154,11 @@ impl Component {
format!("{}/{}", self.namespace, self.name)
}
/// Returns a reference to the namespace string of this component
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn drt(&self) -> &DistributedRuntime {
&self.drt
}
......
......@@ -14,10 +14,12 @@
// limitations under the License.
use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use super::*;
use crate::traits::events::EventPublisher;
use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait]
impl EventPublisher for Namespace {
......@@ -49,6 +51,32 @@ impl EventPublisher for Namespace {
}
}
#[async_trait]
impl EventSubscriber for Namespace {
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl Stream<Item = Result<T>> + Send> {
let subscriber = self.subscribe(event_name).await?;
// Transform the subscriber into a stream of deserialized events
let stream = subscriber.map(move |msg| {
serde_json::from_slice::<T>(&msg.payload)
.map_err(|e| anyhow::anyhow!("Failed to deserialize event: {}", e))
});
Ok(stream)
}
}
#[cfg(feature = "integration")]
#[cfg(test)]
mod tests {
......@@ -64,4 +92,27 @@ mod tests {
ns.publish("test", &"test".to_string()).await.unwrap();
rt.shutdown();
}
#[tokio::test]
async fn test_subscribe() {
let rt = Runtime::from_current().unwrap();
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
// Create a subscriber
let subscriber = ns.subscribe("test").await.unwrap();
// Publish a message
ns.publish("test", &"test_message".to_string())
.await
.unwrap();
// Receive the message
if let Some(msg) = subscriber.next().await {
let received = String::from_utf8(msg.payload.to_vec()).unwrap();
assert_eq!(received, "\"test_message\"");
}
rt.shutdown();
}
}
......@@ -56,3 +56,24 @@ pub trait EventPublisher {
// fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>;
// fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes;
}
/// A trait for subscribing to events in the event plane.
///
/// This trait provides methods to subscribe to events published on specific subjects.
#[async_trait]
pub trait EventSubscriber {
/// Subscribe to events with the given event name.
/// The `event_name` will be `.` concatenated with the base subject provided by the implementation.
/// Returns a subscriber that can be used to receive events.
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber>;
/// Subscribe to events with the given event name and deserialize them to the specified type.
/// This is a convenience method that combines subscribe and deserialization.
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl futures::Stream<Item = Result<T>> + Send>;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment