Unverified Commit 49eb397a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): split Dynamo-native remote indexer [DYN-2593] (#7973)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent d232b450
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
use dynamo_runtime::pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
};
use dynamo_runtime::stream;
use crate::indexer::{IndexerQueryRequest, IndexerQueryResponse};
use crate::standalone_indexer::registry::{IndexerKey, WorkerRegistry};
pub struct IndexerQueryEngine {
pub registry: Arc<WorkerRegistry>,
}
#[async_trait]
impl AsyncEngine<SingleIn<IndexerQueryRequest>, ManyOut<IndexerQueryResponse>, anyhow::Error>
for IndexerQueryEngine
{
async fn generate(
&self,
request: SingleIn<IndexerQueryRequest>,
) -> Result<ManyOut<IndexerQueryResponse>> {
let (req, ctx) = request.into_parts();
let key = IndexerKey {
model_name: req.model_name.clone(),
tenant_id: req.namespace.clone(),
};
let response = match self.registry.get_indexer(&key) {
Some(entry) => match entry.indexer.find_matches(req.block_hashes).await {
Ok(scores) => IndexerQueryResponse::Scores(scores.into()),
Err(err) => IndexerQueryResponse::Error(err.to_string()),
},
None => IndexerQueryResponse::Error(format!(
"no indexer for model={} namespace={}",
req.model_name, req.namespace
)),
};
let response_stream = stream::iter(vec![response]);
Ok(ResponseStream::new(
Box::pin(response_stream),
ctx.context(),
))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
DistributedRuntime, discovery::EventTransportKind, transports::event_plane::EventSubscriber,
};
use crate::protocols::{KV_EVENT_SUBJECT, RouterEvent};
use crate::standalone_indexer::registry::WorkerRegistry;
pub async fn spawn_event_subscriber(
drt: &DistributedRuntime,
namespace: &str,
worker_component_name: &str,
registry: Arc<WorkerRegistry>,
cancel_token: CancellationToken,
) -> Result<()> {
let transport_kind = EventTransportKind::from_env_or_default();
let worker_component = drt.namespace(namespace)?.component(worker_component_name)?;
let mut subscriber = EventSubscriber::for_component_with_transport(
&worker_component,
KV_EVENT_SUBJECT,
transport_kind,
)
.await?
.typed::<RouterEvent>();
let kv_event_subject = format!(
"namespace.{}.component.{}.{}",
namespace, worker_component_name, KV_EVENT_SUBJECT
);
match transport_kind {
EventTransportKind::Nats => {
tracing::info!(
subject = %kv_event_subject,
"KV Indexer subscribing to NATS Core events"
);
}
EventTransportKind::Zmq => {
tracing::info!(
subject = %kv_event_subject,
"KV Indexer subscribing to ZMQ event plane"
);
}
}
tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
tracing::debug!("Event subscriber received cancellation signal");
break;
}
Some(result) = subscriber.next() => {
let (_envelope, event) = match result {
Ok((envelope, event)) => (envelope, event),
Err(err) => {
tracing::warn!("Failed to receive RouterEvent from event plane: {err:?}");
continue;
}
};
let worker_id = event.worker_id;
if let Some(indexer) = registry.get_indexer_for_worker(worker_id) {
indexer.apply_event(event).await;
} else {
tracing::trace!(
worker_id,
"Received event for unknown worker (not yet discovered?)"
);
}
}
}
}
tracing::info!("Event subscriber exiting");
});
Ok(())
}
......@@ -31,17 +31,14 @@ use tracing::Instrument;
use validator::Validate;
pub mod indexer;
mod jetstream;
pub mod metrics;
pub mod prefill_router;
pub mod publisher;
pub mod push_router;
pub mod scheduler;
pub mod sequence;
pub mod subscriber;
pub mod worker_query;
pub use indexer::Indexer;
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
pub use prefill_router::PrefillRouter;
pub use push_router::{DirectRoutingRouter, KvPushRouter};
......@@ -117,6 +114,7 @@ where
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
is_eagle: bool,
_served_indexer_handle: Option<ServedIndexerHandle>,
}
impl<Sel> KvRouter<Sel>
......@@ -142,7 +140,13 @@ where
let cancellation_token = component.drt().primary_token();
let min_initial_workers = min_initial_workers_from_env()?;
let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?;
let indexer = Indexer::new(
component,
&kv_router_config,
block_size,
model_name.as_deref(),
)
.await?;
if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
let mut startup_watch = workers_with_configs.clone();
......@@ -168,12 +172,11 @@ where
)
.await?;
// Start KV event subscription if needed — skip when using a remote indexer
// (the standalone indexer handles its own event subscription).
if kv_router_config.remote_indexer_component.is_some() {
// Start KV event subscription if needed — skip when using a remote indexer.
if kv_router_config.use_remote_indexer {
tracing::info!("Skipping KV event subscription (using remote indexer)");
} else if kv_router_config.should_subscribe_to_kv_events() {
subscriber::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
.await?;
} else {
tracing::info!(
......@@ -183,6 +186,23 @@ where
);
}
let served_indexer_handle = if kv_router_config.serve_indexer {
let model_name = model_name.clone().ok_or_else(|| {
anyhow::anyhow!("model_name is required when serve_indexer is configured")
})?;
Some(
ensure_served_indexer_service(
component.clone(),
ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
model_name,
indexer.clone(),
)
.await?,
)
} else {
None
};
tracing::info!("KV Routing initialized");
Ok(Self {
indexer,
......@@ -193,6 +213,7 @@ where
cancellation_token,
client,
is_eagle,
_served_indexer_handle: served_indexer_handle,
})
}
......
......@@ -18,9 +18,11 @@ use rand::Rng;
use tokio_util::sync::CancellationToken;
use crate::kv_router::{
Indexer, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, router_discovery_query,
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, router_discovery_query,
};
use super::Indexer;
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
......
......@@ -5,71 +5,28 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::StreamExt;
use dynamo_kv_router::{
ConcurrentRadixTreeCompressed, ThreadPoolIndexer,
approx::PruneConfig,
config::KvRouterConfig,
indexer::{
IndexerQueryRequest, IndexerQueryResponse, KV_INDEXER_QUERY_ENDPOINT, KvIndexer,
KvIndexerInterface, KvIndexerMetrics, KvRouterError,
},
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError},
protocols::{
LocalBlockHash, OverlapScores, RouterEvent, TokensWithHashes, WorkerId, WorkerWithDpRank,
},
};
use dynamo_runtime::{
component::Component,
pipeline::{ManyOut, RouterMode, SingleIn, network::egress::push_router::PushRouter},
traits::DistributedRuntimeProvider,
};
use dynamo_runtime::{component::Component, traits::DistributedRuntimeProvider};
use dynamo_tokens::SequenceHash;
use tokio::sync::oneshot;
pub struct RemoteIndexer {
router: PushRouter<IndexerQueryRequest, IndexerQueryResponse>,
model_name: String,
namespace: String,
}
impl RemoteIndexer {
async fn new(
component: &Component,
indexer_component_name: &str,
model_name: String,
) -> Result<Self> {
let namespace = component.namespace().name();
let indexer_ns = component.namespace();
let indexer_component = indexer_ns.component(indexer_component_name)?;
let endpoint = indexer_component.endpoint(KV_INDEXER_QUERY_ENDPOINT);
let client = endpoint.client().await?;
let router =
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?;
Ok(Self {
router,
model_name,
namespace,
})
}
mod jetstream;
pub mod remote;
mod subscriber;
mod worker_query;
async fn find_matches(&self, block_hashes: Vec<LocalBlockHash>) -> Result<OverlapScores> {
let request = IndexerQueryRequest {
model_name: self.model_name.clone(),
namespace: self.namespace.clone(),
block_hashes,
};
let mut stream: ManyOut<IndexerQueryResponse> =
self.router.round_robin(SingleIn::new(request)).await?;
match stream.next().await {
Some(IndexerQueryResponse::Scores(scores)) => Ok(scores.into()),
Some(IndexerQueryResponse::Error(msg)) => {
Err(anyhow::anyhow!("Remote indexer error: {}", msg))
}
None => Err(anyhow::anyhow!("Remote indexer returned empty response")),
}
}
}
use self::remote::RemoteIndexer;
pub use self::remote::{ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
pub(crate) use subscriber::start_subscriber;
pub(crate) use worker_query::start_worker_kv_query_endpoint;
#[derive(Clone)]
pub enum Indexer {
......@@ -84,24 +41,26 @@ impl Indexer {
component: &Component,
kv_router_config: &KvRouterConfig,
block_size: u32,
model_name: Option<String>,
model_name: Option<&str>,
) -> Result<Self> {
if kv_router_config.overlap_score_weight == 0.0 {
return Ok(Self::None);
}
if let Some(ref indexer_component_name) = kv_router_config.remote_indexer_component {
let model_name = model_name.ok_or_else(|| {
anyhow::anyhow!(
"model_name is required when remote_indexer_component is configured"
)
})?;
if kv_router_config.use_remote_indexer {
let model_name = model_name
.ok_or_else(|| {
anyhow::anyhow!("model_name is required when use_remote_indexer is configured")
})?
.to_string();
let indexer_component_name = component.name();
tracing::info!(
remote_indexer_component = %indexer_component_name,
indexer_component = %indexer_component_name,
model_name,
"Using remote KV indexer"
);
let remote = RemoteIndexer::new(component, indexer_component_name, model_name).await?;
let remote =
RemoteIndexer::new(component, model_name, kv_router_config.use_kv_events).await?;
return Ok(Self::Remote(Arc::new(remote)));
}
......@@ -149,14 +108,46 @@ impl Indexer {
match self {
Self::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Self::Concurrent(tpi) => tpi.find_matches(sequence).await,
Self::Remote(remote) => remote.find_matches(sequence).await.map_err(|e| {
tracing::warn!(error = %e, "Remote indexer query failed");
KvRouterError::IndexerOffline
}),
Self::Remote(remote) => match remote.find_matches(sequence).await {
Ok(scores) => Ok(scores),
Err(error) => {
tracing::warn!(error = %error, "Remote indexer query failed");
Ok(OverlapScores::new())
}
},
Self::None => Ok(OverlapScores::new()),
}
}
pub(crate) async fn record_hashed_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
match self {
Self::KvIndexer(indexer) => {
indexer
.process_routing_decision_with_hashes(worker, local_hashes, sequence_hashes)
.await
}
Self::Concurrent(_) => {
tracing::warn!(
"Hashed routing-decision recording is unsupported for concurrent indexers"
);
Err(KvRouterError::IndexerDroppedRequest)
}
Self::Remote(remote) => remote
.record_hashed_routing_decision(worker, local_hashes, sequence_hashes)
.await
.map_err(|error| {
tracing::warn!(error = %error, "Remote indexer write failed");
KvRouterError::IndexerDroppedRequest
}),
Self::None => Ok(()),
}
}
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Self::KvIndexer(indexer) => indexer.dump_events().await,
......@@ -176,16 +167,17 @@ impl Indexer {
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
match self {
Self::KvIndexer(indexer) => {
indexer
.process_routing_decision_for_request(tokens_with_hashes, worker)
Self::KvIndexer(_) | Self::Remote(_) => {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.record_hashed_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
Self::Concurrent(tpi) => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Self::Remote(_) | Self::None => Ok(()),
Self::None => Ok(()),
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, LazyLock};
use anyhow::Result;
use dashmap::DashMap;
use dynamo_kv_router::indexer::{
IndexerQueryRequest, IndexerQueryResponse, IndexerRecordRoutingDecisionRequest,
IndexerRecordRoutingDecisionResponse, KV_INDEXER_QUERY_ENDPOINT,
KV_INDEXER_RECORD_ROUTING_DECISION_ENDPOINT,
};
use dynamo_kv_router::protocols::{LocalBlockHash, OverlapScores, WorkerWithDpRank};
use dynamo_runtime::component::{Client, Component};
use dynamo_runtime::discovery::{DiscoveryInstance, DiscoveryQuery};
use dynamo_runtime::pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, RouterMode, SingleIn,
async_trait, network::Ingress, network::egress::push_router::PushRouter,
};
use dynamo_runtime::stream;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_tokens::SequenceHash;
use futures::StreamExt;
use parking_lot::RwLock;
use tokio::sync::Mutex;
use crate::kv_router::metrics::RemoteIndexerMetrics;
use super::Indexer;
pub struct RemoteIndexer {
query_router: PushRouter<IndexerQueryRequest, IndexerQueryResponse>,
query_client: Client,
record_router: Option<
PushRouter<IndexerRecordRoutingDecisionRequest, IndexerRecordRoutingDecisionResponse>,
>,
record_client: Client,
component: Component,
model_name: String,
metrics: Arc<RemoteIndexerMetrics>,
use_kv_events: bool,
}
impl RemoteIndexer {
pub(super) async fn new(
component: &Component,
model_name: String,
use_kv_events: bool,
) -> Result<Self> {
let query_client = component
.endpoint(KV_INDEXER_QUERY_ENDPOINT)
.client()
.await?;
let query_router = PushRouter::from_client_no_fault_detection(
query_client.clone(),
RouterMode::RoundRobin,
)
.await?;
let record_client = component
.endpoint(KV_INDEXER_RECORD_ROUTING_DECISION_ENDPOINT)
.client()
.await?;
let record_router = if use_kv_events {
None
} else {
Some(
PushRouter::from_client_no_fault_detection(
record_client.clone(),
RouterMode::RoundRobin,
)
.await?,
)
};
let metrics = RemoteIndexerMetrics::from_component(component);
Ok(Self {
query_router,
query_client,
record_router,
record_client,
component: component.clone(),
model_name,
metrics,
use_kv_events,
})
}
pub(super) async fn find_matches(
&self,
block_hashes: Vec<LocalBlockHash>,
) -> Result<OverlapScores> {
self.validate_topology_if_ready().await.inspect_err(|_| {
self.metrics.increment_query_failures();
})?;
let request = IndexerQueryRequest {
model_name: self.model_name.clone(),
block_hashes,
};
let mut stream: ManyOut<IndexerQueryResponse> = self
.query_router
.round_robin(SingleIn::new(request))
.await
.inspect_err(|_| {
self.metrics.increment_query_failures();
})?;
match stream.next().await {
Some(IndexerQueryResponse::Scores(scores)) => Ok(scores.into()),
Some(IndexerQueryResponse::Error(msg)) => {
self.metrics.increment_query_failures();
Err(anyhow::anyhow!("Remote indexer error: {}", msg))
}
None => {
self.metrics.increment_query_failures();
Err(anyhow::anyhow!("Remote indexer returned empty response"))
}
}
}
pub(super) async fn record_hashed_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<()> {
self.validate_topology_if_ready().await.inspect_err(|_| {
self.metrics.increment_write_failures();
})?;
let record_router = self.record_router.as_ref().ok_or_else(|| {
self.metrics.increment_write_failures();
anyhow::anyhow!("remote approximate indexer is not configured for writes")
})?;
let request = IndexerRecordRoutingDecisionRequest {
model_name: self.model_name.clone(),
worker,
local_hashes,
sequence_hashes,
};
let mut stream: ManyOut<IndexerRecordRoutingDecisionResponse> = record_router
.round_robin(SingleIn::new(request))
.await
.inspect_err(|_| {
self.metrics.increment_write_failures();
})?;
match stream.next().await {
Some(IndexerRecordRoutingDecisionResponse::Recorded) => Ok(()),
Some(IndexerRecordRoutingDecisionResponse::Error(msg)) => {
self.metrics.increment_write_failures();
Err(anyhow::anyhow!("Remote indexer write error: {}", msg))
}
None => {
self.metrics.increment_write_failures();
Err(anyhow::anyhow!(
"Remote indexer returned empty write response"
))
}
}
}
async fn validate_topology_if_ready(&self) -> Result<()> {
let query_instances = cached_instance_ids(&self.query_client);
let record_instances = cached_instance_ids(&self.record_client);
if query_instances.is_empty() && record_instances.is_empty() {
return Ok(());
}
if self.use_kv_events {
if !record_instances.is_empty() {
anyhow::bail!(
"remote indexer component {}.{} mixes event-driven and approximate endpoints",
self.component.namespace().name(),
self.component.name()
);
}
return Ok(());
}
if query_instances.len() != 1 || record_instances.len() != 1 {
anyhow::bail!(
"approximate remote indexer component {}.{} must expose exactly one query endpoint and one record endpoint",
self.component.namespace().name(),
self.component.name()
);
}
if query_instances != record_instances {
anyhow::bail!(
"approximate remote indexer component {}.{} must expose query and record endpoints from the same singleton instance",
self.component.namespace().name(),
self.component.name()
);
}
Ok(())
}
}
fn cached_instance_ids(client: &Client) -> HashSet<u64> {
client.instance_ids_avail().iter().copied().collect()
}
type ServiceKey = (u64, String, String);
static SERVED_INDEXER_SERVICES: LazyLock<DashMap<ServiceKey, Arc<ServedIndexerService>>> =
LazyLock::new(DashMap::new);
static SERVICE_CREATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServedIndexerMode {
EventDriven,
Approximate,
}
impl ServedIndexerMode {
pub fn from_use_kv_events(use_kv_events: bool) -> Self {
if use_kv_events {
Self::EventDriven
} else {
Self::Approximate
}
}
fn topology_label(self) -> &'static str {
match self {
Self::EventDriven => "event-driven",
Self::Approximate => "approximate",
}
}
}
struct ServedIndexerService {
mode: ServedIndexerMode,
bindings: Arc<RwLock<HashMap<String, Indexer>>>,
}
impl ServedIndexerService {
async fn start(component: Component, mode: ServedIndexerMode) -> Result<Arc<Self>> {
verify_service_topology(&component, mode).await?;
let bindings = Arc::new(RwLock::new(HashMap::new()));
start_query_endpoint(component.clone(), bindings.clone())?;
if mode == ServedIndexerMode::Approximate {
start_record_endpoint(component.clone(), bindings.clone())?;
}
Ok(Arc::new(Self { mode, bindings }))
}
}
pub struct ServedIndexerHandle {
service: Arc<ServedIndexerService>,
model_name: String,
}
impl Drop for ServedIndexerHandle {
fn drop(&mut self) {
self.service.bindings.write().remove(&self.model_name);
}
}
pub async fn ensure_served_indexer_service(
component: Component,
mode: ServedIndexerMode,
model_name: String,
indexer: Indexer,
) -> Result<ServedIndexerHandle> {
let service = get_or_start_service(component.clone(), mode).await?;
if service.mode != mode {
anyhow::bail!(
"cannot mix {} and {} served indexers under {}.{}",
service.mode.topology_label(),
mode.topology_label(),
component.namespace().name(),
component.name()
);
}
{
let mut bindings = service.bindings.write();
if bindings.contains_key(&model_name) {
anyhow::bail!(
"served indexer for model {} is already registered under {}.{}",
model_name,
component.namespace().name(),
component.name(),
);
}
bindings.insert(model_name.clone(), indexer);
}
Ok(ServedIndexerHandle {
service,
model_name,
})
}
async fn get_or_start_service(
component: Component,
mode: ServedIndexerMode,
) -> Result<Arc<ServedIndexerService>> {
let key = service_key(&component);
if let Some(existing) = SERVED_INDEXER_SERVICES.get(&key) {
return Ok(existing.clone());
}
let _guard = SERVICE_CREATION_LOCK.lock().await;
if let Some(existing) = SERVED_INDEXER_SERVICES.get(&key) {
return Ok(existing.clone());
}
let service = ServedIndexerService::start(component, mode).await?;
SERVED_INDEXER_SERVICES.insert(key, service.clone());
Ok(service)
}
async fn verify_service_topology(component: &Component, mode: ServedIndexerMode) -> Result<()> {
let discovery = component.drt().discovery();
let endpoints = discovery
.list(DiscoveryQuery::ComponentEndpoints {
namespace: component.namespace().name(),
component: component.name().to_string(),
})
.await?;
let mut query_instances = HashSet::new();
let mut record_instances = HashSet::new();
for endpoint in endpoints {
let DiscoveryInstance::Endpoint(instance) = endpoint else {
continue;
};
match instance.endpoint.as_str() {
KV_INDEXER_QUERY_ENDPOINT => {
query_instances.insert(instance.instance_id);
}
KV_INDEXER_RECORD_ROUTING_DECISION_ENDPOINT => {
record_instances.insert(instance.instance_id);
}
_ => {}
}
}
match mode {
ServedIndexerMode::EventDriven => {
if !record_instances.is_empty() {
anyhow::bail!(
"cannot start event-driven served indexer on {}.{}: approximate endpoint already exists",
component.namespace().name(),
component.name()
);
}
}
ServedIndexerMode::Approximate => {
if !query_instances.is_empty() || !record_instances.is_empty() {
anyhow::bail!(
"cannot start approximate served indexer on {}.{}: indexer endpoint already exists",
component.namespace().name(),
component.name()
);
}
}
}
Ok(())
}
fn start_query_endpoint(
component: Component,
bindings: Arc<RwLock<HashMap<String, Indexer>>>,
) -> Result<()> {
let engine = Arc::new(ServedIndexerQueryEngine { bindings });
let ingress =
Ingress::<SingleIn<IndexerQueryRequest>, ManyOut<IndexerQueryResponse>>::for_engine(
engine,
)?;
tokio::spawn(async move {
if let Err(error) = component
.endpoint(KV_INDEXER_QUERY_ENDPOINT)
.endpoint_builder()
.handler(ingress)
.graceful_shutdown(true)
.start()
.await
{
tracing::error!(error = %error, "served indexer query endpoint failed");
}
});
Ok(())
}
fn start_record_endpoint(
component: Component,
bindings: Arc<RwLock<HashMap<String, Indexer>>>,
) -> Result<()> {
let engine = Arc::new(ServedIndexerRecordEngine { bindings });
let ingress = Ingress::<
SingleIn<IndexerRecordRoutingDecisionRequest>,
ManyOut<IndexerRecordRoutingDecisionResponse>,
>::for_engine(engine)?;
tokio::spawn(async move {
if let Err(error) = component
.endpoint(KV_INDEXER_RECORD_ROUTING_DECISION_ENDPOINT)
.endpoint_builder()
.handler(ingress)
.graceful_shutdown(true)
.start()
.await
{
tracing::error!(error = %error, "served indexer record endpoint failed");
}
});
Ok(())
}
struct ServedIndexerQueryEngine {
bindings: Arc<RwLock<HashMap<String, Indexer>>>,
}
#[async_trait]
impl AsyncEngine<SingleIn<IndexerQueryRequest>, ManyOut<IndexerQueryResponse>, anyhow::Error>
for ServedIndexerQueryEngine
{
async fn generate(
&self,
request: SingleIn<IndexerQueryRequest>,
) -> Result<ManyOut<IndexerQueryResponse>> {
let (request, ctx) = request.into_parts();
let indexer = self.bindings.read().get(&request.model_name).cloned();
let response = match indexer {
Some(indexer) => match indexer.find_matches(request.block_hashes).await {
Ok(scores) => IndexerQueryResponse::Scores(scores.into()),
Err(error) => IndexerQueryResponse::Error(error.to_string()),
},
None => IndexerQueryResponse::Error(format!(
"served indexer model {} is not registered",
request.model_name
)),
};
Ok(ResponseStream::new(
Box::pin(stream::iter(vec![response])),
ctx.context(),
))
}
}
struct ServedIndexerRecordEngine {
bindings: Arc<RwLock<HashMap<String, Indexer>>>,
}
#[async_trait]
impl
AsyncEngine<
SingleIn<IndexerRecordRoutingDecisionRequest>,
ManyOut<IndexerRecordRoutingDecisionResponse>,
anyhow::Error,
> for ServedIndexerRecordEngine
{
async fn generate(
&self,
request: SingleIn<IndexerRecordRoutingDecisionRequest>,
) -> Result<ManyOut<IndexerRecordRoutingDecisionResponse>> {
let (request, ctx) = request.into_parts();
let indexer = self.bindings.read().get(&request.model_name).cloned();
let response = match indexer {
Some(indexer) => match indexer
.record_hashed_routing_decision(
request.worker,
request.local_hashes,
request.sequence_hashes,
)
.await
{
Ok(()) => IndexerRecordRoutingDecisionResponse::Recorded,
Err(error) => IndexerRecordRoutingDecisionResponse::Error(error.to_string()),
},
None => IndexerRecordRoutingDecisionResponse::Error(format!(
"served indexer model {} is not registered",
request.model_name
)),
};
Ok(ResponseStream::new(
Box::pin(stream::iter(vec![response])),
ctx.context(),
))
}
}
fn service_key(component: &Component) -> ServiceKey {
(
component.drt().connection_id(),
component.namespace().name(),
component.name().to_string(),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn query_engine_supports_multiple_model_bindings() {
let bindings = Arc::new(RwLock::new(HashMap::from([
("model-a".to_string(), Indexer::None),
("model-b".to_string(), Indexer::None),
])));
let engine = ServedIndexerQueryEngine { bindings };
let request = SingleIn::new(IndexerQueryRequest {
model_name: "model-b".to_string(),
block_hashes: vec![LocalBlockHash(1)],
});
let mut stream = engine.generate(request).await.unwrap();
assert!(matches!(
stream.next().await,
Some(IndexerQueryResponse::Scores(_))
));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::kv_router::{Indexer, worker_query::WorkerQueryClient};
use super::{Indexer, worker_query::WorkerQueryClient};
use anyhow::Result;
use dynamo_kv_router::{
config::KvRouterConfig,
......
......@@ -20,7 +20,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use futures::StreamExt;
use tokio::sync::{Mutex, Semaphore};
use crate::kv_router::Indexer;
use super::Indexer;
use crate::kv_router::worker_kv_indexer_query_endpoint;
use dynamo_kv_router::{
indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse},
......
......@@ -44,7 +44,7 @@ use std::time::Duration;
use dynamo_runtime::component::Component;
use dynamo_runtime::metrics::MetricsHierarchy;
use dynamo_runtime::metrics::prometheus_names::{
frontend_service, labels, name_prefix, router_request, routing_overhead,
frontend_service, labels, name_prefix, router, router_request, routing_overhead,
};
/// Build a router metric name: `"router_" + frontend_service_suffix`.
......@@ -406,6 +406,54 @@ impl RouterRequestMetrics {
}
}
pub struct RemoteIndexerMetrics {
pub query_failures_total: prometheus::IntCounter,
pub write_failures_total: prometheus::IntCounter,
}
static REMOTE_INDEXER_METRICS: OnceLock<Arc<RemoteIndexerMetrics>> = OnceLock::new();
impl RemoteIndexerMetrics {
pub fn from_component(component: &Component) -> Arc<Self> {
REMOTE_INDEXER_METRICS
.get_or_init(|| {
let instance_id = component.drt().discovery().instance_id();
let router_id = instance_id.to_string();
let extra_labels: &[(&str, &str)] = &[(labels::ROUTER_ID, &router_id)];
let metrics = component.metrics();
let query_failures_total = metrics
.create_intcounter(
router::REMOTE_INDEXER_QUERY_FAILURES_TOTAL,
"Total number of remote indexer overlap queries that failed",
extra_labels,
)
.expect("failed to create router_remote_indexer_query_failures_total");
let write_failures_total = metrics
.create_intcounter(
router::REMOTE_INDEXER_WRITE_FAILURES_TOTAL,
"Total number of remote indexer routing-decision writes that failed",
extra_labels,
)
.expect("failed to create router_remote_indexer_write_failures_total");
Arc::new(Self {
query_failures_total,
write_failures_total,
})
})
.clone()
}
pub fn increment_query_failures(&self) {
self.query_failures_total.inc();
}
pub fn increment_write_failures(&self) {
self.write_failures_total.inc();
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -24,7 +24,7 @@ use dynamo_runtime::{
};
use crate::kv_router::{
KV_EVENT_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE, worker_query::start_worker_kv_query_endpoint,
KV_EVENT_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE, indexer::start_worker_kv_query_endpoint,
};
mod event_processor;
......
......@@ -506,6 +506,14 @@ pub mod router {
/// Total number of requests processed by the router
pub const REQUESTS_TOTAL: &str = "router_requests_total";
/// Total number of remote indexer overlap queries that failed
pub const REMOTE_INDEXER_QUERY_FAILURES_TOTAL: &str =
"router_remote_indexer_query_failures_total";
/// Total number of remote indexer routing-decision writes that failed
pub const REMOTE_INDEXER_WRITE_FAILURES_TOTAL: &str =
"router_remote_indexer_write_failures_total";
/// Time to first token observed at the router (seconds)
pub const TIME_TO_FIRST_TOKEN_SECONDS: &str = "router_time_to_first_token_seconds";
......
......@@ -320,6 +320,234 @@ def _test_router_two_routers(
kv_router.__exit__(None, None, None)
def _test_remote_indexer_decisions(
engine_workers,
model_name: str,
block_size: int = 8,
use_kv_events: bool = True,
test_dp_rank: bool = True,
request_plane: str = "nats",
store_backend: str = "etcd",
):
"""Validate remote-indexer-backed routing decisions using direct KvRouter instances."""
async def wait_for_worker_ids(endpoint, expected_num_workers: int) -> list[int]:
client = await endpoint.client()
for _ in range(120):
worker_ids = sorted(set(client.instance_ids()))
if len(worker_ids) >= expected_num_workers:
return worker_ids
await asyncio.sleep(1)
raise TimeoutError("Timed out waiting for backend worker IDs")
async def wait_for_served_indexer(
runtime,
expected_query_instances: int,
expected_record_instances: int,
) -> None:
query_endpoint = runtime.endpoint(
f"{engine_workers.namespace}.{engine_workers.component_name}.kv_indexer_query"
)
query_client = await query_endpoint.client()
record_endpoint = runtime.endpoint(
f"{engine_workers.namespace}.{engine_workers.component_name}.kv_indexer_record_routing_decision"
)
record_client = await record_endpoint.client()
for _ in range(120):
query_ids = set(query_client.instance_ids())
record_ids = set(record_client.instance_ids())
if use_kv_events:
if len(query_ids) >= expected_query_instances and len(record_ids) == 0:
return
elif (
len(query_ids) == expected_query_instances
and len(record_ids) == expected_record_instances
and query_ids == record_ids
):
return
await asyncio.sleep(0.5)
raise TimeoutError("Timed out waiting for served indexer endpoints to register")
async def test_sync():
endpoint_path = (
f"{engine_workers.namespace}.{engine_workers.component_name}.generate"
)
expected_num_instances = engine_workers.num_workers
async def make_router(*, serve_indexer: bool, use_remote_indexer: bool):
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=use_kv_events,
router_track_prefill_tokens=True,
serve_indexer=serve_indexer,
use_remote_indexer=use_remote_indexer,
)
last_error: Exception | None = None
for _ in range(60):
runtime = get_runtime(
store_backend=store_backend, request_plane=request_plane
)
endpoint = runtime.endpoint(endpoint_path)
try:
with min_initial_workers_env(expected_num_instances):
kv_router = KvRouter(
endpoint=endpoint,
block_size=block_size,
kv_router_config=kv_router_config,
)
return runtime, endpoint, kv_router
except Exception as error:
last_error = error
if not (serve_indexer or use_remote_indexer):
raise
del endpoint
del runtime
await asyncio.sleep(1.0)
raise AssertionError(
"Timed out waiting for model discovery before creating remote-indexer router"
) from last_error
serving_runtimes = []
serving_endpoints = []
serving_routers = []
runtime_a, endpoint_a, router_a = await make_router(
serve_indexer=True, use_remote_indexer=False
)
serving_runtimes.append(runtime_a)
serving_endpoints.append(endpoint_a)
serving_routers.append(router_a)
if use_kv_events:
runtime_b, endpoint_b, router_b = await make_router(
serve_indexer=True, use_remote_indexer=False
)
serving_runtimes.append(runtime_b)
serving_endpoints.append(endpoint_b)
serving_routers.append(router_b)
await wait_for_served_indexer(
serving_runtimes[0],
expected_query_instances=len(serving_routers),
expected_record_instances=0 if use_kv_events else 1,
)
_, consumer_endpoint, consumer_router = await make_router(
serve_indexer=False, use_remote_indexer=True
)
worker_ids = await wait_for_worker_ids(
serving_endpoints[0], expected_num_instances
)
if len(worker_ids) >= 2:
worker_a_id = worker_ids[0]
worker_b_id = worker_ids[1]
elif len(worker_ids) == 1 and test_dp_rank:
worker_a_id = worker_ids[0]
worker_b_id = worker_ids[0]
else:
raise AssertionError(
f"Need at least 2 routing targets but got {len(worker_ids)} worker(s) "
f"with test_dp_rank={test_dp_rank}"
)
dp_rank_a = 0 if test_dp_rank else None
dp_rank_b = 1 if test_dp_rank else None
logger.info(
"Remote-indexer routing targets: worker_a=%s/%s worker_b=%s/%s",
worker_a_id,
dp_rank_a,
worker_b_id,
dp_rank_b,
)
blocks = [
[random.randint(1, 10000) for _ in range(block_size)] for _ in range(7)
]
A, B, C, D, E, F, G = blocks
request_specs = [
(serving_routers[0], A + B, worker_a_id, dp_rank_a, 0.1),
(serving_routers[0], A + C + D, worker_a_id, dp_rank_a, 0.1),
(serving_routers[-1], A + C + E, worker_b_id, dp_rank_b, 2.0),
(consumer_router, A + C + D + F, None, None, 2.0),
(consumer_router, A + C + G, None, None, 2.0),
]
responses: list[dict[str, Optional[int]]] = []
for i, (
kv_router,
token_ids,
forced_worker_id,
forced_dp_rank,
sleep_after,
) in enumerate(request_specs, start=1):
logger.info(
"Sending remote-indexer request %s/5%s%s",
i,
(
f" forced_worker_id={forced_worker_id}"
if forced_worker_id is not None
else ""
),
(
f" forced_dp_rank={forced_dp_rank}"
if forced_dp_rank is not None
else ""
),
)
result = await send_request_via_python_kv_router(
kv_python_router=kv_router,
model_name=model_name,
token_ids=token_ids,
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True,
"max_tokens": 2,
},
worker_id=forced_worker_id,
dp_rank=forced_dp_rank,
return_worker_ids=True,
)
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
responses.append(result)
if sleep_after > 0:
await asyncio.sleep(sleep_after)
req4 = responses[3]
assert req4["prefill_worker_id"] == worker_a_id, (
f"Request 4: expected prefill_worker_id={worker_a_id} (longest prefix match), "
f"got {req4['prefill_worker_id']}"
)
if test_dp_rank:
assert req4["prefill_dp_rank"] == dp_rank_a, (
f"Request 4: expected prefill_dp_rank={dp_rank_a} "
f"(longest prefix match), got {req4['prefill_dp_rank']}"
)
req5 = responses[4]
assert req5["prefill_worker_id"] == worker_b_id, (
f"Request 5: expected prefill_worker_id={worker_b_id} (tiebreak by smaller tree), "
f"got {req5['prefill_worker_id']}"
)
if test_dp_rank:
assert req5["prefill_dp_rank"] == dp_rank_b, (
f"Request 5: expected prefill_dp_rank={dp_rank_b} "
f"(tiebreak by smaller tree), got {req5['prefill_dp_rank']}"
)
await wait_for_worker_ids(consumer_endpoint, expected_num_instances)
asyncio.run(test_sync())
def _test_python_router_bindings(
engine_workers,
endpoint,
......
......@@ -30,6 +30,8 @@ class FrontendRouterProcess(ManagedProcess):
router_mode: str = "kv",
min_initial_workers: int | None = None,
router_aic_config: dict[str, str | int] | None = None,
serve_indexer: bool = False,
use_remote_indexer: bool = False,
):
command = [
"python3",
......@@ -65,6 +67,12 @@ class FrontendRouterProcess(ManagedProcess):
if durable_kv_events:
command.append("--router-durable-kv-events")
if serve_indexer:
command.append("--serve-indexer")
if use_remote_indexer:
command.append("--use-remote-indexer")
if router_aic_config is not None:
command.extend(
[
......
......@@ -21,6 +21,7 @@ from tests.router.common import (
_test_busy_threshold_endpoint,
_test_disagg_direct_mode,
_test_python_router_bindings,
_test_remote_indexer_decisions,
_test_router_basic,
_test_router_decisions,
_test_router_decisions_disagg,
......@@ -1014,14 +1015,28 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(300) # bumped for xdist contention (was 29s; ~9.55s serial avg)
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.parametrize(
"durable_kv_events,use_kv_events,zmq_kv_events",
"durable_kv_events,use_kv_events,zmq_kv_events,use_remote_indexer",
[
(True, True, False), # JetStream mode with KV events
(False, True, False), # NATS Core mode with local indexer (default)
(False, False, False), # Approximate mode (--no-kv-events) - no KV events
(False, True, True), # ZMQ mode: mocker → ZMQ PUB → relay → NATS
(True, True, False, False), # JetStream mode with KV events
(False, True, False, False), # NATS Core mode with local indexer (default)
(False, True, False, True), # NATS Core mode with a served remote indexer
(False, False, False, False), # Approximate mode (--no-kv-events)
(
False,
False,
False,
True,
), # Approximate mode with a singleton served remote indexer
(False, True, True, False), # ZMQ mode: mocker → ZMQ PUB → relay → NATS
],
ids=[
"jetstream",
"nats_core",
"nats_core_remote",
"no_kv_events",
"no_kv_events_remote",
"zmq",
],
ids=["jetstream", "nats_core", "no_kv_events", "zmq"],
indirect=["durable_kv_events"],
)
def test_router_decisions(
......@@ -1032,18 +1047,24 @@ def test_router_decisions(
use_kv_events,
request_plane,
zmq_kv_events,
use_remote_indexer,
):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Parameterized to test:
- JetStream mode: KV events via NATS JetStream (durable)
- NATS Core mode (default): KV events via NATS Core with local indexer on workers
- NATS Core mode with a served remote indexer
- Approximate mode (--no-kv-events): No KV events, router predicts cache state
based on routing decisions with TTL-based expiration and pruning
- Approximate mode with a singleton served remote indexer
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
logger.info(
f"Starting test router decisions: durable_kv_events={durable_kv_events}, use_kv_events={use_kv_events}"
"Starting test router decisions: durable_kv_events=%s, use_kv_events=%s, use_remote_indexer=%s",
durable_kv_events,
use_kv_events,
use_remote_indexer,
)
# Create mocker args dictionary with dp_size=4
......@@ -1066,10 +1087,18 @@ def test_router_decisions(
) as mockers:
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers
# Get runtime and create endpoint
if use_remote_indexer:
_test_remote_indexer_decisions(
mockers,
MODEL_NAME,
block_size=8,
use_kv_events=use_kv_events,
test_dp_rank=True,
request_plane=request_plane,
)
return
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the mockers
endpoint = runtime.endpoint(f"{mockers.namespace}.mocker.generate")
_test_router_decisions(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment