Unverified Commit b441e26a authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Interface for shared kv cache handling in kv routing (#7536)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent 1b0334c8
...@@ -176,6 +176,7 @@ where ...@@ -176,6 +176,7 @@ where
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>, pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
shared_cache_hits: Option<crate::SharedCacheHits>,
) -> Result<SchedulingResponse, KvSchedulerError> { ) -> Result<SchedulingResponse, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let track_prefill_tokens = router_config_override let track_prefill_tokens = router_config_override
...@@ -196,6 +197,7 @@ where ...@@ -196,6 +197,7 @@ where
expected_output_tokens, expected_output_tokens,
pinned_worker, pinned_worker,
allowed_worker_ids, allowed_worker_ids,
shared_cache_hits,
resp_tx: Some(resp_tx), resp_tx: Some(resp_tx),
}; };
...@@ -429,6 +431,7 @@ mod tests { ...@@ -429,6 +431,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -470,6 +473,7 @@ mod tests { ...@@ -470,6 +473,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -511,6 +515,7 @@ mod tests { ...@@ -511,6 +515,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -531,6 +536,7 @@ mod tests { ...@@ -531,6 +536,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -572,6 +578,7 @@ mod tests { ...@@ -572,6 +578,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -592,6 +599,7 @@ mod tests { ...@@ -592,6 +599,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -647,6 +655,7 @@ mod tests { ...@@ -647,6 +655,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -667,6 +676,7 @@ mod tests { ...@@ -667,6 +676,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -721,6 +731,7 @@ mod tests { ...@@ -721,6 +731,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -741,6 +752,7 @@ mod tests { ...@@ -741,6 +752,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -793,6 +805,7 @@ mod tests { ...@@ -793,6 +805,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -894,6 +907,7 @@ mod tests { ...@@ -894,6 +907,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -989,6 +1003,7 @@ mod tests { ...@@ -989,6 +1003,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
......
...@@ -139,6 +139,7 @@ mod tests { ...@@ -139,6 +139,7 @@ mod tests {
expected_output_tokens: None, expected_output_tokens: None,
pinned_worker: None, pinned_worker: None,
allowed_worker_ids: None, allowed_worker_ids: None,
shared_cache_hits: None,
resp_tx: None, resp_tx: None,
} }
} }
......
...@@ -639,6 +639,7 @@ mod tests { ...@@ -639,6 +639,7 @@ mod tests {
expected_output_tokens: None, expected_output_tokens: None,
pinned_worker: None, pinned_worker: None,
allowed_worker_ids: None, allowed_worker_ids: None,
shared_cache_hits: None,
resp_tx: Some(tx), resp_tx: Some(tx),
}; };
(req, rx) (req, rx)
...@@ -1030,6 +1031,7 @@ mod tests { ...@@ -1030,6 +1031,7 @@ mod tests {
expected_output_tokens: None, expected_output_tokens: None,
pinned_worker: None, pinned_worker: None,
allowed_worker_ids: Some(allowed), allowed_worker_ids: Some(allowed),
shared_cache_hits: None,
resp_tx: Some(tx), resp_tx: Some(tx),
}; };
queue.enqueue(req).await; queue.enqueue(req).await;
......
...@@ -119,6 +119,7 @@ impl DefaultWorkerSelector { ...@@ -119,6 +119,7 @@ impl DefaultWorkerSelector {
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
block_size: u32, block_size: u32,
overlap_weight: f64, overlap_weight: f64,
shared_cache_multiplier: f64,
formula_name: &'static str, formula_name: &'static str,
) -> WorkerScore { ) -> WorkerScore {
let isl = request.isl_tokens; let isl = request.isl_tokens;
...@@ -129,14 +130,38 @@ impl DefaultWorkerSelector { ...@@ -129,14 +130,38 @@ impl DefaultWorkerSelector {
.get(&worker) .get(&worker)
.copied() .copied()
.unwrap_or(default_prefill_token); .unwrap_or(default_prefill_token);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// Adjust prefill tokens by shared cache hits beyond this worker's device prefix.
let (adjusted_prefill_token, shared_beyond) =
if let Some(ref shared_hits) = request.shared_cache_hits {
let beyond = shared_hits.hits_beyond(overlap_blocks);
let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64);
let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize;
(adjusted, beyond)
} else {
(prefill_token, 0)
};
let potential_prefill_block = (adjusted_prefill_token as f64) / (block_size as f64);
let decode_block_fallback = (prefill_token as f64) / (block_size as f64);
let decode_block = request let decode_block = request
.decode_blocks .decode_blocks
.get(&worker) .get(&worker)
.copied() .copied()
.unwrap_or(potential_prefill_block.floor() as usize) as f64; .unwrap_or(decode_block_fallback.floor() as usize) as f64;
let logit = overlap_weight * potential_prefill_block + decode_block; let logit = overlap_weight * potential_prefill_block + decode_block;
if shared_beyond > 0 {
tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} device blocks, \
{shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \
= {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \
(prefill_tokens: {prefill_token} -> {adjusted_prefill_token})",
worker.worker_id,
worker.dp_rank
);
} else {
tracing::debug!( tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \ "{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \
...@@ -144,6 +169,7 @@ impl DefaultWorkerSelector { ...@@ -144,6 +169,7 @@ impl DefaultWorkerSelector {
worker.worker_id, worker.worker_id,
worker.dp_rank worker.dp_rank
); );
}
WorkerScore { WorkerScore {
overlap_blocks, overlap_blocks,
...@@ -177,19 +203,27 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -177,19 +203,27 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let request_blocks = isl.div_ceil(block_size as usize); let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores; let overlaps = &request.overlaps.scores;
if let Some(worker) = pinned_worker {
pinned_worker_config(workers, worker)?;
let overlap_weight = request let overlap_weight = request
.router_config_override .router_config_override
.as_ref() .as_ref()
.and_then(|cfg| cfg.overlap_score_weight) .and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight); .unwrap_or(self.kv_router_config.overlap_score_weight);
let shared_cache_multiplier = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.shared_cache_multiplier)
.unwrap_or(self.kv_router_config.shared_cache_multiplier);
if let Some(worker) = pinned_worker {
pinned_worker_config(workers, worker)?;
let score = self.worker_score( let score = self.worker_score(
request, request,
worker, worker,
block_size, block_size,
overlap_weight, overlap_weight,
shared_cache_multiplier,
"Pinned formula", "Pinned formula",
); );
...@@ -200,12 +234,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -200,12 +234,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
}); });
} }
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
let temperature = request let temperature = request
.router_config_override .router_config_override
.as_ref() .as_ref()
...@@ -213,7 +241,14 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -213,7 +241,14 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let get_score = |worker: WorkerWithDpRank| -> f64 { let get_score = |worker: WorkerWithDpRank| -> f64 {
self.worker_score(request, worker, block_size, overlap_weight, "Formula") self.worker_score(
request,
worker,
block_size,
overlap_weight,
shared_cache_multiplier,
"Formula",
)
.logit .logit
}; };
...@@ -322,6 +357,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -322,6 +357,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::SharedCacheHits;
#[test] #[test]
fn test_softmax_sample_single_key() { fn test_softmax_sample_single_key() {
...@@ -431,4 +467,118 @@ mod tests { ...@@ -431,4 +467,118 @@ mod tests {
let result = softmax_sample_with_sample(&logits, temperature, sample); let result = softmax_sample_with_sample(&logits, temperature, sample);
assert_eq!(result, entries[target_idx]); assert_eq!(result, entries[target_idx]);
} }
/// Test the scoring formula with shared cache hits.
///
/// Request [A, B, C, D], shared_cache_multiplier=0.5, block_size=1
/// - Worker 0: device=[A,B] (overlap=2), shared has [A,B,C,D] -> shared_beyond=2
/// adjusted_prefill = isl - 0.5*2*1 = 4-1 = 3, logit = 1.0 * 3 + 0 = 3.0
/// - Worker 1: device=[] (overlap=0), shared has [A,B,C,D] -> shared_beyond=4
/// adjusted_prefill = isl - 0.5*4*1 = 4-2 = 2, logit = 1.0 * 2 + 0 = 2.0
///
/// Worker 1 has lower logit (less work), so it wins.
#[test]
fn test_shared_cache_hits_scoring() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig;
let block_size = 1u32;
let isl = 4usize;
let worker0 = WorkerWithDpRank::from_worker_id(0);
let worker1 = WorkerWithDpRank::from_worker_id(1);
let mut overlaps = OverlapScores::new();
overlaps.scores.insert(worker0, 2);
// worker1 has 0 overlap (not in map)
#[allow(clippy::single_range_in_vec_init)]
let shared_hits = SharedCacheHits::from_ranges(vec![0..4]);
let config = KvRouterConfig {
overlap_score_weight: 1.0,
shared_cache_multiplier: 0.5,
router_temperature: 0.0,
..Default::default()
};
let selector = DefaultWorkerSelector::new(Some(config), "test");
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
workers.insert(1, SimpleWorkerConfig::default());
let (tx, _rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id: Some("test".into()),
token_seq: None,
isl_tokens: isl,
overlaps,
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
router_config_override: None,
update_states: false,
lora_name: None,
priority_jump: 0.0,
expected_output_tokens: None,
pinned_worker: None,
allowed_worker_ids: None,
shared_cache_hits: Some(shared_hits),
resp_tx: Some(tx),
};
let result = selector
.select_worker(&workers, &request, block_size)
.unwrap();
// Worker 1 should win: logit 2.0 < 3.0
assert_eq!(
result.worker, worker1,
"Worker 1 should be selected (lower logit due to shared cache)"
);
}
/// Without shared cache hits, the scoring should be unchanged.
#[test]
fn test_no_shared_cache_unchanged() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig;
let block_size = 16u32;
let isl = 64usize;
let worker0 = WorkerWithDpRank::from_worker_id(0);
let mut overlaps = OverlapScores::new();
overlaps.scores.insert(worker0, 2);
let config = KvRouterConfig::default();
let selector = DefaultWorkerSelector::new(Some(config), "test");
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
let (tx, _rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id: Some("test".into()),
token_seq: None,
isl_tokens: isl,
overlaps,
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
router_config_override: None,
update_states: false,
lora_name: None,
priority_jump: 0.0,
expected_output_tokens: None,
pinned_worker: None,
allowed_worker_ids: None,
shared_cache_hits: None,
resp_tx: Some(tx),
};
let result = selector
.select_worker(&workers, &request, block_size)
.unwrap();
assert_eq!(result.worker, worker0);
}
} }
...@@ -8,7 +8,9 @@ use rustc_hash::FxHashMap; ...@@ -8,7 +8,9 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::config::RouterConfigOverride; use super::config::RouterConfigOverride;
use crate::protocols::{DpRank, OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::protocols::{
DpRank, OverlapScores, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank,
};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad { pub struct PotentialLoad {
...@@ -59,6 +61,10 @@ pub struct SchedulingRequest { ...@@ -59,6 +61,10 @@ pub struct SchedulingRequest {
pub pinned_worker: Option<WorkerWithDpRank>, pub pinned_worker: Option<WorkerWithDpRank>,
/// Optional set of allowed worker IDs to restrict routing decisions (EPP). /// Optional set of allowed worker IDs to restrict routing decisions (EPP).
pub allowed_worker_ids: Option<HashSet<WorkerId>>, pub allowed_worker_ids: Option<HashSet<WorkerId>>,
/// Shared cache hit information from an external shared KV cache pool.
/// When present, the selector adjusts prefill cost by weighting shared hits
/// beyond each worker's device prefix.
pub shared_cache_hits: Option<SharedCacheHits>,
pub resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>, pub resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Standalone dummy shared KV cache server.
//!
//! This is a minimal implementation intended for development and testing.
//! It stores block hashes in a simple in-memory `HashSet` and responds to
//! `check_blocks` queries with the positions that are present.
pub mod server;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use server::{AppState, SharedCacheStore, create_router};
pub struct SharedCacheConfig {
pub port: u16,
}
pub async fn run_server(config: SharedCacheConfig) -> anyhow::Result<()> {
let cancel_token = CancellationToken::new();
let shutdown_token = cancel_token.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Received shutdown signal");
shutdown_token.cancel();
});
tracing::info!(
port = config.port,
"Starting standalone shared KV cache server"
);
let store = Arc::new(SharedCacheStore::new());
let state = Arc::new(AppState { store });
let app = create_router(state);
let listener = TcpListener::bind(("0.0.0.0", config.port)).await?;
tracing::info!("HTTP server listening on 0.0.0.0:{}", config.port);
axum::serve(listener, app)
.with_graceful_shutdown(async move {
cancel_token.cancelled().await;
tracing::info!("Received shutdown signal, stopping HTTP server");
})
.await?;
Ok(())
}
#[cfg(feature = "indexer-runtime")]
pub async fn run_with_runtime(
runtime: dynamo_runtime::Runtime,
config: SharedCacheConfig,
namespace: String,
component_name: String,
) -> anyhow::Result<()> {
use dynamo_runtime::{
DistributedRuntime,
pipeline::{ManyOut, SingleIn, network::Ingress},
};
tracing::info!(
namespace,
component_name,
port = config.port,
"Starting standalone shared KV cache server (Dynamo runtime mode)"
);
let distributed_runtime = DistributedRuntime::from_settings(runtime).await?;
let cancel_token = distributed_runtime.primary_token();
let component = distributed_runtime
.namespace(namespace)?
.component(component_name)?;
let store = Arc::new(SharedCacheStore::new());
// Register a request-plane endpoint so routers can query via SharedKvCacheRequestPlaneClient.
let engine = Arc::new(server::SharedCacheQueryEngine {
store: store.clone(),
});
let ingress = Ingress::<
SingleIn<server::SharedCacheQueryRequest>,
ManyOut<server::SharedCacheQueryResponse>,
>::for_engine(engine)?;
let query_endpoint = component
.endpoint(server::SHARED_KV_CACHE_QUERY_ENDPOINT)
.endpoint_builder()
.handler(ingress)
.graceful_shutdown(true);
distributed_runtime.runtime().secondary().spawn(async move {
if let Err(err) = query_endpoint.start().await {
tracing::error!(error = %err, "Shared cache query endpoint failed");
}
});
tracing::info!(
endpoint = server::SHARED_KV_CACHE_QUERY_ENDPOINT,
"Query endpoint registered"
);
let state = Arc::new(AppState {
store: store.clone(),
});
let app = create_router(state);
let listener = TcpListener::bind(("0.0.0.0", config.port)).await?;
tracing::info!("HTTP server listening on 0.0.0.0:{}", config.port);
axum::serve(listener, app)
.with_graceful_shutdown(async move {
cancel_token.cancelled().await;
tracing::info!("Received shutdown signal, stopping HTTP server");
})
.await?;
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! HTTP server and in-memory store for the dummy shared KV cache.
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use dashmap::DashSet;
use serde::{Deserialize, Serialize};
use crate::protocols::SharedCacheHits;
/// Endpoint name for the shared KV cache query service (request plane).
/// Matches the constant in `lib/llm/src/kv_router/shared_cache.rs`.
pub const SHARED_KV_CACHE_QUERY_ENDPOINT: &str = "shared_kv_cache_query";
// ---------------------------------------------------------------------------
// Wire protocol types (shared with lib/llm/src/kv_router/shared_cache.rs)
// ---------------------------------------------------------------------------
/// Request to check which blocks exist in the shared cache.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SharedCacheQueryRequest {
pub block_hashes: Vec<u64>,
}
/// Response: sorted non-overlapping half-open ranges of present block positions.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SharedCacheQueryResponse {
pub ranges: Vec<[u32; 2]>,
}
#[cfg(feature = "indexer-runtime")]
impl crate::indexer::MaybeError for SharedCacheQueryResponse {
fn from_err(err: impl std::error::Error + 'static) -> Self {
tracing::warn!("SharedCacheQueryResponse::from_err: {err}");
Self { ranges: vec![] }
}
fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
None
}
}
#[cfg(feature = "indexer-runtime")]
impl dynamo_runtime::protocols::maybe_error::MaybeError for SharedCacheQueryResponse {
fn from_err(err: impl std::error::Error + 'static) -> Self {
tracing::warn!("SharedCacheQueryResponse::from_err: {err}");
Self { ranges: vec![] }
}
fn err(&self) -> Option<dynamo_runtime::error::DynamoError> {
None
}
}
/// Request to store block hashes (for populating the dummy cache).
#[derive(Deserialize)]
pub struct StoreRequest {
pub block_hashes: Vec<u64>,
}
/// Request to remove block hashes.
#[derive(Deserialize)]
pub struct RemoveRequest {
pub block_hashes: Vec<u64>,
}
// ---------------------------------------------------------------------------
// In-memory store
// ---------------------------------------------------------------------------
/// Thread-safe set of block hashes that exist in the "shared cache".
pub struct SharedCacheStore {
blocks: DashSet<u64>,
}
impl SharedCacheStore {
pub fn new() -> Self {
Self {
blocks: DashSet::new(),
}
}
/// Insert block hashes into the store.
pub fn store(&self, hashes: &[u64]) {
for &h in hashes {
self.blocks.insert(h);
}
}
/// Remove block hashes from the store.
pub fn remove(&self, hashes: &[u64]) {
for &h in hashes {
self.blocks.remove(&h);
}
}
/// Check which positions in the request have their block hash present.
/// Returns coalesced ranges for the response.
pub fn check_blocks(&self, block_hashes: &[u64]) -> SharedCacheHits {
let hits: Vec<bool> = block_hashes
.iter()
.map(|h| self.blocks.contains(h))
.collect();
SharedCacheHits::from_hits(&hits)
}
/// Number of blocks currently stored.
pub fn len(&self) -> usize {
self.blocks.len()
}
}
// ---------------------------------------------------------------------------
// Axum handlers
// ---------------------------------------------------------------------------
pub struct AppState {
pub store: Arc<SharedCacheStore>,
}
/// POST /check_blocks — query which block hashes exist.
async fn check_blocks(
State(state): State<Arc<AppState>>,
Json(req): Json<SharedCacheQueryRequest>,
) -> impl IntoResponse {
let hits = state.store.check_blocks(&req.block_hashes);
let ranges: Vec<[u32; 2]> = hits.ranges.iter().map(|r| [r.start, r.end]).collect();
(StatusCode::OK, Json(SharedCacheQueryResponse { ranges }))
}
/// POST /store — add block hashes to the cache.
async fn store_blocks(
State(state): State<Arc<AppState>>,
Json(req): Json<StoreRequest>,
) -> impl IntoResponse {
let count = req.block_hashes.len();
state.store.store(&req.block_hashes);
(
StatusCode::CREATED,
Json(serde_json::json!({
"status": "ok",
"stored": count,
"total": state.store.len(),
})),
)
}
/// POST /remove — remove block hashes from the cache.
async fn remove_blocks(
State(state): State<Arc<AppState>>,
Json(req): Json<RemoveRequest>,
) -> impl IntoResponse {
let count = req.block_hashes.len();
state.store.remove(&req.block_hashes);
(
StatusCode::OK,
Json(serde_json::json!({
"status": "ok",
"removed": count,
"total": state.store.len(),
})),
)
}
/// GET /health — liveness check.
async fn health() -> StatusCode {
StatusCode::OK
}
/// GET /stats — number of blocks stored.
async fn stats(State(state): State<Arc<AppState>>) -> impl IntoResponse {
Json(serde_json::json!({
"total_blocks": state.store.len(),
}))
}
pub fn create_router(state: Arc<AppState>) -> Router {
Router::new()
.route("/check_blocks", post(check_blocks))
.route("/store", post(store_blocks))
.route("/remove", post(remove_blocks))
.route("/health", get(health))
.route("/stats", get(stats))
.with_state(state)
}
// ---------------------------------------------------------------------------
// Request-plane engine (for Dynamo runtime integration)
// ---------------------------------------------------------------------------
#[cfg(feature = "indexer-runtime")]
pub struct SharedCacheQueryEngine {
pub store: Arc<SharedCacheStore>,
}
#[cfg(feature = "indexer-runtime")]
#[dynamo_runtime::pipeline::async_trait]
impl
dynamo_runtime::pipeline::AsyncEngine<
dynamo_runtime::pipeline::SingleIn<SharedCacheQueryRequest>,
dynamo_runtime::pipeline::ManyOut<SharedCacheQueryResponse>,
anyhow::Error,
> for SharedCacheQueryEngine
{
async fn generate(
&self,
request: dynamo_runtime::pipeline::SingleIn<SharedCacheQueryRequest>,
) -> anyhow::Result<dynamo_runtime::pipeline::ManyOut<SharedCacheQueryResponse>> {
use dynamo_runtime::pipeline::{AsyncEngineContextProvider, ResponseStream};
let (req, ctx) = request.into_parts();
let hits = self.store.check_blocks(&req.block_hashes);
let ranges: Vec<[u32; 2]> = hits.ranges.iter().map(|r| [r.start, r.end]).collect();
let response = SharedCacheQueryResponse { ranges };
let stream = dynamo_runtime::stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_store_and_check() {
let store = SharedCacheStore::new();
store.store(&[100, 200, 300]);
// Query: [100, 999, 200, 300, 888]
// Hits at positions 0, 2, 3 => ranges [0..1, 2..4]
let hits = store.check_blocks(&[100, 999, 200, 300, 888]);
assert_eq!(hits.total_hits, 3);
assert_eq!(hits.ranges, vec![0..1, 2..4]);
}
#[test]
fn test_check_empty_cache() {
let store = SharedCacheStore::new();
let hits = store.check_blocks(&[1, 2, 3]);
assert_eq!(hits.total_hits, 0);
assert!(hits.ranges.is_empty());
}
#[test]
fn test_remove_blocks() {
let store = SharedCacheStore::new();
store.store(&[10, 20, 30]);
store.remove(&[20]);
// Query: [10, 20, 30] => hits at 0 and 2 => ranges [0..1, 2..3]
let hits = store.check_blocks(&[10, 20, 30]);
assert_eq!(hits.total_hits, 2);
assert_eq!(hits.ranges, vec![0..1, 2..3]);
}
#[test]
fn test_all_hits() {
let store = SharedCacheStore::new();
store.store(&[1, 2, 3]);
let hits = store.check_blocks(&[1, 2, 3]);
assert_eq!(hits.total_hits, 3);
assert_eq!(hits.ranges, vec![0..3]);
}
#[test]
fn test_store_len() {
let store = SharedCacheStore::new();
assert_eq!(store.len(), 0);
store.store(&[1, 2, 3]);
assert_eq!(store.len(), 3);
store.store(&[1, 4]); // 1 is a duplicate
assert_eq!(store.len(), 4);
}
#[test]
fn test_response_wire_format() {
let hits = SharedCacheHits::from_ranges(vec![0..2, 5..8]);
let ranges: Vec<[u32; 2]> = hits.ranges.iter().map(|r| [r.start, r.end]).collect();
let resp = SharedCacheQueryResponse { ranges };
let json = serde_json::to_string(&resp).unwrap();
let parsed: SharedCacheQueryResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.ranges, vec![[0, 2], [5, 8]]);
}
}
...@@ -81,6 +81,7 @@ parking_lot = { workspace = true } ...@@ -81,6 +81,7 @@ parking_lot = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
sha2 = "0.10"
strum = { workspace = true } strum = { workspace = true }
tempfile = { workspace = true } tempfile = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
......
...@@ -18,7 +18,10 @@ use dynamo_runtime::{ ...@@ -18,7 +18,10 @@ use dynamo_runtime::{
}; };
use crate::{ use crate::{
kv_router::{KvRouter, router_endpoint_id, scheduler::DefaultWorkerSelector}, kv_router::{
KvRouter, router_endpoint_id, scheduler::DefaultWorkerSelector,
shared_cache::HicacheSharedKvCache,
},
local_model::runtime_config::DisaggregatedEndpoint, local_model::runtime_config::DisaggregatedEndpoint,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
types::{ types::{
...@@ -610,6 +613,26 @@ impl ModelManager { ...@@ -610,6 +613,26 @@ impl ModelManager {
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?; let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = DefaultWorkerSelector::new(kv_router_config.clone(), worker_type); let selector = DefaultWorkerSelector::new(kv_router_config.clone(), worker_type);
// Build shared cache client based on shared_cache_type.
let shared_cache: Option<Box<dyn dynamo_kv_router::SharedKvCache>> = match kv_router_config
.as_ref()
.map(|c| c.shared_cache_type)
.unwrap_or_default()
{
dynamo_kv_router::SharedCacheType::None => None,
dynamo_kv_router::SharedCacheType::Hicache => {
let worker_component_name = &endpoint.id().component;
tracing::info!(
worker_component = worker_component_name,
"Using HiCache shared KV cache"
);
Some(Box::new(HicacheSharedKvCache::new(
workers_with_configs.clone(),
)))
}
};
let chooser = KvRouter::new( let chooser = KvRouter::new(
endpoint.clone(), endpoint.clone(),
client, client,
...@@ -621,6 +644,7 @@ impl ModelManager { ...@@ -621,6 +644,7 @@ impl ModelManager {
worker_type, worker_type,
model_name, model_name,
is_eagle, is_eagle,
shared_cache,
) )
.await?; .await?;
Ok(Arc::new(chooser)) Ok(Arc::new(chooser))
...@@ -1152,8 +1176,6 @@ mod tests { ...@@ -1152,8 +1176,6 @@ mod tests {
use crate::kv_router::PrefillRouter; use crate::kv_router::PrefillRouter;
/// Helper: make a WorkerSet with an activated PrefillRouter attached. /// Helper: make a WorkerSet with an activated PrefillRouter attached.
/// The router is marked as activated to simulate a real deployment where
/// the prefill endpoint has already rendezvoused with the decode side.
fn make_worker_set_with_prefill_router( fn make_worker_set_with_prefill_router(
namespace: &str, namespace: &str,
mdcsum: &str, mdcsum: &str,
......
...@@ -6,7 +6,7 @@ use std::time::Instant; ...@@ -6,7 +6,7 @@ use std::time::Instant;
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
PrefillLoadEstimator, PrefillLoadEstimator, SharedKvCache,
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env}, config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError, indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT, protocols::KV_EVENT_SUBJECT,
...@@ -45,6 +45,7 @@ pub mod publisher; ...@@ -45,6 +45,7 @@ pub mod publisher;
pub mod push_router; pub mod push_router;
pub mod scheduler; pub mod scheduler;
pub mod sequence; pub mod sequence;
pub mod shared_cache;
pub mod sticky_sessions; pub mod sticky_sessions;
pub use agent_controller::AgentController; pub use agent_controller::AgentController;
...@@ -149,6 +150,9 @@ where ...@@ -149,6 +150,9 @@ where
client: Client, client: Client,
is_eagle: bool, is_eagle: bool,
_served_indexer_handle: Option<ServedIndexerHandle>, _served_indexer_handle: Option<ServedIndexerHandle>,
/// Optional external shared KV cache pool. When present, `find_best_match`
/// queries it in parallel with the indexer and factors shared hits into scoring.
shared_cache: Option<Box<dyn SharedKvCache>>,
} }
impl<Sel> KvRouter<Sel> impl<Sel> KvRouter<Sel>
...@@ -167,6 +171,7 @@ where ...@@ -167,6 +171,7 @@ where
worker_type: &'static str, worker_type: &'static str,
model_name: Option<String>, model_name: Option<String>,
is_eagle: bool, is_eagle: bool,
shared_cache: Option<Box<dyn SharedKvCache>>,
) -> Result<Self> { ) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default(); let kv_router_config = kv_router_config.unwrap_or_default();
kv_router_config.validate()?; kv_router_config.validate()?;
...@@ -249,6 +254,7 @@ where ...@@ -249,6 +254,7 @@ where
client, client,
is_eagle, is_eagle,
_served_indexer_handle: served_indexer_handle, _served_indexer_handle: served_indexer_handle,
shared_cache,
}) })
} }
...@@ -330,13 +336,61 @@ where ...@@ -330,13 +336,61 @@ where
}); });
let seq_hash_elapsed = start.elapsed(); let seq_hash_elapsed = start.elapsed();
let overlap_scores = self // Query indexer and shared cache in parallel when shared cache is configured.
// Time each independently so metrics can separate indexer vs shared cache latency.
let (overlap_scores, shared_cache_hits, indexer_duration, shared_cache_duration) =
if let Some(ref shared_cache) = self.shared_cache {
let indexer_fut = self
.indexer
.find_matches(block_hashes.clone())
.instrument(tracing::info_span!("kv_router.find_matches"));
let shared_fut = shared_cache
.check_blocks(tokens, self.block_size)
.instrument(tracing::info_span!("kv_router.shared_cache_check"));
let indexer_timed = async {
let t = Instant::now();
let r = indexer_fut.await;
(r, t.elapsed())
};
let shared_timed = async {
let t = Instant::now();
let r = shared_fut.await;
(r, t.elapsed())
};
let ((indexer_result, idx_dur), (shared_result, sc_dur)) =
tokio::join!(indexer_timed, shared_timed);
let overlaps = indexer_result?;
// Shared cache failure is non-fatal: log warning and fall back to empty hits.
let hits = match shared_result {
Ok(hits) => Some(hits),
Err(e) => {
tracing::warn!(error = %e, "Shared cache query failed, ignoring");
if let Some(m) = metrics::RoutingOverheadMetrics::get() {
m.inc_shared_cache_errors();
}
None
}
};
(overlaps, hits, idx_dur, Some(sc_dur))
} else {
let t = Instant::now();
let overlaps = self
.indexer .indexer
.find_matches(block_hashes) .find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches")) .instrument(tracing::info_span!("kv_router.find_matches"))
.await?; .await?;
(overlaps, None, t.elapsed(), None)
};
let find_matches_elapsed = start.elapsed(); let find_matches_elapsed = start.elapsed();
// Capture shared cache info for metrics before moving into schedule().
// Clone the hits so we can compute `hits_beyond(overlap_blocks)` after
// scheduling returns, since `overlap_blocks` isn't known until then.
let num_blocks = isl_tokens / self.block_size as usize;
let sc_hits_for_metrics = shared_cache_hits.clone();
let response = self let response = self
.scheduler .scheduler
.schedule( .schedule(
...@@ -351,6 +405,7 @@ where ...@@ -351,6 +405,7 @@ where
expected_output_tokens, expected_output_tokens,
pinned_worker, pinned_worker,
allowed_worker_ids, allowed_worker_ids,
shared_cache_hits,
) )
.instrument(tracing::info_span!("kv_router.schedule")) .instrument(tracing::info_span!("kv_router.schedule"))
.await?; .await?;
...@@ -360,11 +415,25 @@ where ...@@ -360,11 +415,25 @@ where
m.observe( m.observe(
hash_elapsed, hash_elapsed,
seq_hash_elapsed, seq_hash_elapsed,
indexer_duration,
shared_cache_duration,
find_matches_elapsed, find_matches_elapsed,
total_elapsed, total_elapsed,
); );
} }
// Observe per-request shared cache metrics.
if let Some(hits) = sc_hits_for_metrics
&& let Some(m) = metrics::RouterRequestMetrics::get()
{
if num_blocks > 0 {
m.shared_cache_hit_rate
.observe(hits.total_hits as f64 / num_blocks as f64);
}
let beyond = hits.hits_beyond(response.overlap_blocks);
m.shared_cache_beyond_blocks.observe(beyond as f64);
}
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
tracing::info!( tracing::info!(
isl_tokens, isl_tokens,
...@@ -643,3 +712,189 @@ where ...@@ -643,3 +712,189 @@ where
self.cancellation_token.cancel(); self.cancellation_token.cancel();
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use async_trait::async_trait;
use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
use tokio::sync::watch;
use crate::kv_router::scheduler::KvSchedulerError;
use crate::local_model::runtime_config::ModelRuntimeConfig;
struct FakeSharedCache {
hits: Option<dynamo_kv_router::protocols::SharedCacheHits>,
should_error: bool,
}
#[async_trait]
impl SharedKvCache for FakeSharedCache {
async fn check_blocks(
&self,
_tokens: &[u32],
_block_size: u32,
) -> Result<dynamo_kv_router::protocols::SharedCacheHits, KvRouterError> {
if self.should_error {
Err(KvRouterError::IndexerOffline)
} else {
Ok(self.hits.clone().unwrap_or_default())
}
}
}
struct InspectingSelector {
expected_hits: Option<u32>,
selected_worker: WorkerWithDpRank,
}
impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> for InspectingSelector {
fn select_worker(
&self,
_workers: &HashMap<WorkerId, ModelRuntimeConfig>,
request: &dynamo_kv_router::scheduling::SchedulingRequest,
block_size: u32,
) -> Result<dynamo_kv_router::protocols::WorkerSelectionResult, KvSchedulerError> {
let observed_hits = request
.shared_cache_hits
.as_ref()
.map(|hits| hits.total_hits);
assert_eq!(observed_hits, self.expected_hits);
Ok(dynamo_kv_router::protocols::WorkerSelectionResult {
worker: self.selected_worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: 0,
})
}
}
async fn make_test_component(name: &str) -> dynamo_runtime::component::Component {
let runtime = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(runtime, DistributedConfig::process_local())
.await
.unwrap();
let namespace = drt.namespace(format!("test-ns-{name}")).unwrap();
namespace
.component(format!("test-component-{name}"))
.unwrap()
}
async fn make_test_router(
selector: impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>
+ Send
+ Sync
+ 'static,
shared_cache: Option<Box<dyn SharedKvCache>>,
) -> KvRouter<
impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
> {
let component = make_test_component("shared-cache-router").await;
let endpoint = component.endpoint("backend");
let client = endpoint.client().await.unwrap();
let mut workers = HashMap::new();
workers.insert(0, ModelRuntimeConfig::default());
workers.insert(1, ModelRuntimeConfig::default());
let (_tx, rx) = watch::channel(workers);
let config = KvRouterConfig {
overlap_score_weight: 0.0,
router_temperature: 0.0,
use_kv_events: false,
router_track_active_blocks: false,
shared_cache_multiplier: 0.5,
skip_initial_worker_wait: true,
..Default::default()
};
KvRouter::new(
endpoint,
client,
rx,
2,
selector,
Some(config),
None,
"decode",
None,
false,
shared_cache,
)
.await
.unwrap()
}
#[tokio::test]
async fn test_find_best_match_passes_shared_cache_hits_to_scheduler() {
let router = make_test_router(
InspectingSelector {
expected_hits: Some(2),
selected_worker: WorkerWithDpRank::from_worker_id(1),
},
Some(Box::new(FakeSharedCache {
#[allow(clippy::single_range_in_vec_init)]
hits: Some(dynamo_kv_router::protocols::SharedCacheHits::from_ranges(
vec![0..2],
)),
should_error: false,
})),
)
.await;
let (worker, overlap) = router
.find_best_match(
None,
&[11, 12, 21, 22],
None,
None,
false,
None,
0.0,
None,
None,
None,
)
.await
.unwrap();
assert_eq!(worker, WorkerWithDpRank::from_worker_id(1));
assert_eq!(overlap, 0);
}
#[tokio::test]
async fn test_find_best_match_ignores_shared_cache_errors() {
let router = make_test_router(
InspectingSelector {
expected_hits: None,
selected_worker: WorkerWithDpRank::from_worker_id(0),
},
Some(Box::new(FakeSharedCache {
hits: None,
should_error: true,
})),
)
.await;
let (worker, overlap) = router
.find_best_match(
None,
&[11, 12, 21, 22],
None,
None,
false,
None,
0.0,
None,
None,
None,
)
.await
.unwrap();
assert_eq!(worker, WorkerWithDpRank::from_worker_id(0));
assert_eq!(overlap, 0);
}
}
...@@ -208,6 +208,8 @@ pub struct RoutingOverheadMetrics { ...@@ -208,6 +208,8 @@ pub struct RoutingOverheadMetrics {
pub seq_hashing: prometheus::Histogram, pub seq_hashing: prometheus::Histogram,
pub scheduling: prometheus::Histogram, pub scheduling: prometheus::Histogram,
pub total: prometheus::Histogram, pub total: prometheus::Histogram,
pub shared_cache_query: prometheus::Histogram,
pub shared_cache_errors_total: prometheus::IntCounter,
} }
static ROUTING_OVERHEAD_METRICS: OnceLock<Arc<RoutingOverheadMetrics>> = OnceLock::new(); static ROUTING_OVERHEAD_METRICS: OnceLock<Arc<RoutingOverheadMetrics>> = OnceLock::new();
...@@ -261,15 +263,35 @@ impl RoutingOverheadMetrics { ...@@ -261,15 +263,35 @@ impl RoutingOverheadMetrics {
let total = make( let total = make(
routing_overhead::TOTAL_MS, routing_overhead::TOTAL_MS,
"Total routing overhead per request in milliseconds", "Total routing overhead per request in milliseconds",
async_buckets, async_buckets.clone(),
) )
.expect("overhead_total_ms"); .expect("overhead_total_ms");
let shared_cache_query = make(
routing_overhead::SHARED_CACHE_QUERY_MS,
"Time spent querying the shared KV cache in milliseconds",
async_buckets,
)
.expect("overhead_shared_cache_query_ms");
let shared_cache_errors_total = {
let name = format!(
"{}_{}",
name_prefix::ROUTER,
routing_overhead::SHARED_CACHE_ERRORS_TOTAL
);
prometheus::IntCounter::with_opts(
Opts::new(name, "Total shared cache query errors")
.const_label(labels::ROUTER_ID, &router_id),
)
.expect("shared_cache_errors_total")
};
Arc::new(Self { Arc::new(Self {
block_hashing, block_hashing,
indexer_find_matches, indexer_find_matches,
seq_hashing, seq_hashing,
scheduling, scheduling,
total, total,
shared_cache_query,
shared_cache_errors_total,
}) })
}); });
registry.register(Box::new(m.block_hashing.clone()))?; registry.register(Box::new(m.block_hashing.clone()))?;
...@@ -277,6 +299,8 @@ impl RoutingOverheadMetrics { ...@@ -277,6 +299,8 @@ impl RoutingOverheadMetrics {
registry.register(Box::new(m.seq_hashing.clone()))?; registry.register(Box::new(m.seq_hashing.clone()))?;
registry.register(Box::new(m.scheduling.clone()))?; registry.register(Box::new(m.scheduling.clone()))?;
registry.register(Box::new(m.total.clone()))?; registry.register(Box::new(m.total.clone()))?;
registry.register(Box::new(m.shared_cache_query.clone()))?;
registry.register(Box::new(m.shared_cache_errors_total.clone()))?;
Ok(()) Ok(())
} }
...@@ -286,10 +310,16 @@ impl RoutingOverheadMetrics { ...@@ -286,10 +310,16 @@ impl RoutingOverheadMetrics {
} }
/// Observe routing overhead timings in milliseconds. /// Observe routing overhead timings in milliseconds.
///
/// `indexer_duration` and `shared_cache_duration` are independent wall-clock times
/// measured inside the `tokio::join!` block. They run in parallel, so
/// `find_matches_elapsed >= max(indexer_duration, shared_cache_duration)`.
pub fn observe( pub fn observe(
&self, &self,
hash_elapsed: Duration, hash_elapsed: Duration,
seq_hash_elapsed: Duration, seq_hash_elapsed: Duration,
indexer_duration: Duration,
shared_cache_duration: Option<Duration>,
find_matches_elapsed: Duration, find_matches_elapsed: Duration,
total_elapsed: Duration, total_elapsed: Duration,
) { ) {
...@@ -297,12 +327,12 @@ impl RoutingOverheadMetrics { ...@@ -297,12 +327,12 @@ impl RoutingOverheadMetrics {
.observe(hash_elapsed.as_secs_f64() * 1000.0); .observe(hash_elapsed.as_secs_f64() * 1000.0);
self.seq_hashing self.seq_hashing
.observe(seq_hash_elapsed.saturating_sub(hash_elapsed).as_secs_f64() * 1000.0); .observe(seq_hash_elapsed.saturating_sub(hash_elapsed).as_secs_f64() * 1000.0);
self.indexer_find_matches.observe( self.indexer_find_matches
find_matches_elapsed .observe(indexer_duration.as_secs_f64() * 1000.0);
.saturating_sub(seq_hash_elapsed) if let Some(sc_duration) = shared_cache_duration {
.as_secs_f64() self.shared_cache_query
* 1000.0, .observe(sc_duration.as_secs_f64() * 1000.0);
); }
self.scheduling.observe( self.scheduling.observe(
total_elapsed total_elapsed
.saturating_sub(find_matches_elapsed) .saturating_sub(find_matches_elapsed)
...@@ -311,6 +341,11 @@ impl RoutingOverheadMetrics { ...@@ -311,6 +341,11 @@ impl RoutingOverheadMetrics {
); );
self.total.observe(total_elapsed.as_secs_f64() * 1000.0); self.total.observe(total_elapsed.as_secs_f64() * 1000.0);
} }
/// Increment the shared cache error counter.
pub fn inc_shared_cache_errors(&self) {
self.shared_cache_errors_total.inc();
}
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -354,11 +389,18 @@ pub struct RouterRequestMetrics { ...@@ -354,11 +389,18 @@ pub struct RouterRequestMetrics {
pub output_sequence_tokens: prometheus::Histogram, pub output_sequence_tokens: prometheus::Histogram,
pub kv_hit_rate: prometheus::Histogram, pub kv_hit_rate: prometheus::Histogram,
pub kv_transfer_estimated_latency_seconds: prometheus::Histogram, pub kv_transfer_estimated_latency_seconds: prometheus::Histogram,
pub shared_cache_hit_rate: prometheus::Histogram,
pub shared_cache_beyond_blocks: prometheus::Histogram,
} }
static ROUTER_REQUEST_METRICS: OnceLock<Arc<RouterRequestMetrics>> = OnceLock::new(); static ROUTER_REQUEST_METRICS: OnceLock<Arc<RouterRequestMetrics>> = OnceLock::new();
impl RouterRequestMetrics { impl RouterRequestMetrics {
/// Returns the registered metrics if `from_component()` was called earlier.
pub fn get() -> Option<Arc<Self>> {
ROUTER_REQUEST_METRICS.get().cloned()
}
/// Create from a Component, memoized in a static OnceLock. /// Create from a Component, memoized in a static OnceLock.
/// Uses the MetricsHierarchy API which auto-prepends `dynamo_component_`, /// Uses the MetricsHierarchy API which auto-prepends `dynamo_component_`,
/// injects hierarchy labels, and registers with the DRT `MetricsRegistry`. /// injects hierarchy labels, and registers with the DRT `MetricsRegistry`.
...@@ -428,6 +470,22 @@ impl RouterRequestMetrics { ...@@ -428,6 +470,22 @@ impl RouterRequestMetrics {
Some(generate_log_buckets(0.001, 10.0, 15)), Some(generate_log_buckets(0.001, 10.0, 15)),
) )
.expect("failed to create router_kv_transfer_estimated_latency_seconds"); .expect("failed to create router_kv_transfer_estimated_latency_seconds");
let shared_cache_hit_rate = metrics
.create_histogram(
&router_metric(frontend_service::SHARED_CACHE_HIT_RATE),
"Fraction of request blocks found in the shared KV cache (0.0-1.0)",
extra_labels,
Some(prometheus::linear_buckets(0.0, 0.05, 21).unwrap()),
)
.expect("failed to create router_shared_cache_hit_rate");
let shared_cache_beyond_blocks = metrics
.create_histogram(
&router_metric(frontend_service::SHARED_CACHE_BEYOND_BLOCKS),
"Shared cache blocks beyond device overlap for the selected worker",
extra_labels,
Some(prometheus::exponential_buckets(1.0, 2.0, 12).unwrap()),
)
.expect("failed to create router_shared_cache_beyond_blocks");
Arc::new(Self { Arc::new(Self {
requests_total, requests_total,
time_to_first_token_seconds, time_to_first_token_seconds,
...@@ -436,6 +494,8 @@ impl RouterRequestMetrics { ...@@ -436,6 +494,8 @@ impl RouterRequestMetrics {
output_sequence_tokens, output_sequence_tokens,
kv_hit_rate, kv_hit_rate,
kv_transfer_estimated_latency_seconds, kv_transfer_estimated_latency_seconds,
shared_cache_hit_rate,
shared_cache_beyond_blocks,
}) })
}) })
.clone() .clone()
...@@ -654,12 +714,20 @@ dynamo_frontend_router_queue_pending_requests{worker_type=\"decode\"} 5 ...@@ -654,12 +714,20 @@ dynamo_frontend_router_queue_pending_requests{worker_type=\"decode\"} 5
seq_hashing: make("test_seq_hashing_ms"), seq_hashing: make("test_seq_hashing_ms"),
scheduling: make("test_scheduling_ms"), scheduling: make("test_scheduling_ms"),
total: make("test_total_ms"), total: make("test_total_ms"),
shared_cache_query: make("test_shared_cache_query_ms"),
shared_cache_errors_total: prometheus::IntCounter::new(
"test_shared_cache_errors_total",
"test",
)
.unwrap(),
}; };
// Out-of-order cumulative durations: each phase < previous (would panic without saturating_sub) // Out-of-order cumulative durations: each phase < previous (would panic without saturating_sub)
metrics.observe( metrics.observe(
Duration::from_millis(10), Duration::from_millis(10),
Duration::from_millis(5), Duration::from_millis(5),
Duration::from_millis(4),
None,
Duration::from_millis(3), Duration::from_millis(3),
Duration::from_millis(1), Duration::from_millis(1),
); );
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_kv_router::protocols::SharedCacheHits;
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy; pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
pub use dynamo_kv_router::scheduling::{ pub use dynamo_kv_router::scheduling::{
KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse, KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
...@@ -138,6 +139,7 @@ where ...@@ -138,6 +139,7 @@ where
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>, pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
shared_cache_hits: Option<SharedCacheHits>,
) -> Result<SchedulingResponse, KvSchedulerError> { ) -> Result<SchedulingResponse, KvSchedulerError> {
let response = self let response = self
.inner .inner
...@@ -153,6 +155,7 @@ where ...@@ -153,6 +155,7 @@ where
expected_output_tokens, expected_output_tokens,
pinned_worker, pinned_worker,
allowed_worker_ids, allowed_worker_ids,
shared_cache_hits,
) )
.await; .await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count()); ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! HiCache shared KV cache client for SGLang + Mooncake.
//!
//! Instead of querying a worker endpoint over the request plane, this client:
//! 1. Reads Mooncake HiCache metadata published by SGLang workers in runtime config.
//! 2. Recomputes the logical HiCache page hashes from request tokens using the
//! same token -> page-hash logic as SGLang.
//! 3. Expands those logical page hashes into the concrete Mooncake object keys
//! SGLang uses for the configured TP/PP/MLA layout.
//! 4. Queries the Mooncake master HTTP service directly via `/batch_query_keys`.
use std::collections::HashMap;
use std::time::Duration;
use async_trait::async_trait;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
const MOONCAKE_HTTP_TIMEOUT: Duration = Duration::from_secs(2);
use dynamo_kv_router::{
SharedKvCache,
indexer::KvRouterError,
protocols::{SharedCacheHits, WorkerId},
};
use crate::{discovery::RuntimeConfigWatch, local_model::runtime_config::ModelRuntimeConfig};
const SGLANG_HICACHE_MOONCAKE_RUNTIME_KEY: &str = "sglang_hicache_mooncake";
const MOONCAKE_BATCH_QUERY_KEYS_CHUNK_SIZE: usize = 128;
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
struct SglangHicacheMooncakeConfig {
backend: String,
page_size: u32,
tp_size: u32,
pp_size: u32,
is_mla_model: bool,
is_eagle: bool,
tp_lcm_size: Option<u32>,
should_split_heads: bool,
extra_backend_tag: Option<String>,
master_server_address: Option<String>,
master_metrics_port: u16,
}
#[derive(Debug, Deserialize)]
struct MooncakeBatchQueryKeysResponse {
success: bool,
#[serde(default)]
data: HashMap<String, MooncakeBatchQueryKeyResult>,
}
#[derive(Debug, Deserialize, Default)]
struct MooncakeBatchQueryKeyResult {
#[serde(default)]
ok: bool,
}
#[derive(Debug, Clone, Copy)]
enum QueryToken {
Single(u32),
Bigram(u32, u32),
}
/// Shared KV cache client that queries the Mooncake master HTTP service for
/// SGLang HiCache (L3) state.
pub struct HicacheSharedKvCache {
runtime_configs: RuntimeConfigWatch,
http_client: reqwest::Client,
}
impl HicacheSharedKvCache {
pub fn new(runtime_configs: RuntimeConfigWatch) -> Self {
Self {
runtime_configs,
http_client: reqwest::Client::builder()
.timeout(MOONCAKE_HTTP_TIMEOUT)
.build()
.expect("failed to build reqwest client"),
}
}
fn resolve_mooncake_config(&self) -> Option<SglangHicacheMooncakeConfig> {
let workers = self.runtime_configs.borrow();
let mut configs = Vec::new();
for (worker_id, runtime_config) in workers.iter() {
if let Some(config) = mooncake_config_from_runtime(*worker_id, runtime_config) {
configs.push((*worker_id, config));
}
}
let (_, first) = configs.first()?;
if configs.iter().any(|(_, config)| config != first) {
tracing::warn!(
workers = ?configs.iter().map(|(worker_id, _)| *worker_id).collect::<Vec<_>>(),
"SGLang Mooncake HiCache runtime configs differ across workers; skipping shared-cache lookup"
);
return None;
}
Some(first.clone())
}
async fn fetch_key_presence(
&self,
endpoint: &Url,
actual_keys: &[String],
) -> Result<HashMap<String, bool>, KvRouterError> {
let mut key_presence = HashMap::with_capacity(actual_keys.len());
for chunk in actual_keys.chunks(MOONCAKE_BATCH_QUERY_KEYS_CHUNK_SIZE) {
let joined_keys = chunk.join(",");
let mut url = endpoint.clone();
// Mooncake expects a raw comma-separated `keys=` list. If commas are
// percent-encoded (`%2C`), Mooncake treats the entire value as one key.
url.set_query(Some(&format!("keys={joined_keys}")));
let response = self.http_client.get(url.clone()).send().await.map_err(|e| {
tracing::warn!(error = %e, url = %url, "Mooncake batch_query_keys request failed");
KvRouterError::IndexerOffline
})?;
let status = response.status();
if !status.is_success() {
tracing::warn!(
status = %status,
url = %url,
"Mooncake batch_query_keys returned non-success status"
);
return Err(KvRouterError::IndexerOffline);
}
let body: MooncakeBatchQueryKeysResponse = response.json().await.map_err(|e| {
tracing::warn!(
error = %e,
url = %url,
"Failed to decode Mooncake batch_query_keys response"
);
KvRouterError::IndexerOffline
})?;
if !body.success {
tracing::warn!(url = %url, "Mooncake batch_query_keys reported failure");
return Err(KvRouterError::IndexerOffline);
}
for key in chunk {
let exists = body.data.get(key).map(|entry| entry.ok).unwrap_or(false);
key_presence.insert(key.clone(), exists);
}
}
Ok(key_presence)
}
}
#[async_trait]
impl SharedKvCache for HicacheSharedKvCache {
async fn check_blocks(
&self,
tokens: &[u32],
block_size: u32,
) -> Result<SharedCacheHits, KvRouterError> {
let Some(config) = self.resolve_mooncake_config() else {
tracing::debug!("No SGLang Mooncake HiCache runtime config available");
return Ok(SharedCacheHits::default());
};
if config.backend != "mooncake" {
tracing::debug!(backend = %config.backend, "Skipping non-Mooncake HiCache config");
return Ok(SharedCacheHits::default());
}
if config.page_size == 0 || block_size == 0 {
tracing::warn!(
worker_page_size = config.page_size,
router_page_size = block_size,
"Invalid HiCache page size; skipping shared-cache lookup"
);
return Ok(SharedCacheHits::default());
}
if config.page_size != block_size {
tracing::warn!(
worker_page_size = config.page_size,
router_page_size = block_size,
"HiCache page size mismatch; skipping shared-cache lookup"
);
return Ok(SharedCacheHits::default());
}
let Some(endpoint) = mooncake_batch_query_endpoint(&config) else {
tracing::debug!("Mooncake master HTTP endpoint is unavailable");
return Ok(SharedCacheHits::default());
};
let page_hashes = logical_page_hashes(tokens, config.page_size, config.is_eagle);
if page_hashes.is_empty() {
return Ok(SharedCacheHits::default());
}
let page_query_keys = build_page_query_keys(&page_hashes, &config);
let all_actual_keys = page_query_keys
.iter()
.flat_map(|keys| keys.iter().cloned())
.collect::<Vec<_>>();
let key_presence = self.fetch_key_presence(&endpoint, &all_actual_keys).await?;
let page_hits = page_query_keys
.iter()
.map(|keys| {
keys.iter()
.all(|key| key_presence.get(key).copied().unwrap_or(false))
})
.collect::<Vec<_>>();
Ok(SharedCacheHits::from_hits(&page_hits))
}
}
fn mooncake_config_from_runtime(
worker_id: WorkerId,
runtime_config: &ModelRuntimeConfig,
) -> Option<SglangHicacheMooncakeConfig> {
match runtime_config
.get_engine_specific::<SglangHicacheMooncakeConfig>(SGLANG_HICACHE_MOONCAKE_RUNTIME_KEY)
{
Ok(Some(config)) => Some(config),
Ok(None) => None,
Err(error) => {
tracing::warn!(
worker_id,
runtime_key = SGLANG_HICACHE_MOONCAKE_RUNTIME_KEY,
%error,
"Failed to parse SGLang Mooncake HiCache runtime config"
);
None
}
}
}
fn mooncake_batch_query_endpoint(config: &SglangHicacheMooncakeConfig) -> Option<Url> {
let master_server_address = config.master_server_address.as_deref()?;
let mut url = Url::parse(&format!("http://{master_server_address}"))
.inspect_err(|error| {
tracing::warn!(
master_server_address,
%error,
"Failed to parse Mooncake master address"
);
})
.ok()?;
if url.set_port(Some(config.master_metrics_port)).is_err() {
tracing::warn!(
master_server_address,
master_metrics_port = config.master_metrics_port,
"Failed to set Mooncake master HTTP port"
);
return None;
}
url.set_path("/batch_query_keys");
url.set_query(None);
Some(url)
}
fn logical_page_hashes(tokens: &[u32], page_size: u32, is_eagle: bool) -> Vec<String> {
let page_size = page_size as usize;
if page_size == 0 {
return Vec::new();
}
let query_tokens = if is_eagle {
tokens
.windows(2)
.map(|pair| QueryToken::Bigram(pair[0], pair[1]))
.collect::<Vec<_>>()
} else {
tokens
.iter()
.copied()
.map(QueryToken::Single)
.collect::<Vec<_>>()
};
let aligned_len = (query_tokens.len() / page_size) * page_size;
let aligned_tokens = &query_tokens[..aligned_len];
let mut page_hashes = Vec::with_capacity(aligned_tokens.len() / page_size);
let mut prior_hash = None;
for page_tokens in aligned_tokens.chunks(page_size) {
let digest = hash_query_tokens(page_tokens, prior_hash.as_ref());
page_hashes.push(hex_encode(&digest));
prior_hash = Some(digest);
}
page_hashes
}
fn hash_query_tokens(page_tokens: &[QueryToken], prior_hash: Option<&[u8; 32]>) -> [u8; 32] {
let mut hasher = Sha256::new();
if let Some(prior_hash) = prior_hash {
hasher.update(prior_hash);
}
for token in page_tokens {
match token {
QueryToken::Single(token) => hasher.update(token.to_le_bytes()),
QueryToken::Bigram(lhs, rhs) => {
hasher.update(lhs.to_le_bytes());
hasher.update(rhs.to_le_bytes());
}
}
}
hasher.finalize().into()
}
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut output = String::with_capacity(bytes.len() * 2);
for byte in bytes {
output.push(HEX[(byte >> 4) as usize] as char);
output.push(HEX[(byte & 0x0f) as usize] as char);
}
output
}
fn build_page_query_keys(
page_hashes: &[String],
config: &SglangHicacheMooncakeConfig,
) -> Vec<Vec<String>> {
page_hashes
.iter()
.map(|page_hash| expand_actual_query_keys(page_hash, config))
.collect()
}
fn expand_actual_query_keys(
logical_page_hash: &str,
config: &SglangHicacheMooncakeConfig,
) -> Vec<String> {
let logical_key = maybe_prefix_key(logical_page_hash, config.extra_backend_tag.as_deref());
let pp_size = config.pp_size.max(1);
if config.is_mla_model {
return if pp_size > 1 {
(0..pp_size)
.map(|pp_rank| format!("{logical_key}_{pp_rank}_k"))
.collect()
} else {
vec![format!("{logical_key}__k")]
};
}
let rank_count = if config.should_split_heads {
config
.tp_lcm_size
.unwrap_or(config.tp_size)
.max(config.tp_size)
.max(1)
} else {
config.tp_size.max(1)
};
let mut query_keys = Vec::with_capacity((pp_size * rank_count * 2) as usize);
for pp_rank in 0..pp_size {
for rank in 0..rank_count {
let suffix = if pp_size > 1 {
format!("{rank}_{pp_rank}")
} else {
rank.to_string()
};
query_keys.push(format!("{logical_key}_{suffix}_k"));
query_keys.push(format!("{logical_key}_{suffix}_v"));
}
}
query_keys
}
fn maybe_prefix_key(logical_key: &str, extra_backend_tag: Option<&str>) -> String {
match extra_backend_tag.filter(|tag| !tag.is_empty()) {
Some(prefix) => format!("{prefix}_{logical_key}"),
None => logical_key.to_string(),
}
}
#[cfg(test)]
mod tests {
use std::ops::Range;
use super::*;
use mockito::{Matcher, Server};
use serde_json::json;
use tokio::sync::watch;
fn mooncake_config() -> SglangHicacheMooncakeConfig {
SglangHicacheMooncakeConfig {
backend: "mooncake".to_string(),
page_size: 4,
tp_size: 1,
pp_size: 1,
is_mla_model: false,
is_eagle: false,
tp_lcm_size: None,
should_split_heads: false,
extra_backend_tag: None,
master_server_address: Some("127.0.0.1:50051".to_string()),
master_metrics_port: 9003,
}
}
fn runtime_watch_with_config(config: SglangHicacheMooncakeConfig) -> RuntimeConfigWatch {
let mut runtime_config = ModelRuntimeConfig::new();
runtime_config
.set_engine_specific(SGLANG_HICACHE_MOONCAKE_RUNTIME_KEY, config)
.unwrap();
let mut workers = HashMap::new();
workers.insert(1, runtime_config);
let (_tx, rx) = watch::channel(workers);
rx
}
#[test]
fn test_logical_page_hashes_match_sglang_for_normal_tokens() {
let hashes = logical_page_hashes(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 4, false);
assert_eq!(
hashes,
vec![
"cf97adeedb59e05bfd73a2b4c2a8885708c4f4f70c84c64b27120e72ab733b72".to_string(),
"4ebfa8a1f3c341517621838c6e1b9aa350307e3f00b3cbd1a07ef740f54396d6".to_string(),
]
);
}
#[test]
fn test_logical_page_hashes_match_sglang_for_eagle_tokens() {
let hashes = logical_page_hashes(&[10, 11, 12, 13, 14], 2, true);
assert_eq!(
hashes,
vec![
"4bde82677ba8b6de843da1713b58a439678ec01b642bbdcffec4acfa81b0ec8e".to_string(),
"75ab93a767bad1e254945d1a0ccfa1588d6ebb803303e412d984baedcbbf04b9".to_string(),
]
);
}
#[test]
fn test_expand_actual_query_keys_for_mha_tp_pp_layout() {
let config = SglangHicacheMooncakeConfig {
tp_size: 2,
pp_size: 2,
..mooncake_config()
};
let query_keys = expand_actual_query_keys("hash", &config);
assert_eq!(
query_keys,
vec![
"hash_0_0_k",
"hash_0_0_v",
"hash_1_0_k",
"hash_1_0_v",
"hash_0_1_k",
"hash_0_1_v",
"hash_1_1_k",
"hash_1_1_v",
]
);
}
#[test]
fn test_expand_actual_query_keys_for_mla_without_pp_uses_double_underscore() {
let config = SglangHicacheMooncakeConfig {
is_mla_model: true,
..mooncake_config()
};
let query_keys = expand_actual_query_keys("hash", &config);
assert_eq!(query_keys, vec!["hash__k"]);
}
#[test]
fn test_expand_actual_query_keys_for_split_heads() {
let config = SglangHicacheMooncakeConfig {
tp_size: 2,
tp_lcm_size: Some(4),
should_split_heads: true,
extra_backend_tag: Some("tag".to_string()),
..mooncake_config()
};
let query_keys = expand_actual_query_keys("hash", &config);
assert_eq!(
query_keys,
vec![
"tag_hash_0_k",
"tag_hash_0_v",
"tag_hash_1_k",
"tag_hash_1_v",
"tag_hash_2_k",
"tag_hash_2_v",
"tag_hash_3_k",
"tag_hash_3_v",
]
);
}
#[tokio::test]
async fn test_check_blocks_queries_mooncake_master() {
let mut server = Server::new_async().await;
let server_url = Url::parse(&server.url()).unwrap();
let hash0 = "cf97adeedb59e05bfd73a2b4c2a8885708c4f4f70c84c64b27120e72ab733b72".to_string();
let hash1 = "4ebfa8a1f3c341517621838c6e1b9aa350307e3f00b3cbd1a07ef740f54396d6".to_string();
let response = json!({
"success": true,
"data": {
format!("{hash0}_0_k"): {"ok": true, "values": []},
format!("{hash0}_0_v"): {"ok": true, "values": []},
format!("{hash1}_0_k"): {"ok": true, "values": []},
format!("{hash1}_0_v"): {"ok": false, "error": "not found"},
}
});
let mock = server
.mock("GET", "/batch_query_keys")
.match_query(Matcher::Exact(format!(
"keys={hash0}_0_k,{hash0}_0_v,{hash1}_0_k,{hash1}_0_v"
)))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response.to_string())
.create_async()
.await;
let config = SglangHicacheMooncakeConfig {
master_server_address: Some(format!("{}:50051", server_url.host_str().unwrap())),
master_metrics_port: server_url.port().unwrap(),
..mooncake_config()
};
let cache = HicacheSharedKvCache::new(runtime_watch_with_config(config));
let hits = cache
.check_blocks(&[1, 2, 3, 4, 5, 6, 7, 8], 4)
.await
.unwrap();
assert_eq!(hits.ranges, vec![Range { start: 0, end: 1 }]);
assert_eq!(hits.total_hits, 1);
mock.assert_async().await;
}
}
...@@ -153,6 +153,7 @@ impl PendingRequest { ...@@ -153,6 +153,7 @@ impl PendingRequest {
expected_output_tokens: self.expected_output_tokens, expected_output_tokens: self.expected_output_tokens,
pinned_worker: None, pinned_worker: None,
allowed_worker_ids: None, allowed_worker_ids: None,
shared_cache_hits: None,
resp_tx: None, resp_tx: None,
} }
} }
......
...@@ -214,6 +214,7 @@ impl KvReplayRouter { ...@@ -214,6 +214,7 @@ impl KvReplayRouter {
), ),
None, None,
None, None,
None,
) )
.await?; .await?;
usize::try_from(response.best_worker.worker_id) usize::try_from(response.best_worker.worker_id)
......
...@@ -199,6 +199,12 @@ pub mod frontend_service { ...@@ -199,6 +199,12 @@ pub mod frontend_service {
/// Upper-bound estimation of KV cache transfer latency in disaggregated serving (seconds) /// Upper-bound estimation of KV cache transfer latency in disaggregated serving (seconds)
pub const KV_TRANSFER_ESTIMATED_LATENCY_SECONDS: &str = "kv_transfer_estimated_latency_seconds"; pub const KV_TRANSFER_ESTIMATED_LATENCY_SECONDS: &str = "kv_transfer_estimated_latency_seconds";
/// Shared cache hit rate (0.0-1.0): fraction of request blocks found in shared cache
pub const SHARED_CACHE_HIT_RATE: &str = "shared_cache_hit_rate";
/// Shared cache blocks beyond device overlap for the selected worker
pub const SHARED_CACHE_BEYOND_BLOCKS: &str = "shared_cache_beyond_blocks";
/// Number of cached tokens (prefix cache hits) per request /// Number of cached tokens (prefix cache hits) per request
pub const CACHED_TOKENS: &str = "cached_tokens"; pub const CACHED_TOKENS: &str = "cached_tokens";
...@@ -508,6 +514,12 @@ pub mod routing_overhead { ...@@ -508,6 +514,12 @@ pub mod routing_overhead {
/// Total routing overhead per request /// Total routing overhead per request
pub const TOTAL_MS: &str = "overhead_total_ms"; pub const TOTAL_MS: &str = "overhead_total_ms";
/// Time spent querying the shared KV cache (Mooncake)
pub const SHARED_CACHE_QUERY_MS: &str = "overhead_shared_cache_query_ms";
/// Total shared cache query errors (timeouts, HTTP failures)
pub const SHARED_CACHE_ERRORS_TOTAL: &str = "shared_cache_errors_total";
} }
/// Router request metrics (component-scoped aggregate histograms + counter) /// Router request metrics (component-scoped aggregate histograms + counter)
...@@ -543,6 +555,12 @@ pub mod router { ...@@ -543,6 +555,12 @@ pub mod router {
/// Predicted KV cache hit rate at routing time (0.0-1.0) /// Predicted KV cache hit rate at routing time (0.0-1.0)
pub const KV_HIT_RATE: &str = "router_kv_hit_rate"; pub const KV_HIT_RATE: &str = "router_kv_hit_rate";
/// Shared cache hit rate (0.0-1.0): fraction of request blocks found in shared cache
pub const SHARED_CACHE_HIT_RATE: &str = "router_shared_cache_hit_rate";
/// Shared cache blocks beyond device overlap for the selected worker
pub const SHARED_CACHE_BEYOND_BLOCKS: &str = "router_shared_cache_beyond_blocks";
} }
/// Frontend pipeline stage and event-loop metrics /// Frontend pipeline stage and event-loop metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment