Unverified Commit 35f99f93 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-indexer): multi-model and multi-tenant isolation (#6830)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent e0a2e7bb
...@@ -13,6 +13,14 @@ This is distinct from the [Standalone Router](../../../components/src/dynamo/rou ...@@ -13,6 +13,14 @@ This is distinct from the [Standalone Router](../../../components/src/dynamo/rou
The HTTP API follows the [Mooncake KV Indexer RFC](https://github.com/kvcache-ai/Mooncake/issues/1403) conventions. The HTTP API follows the [Mooncake KV Indexer RFC](https://github.com/kvcache-ai/Mooncake/issues/1403) conventions.
## Multi-Model and Multi-Tenant Support
The indexer maintains one radix tree per `(model_name, tenant_id)` pair. Workers registered with different model names or tenant IDs are isolated into separate indexers — queries against one model/tenant never return scores from another.
- **`model_name`** (required on `/register` and `/query`): Identifies the model. Workers serving different models get separate radix trees.
- **`tenant_id`** (optional, defaults to `"default"`): Enables multi-tenant isolation within the same model. Omit for single-tenant deployments.
- **`block_size`** is per-indexer: the first `/register` call for a given `(model_name, tenant_id)` sets the block size. Subsequent registrations for the same pair must use the same block size or the request will fail.
## Compatibility ## Compatibility
The standalone indexer works with any engine that publishes KV cache events over ZMQ in the expected msgpack format. This includes bare vLLM and SGLang engines, which emit ZMQ KV events natively — no Dynamo-specific wrapper is required. The standalone indexer works with any engine that publishes KV cache events over ZMQ in the expected msgpack format. This includes bare vLLM and SGLang engines, which emit ZMQ KV events natively — no Dynamo-specific wrapper is required.
...@@ -35,56 +43,84 @@ cargo build -p dynamo-kv-router --features indexer-bin --bin dynamo-kv-indexer ...@@ -35,56 +43,84 @@ cargo build -p dynamo-kv-router --features indexer-bin --bin dynamo-kv-indexer
## CLI ## CLI
```bash ```bash
dynamo-kv-indexer --block-size 16 --port 8090 [--threads 1] [--workers "1=tcp://host:5557,2=tcp://host:5558"] dynamo-kv-indexer --port 8090 [--threads 1] [--block-size 16 --model-name my-model --tenant-id default --workers "1=tcp://host:5557,2=tcp://host:5558"]
``` ```
| Flag | Default | Description | | Flag | Default | Description |
|------|---------|-------------| |------|---------|-------------|
| `--block-size` | (required) | KV cache block size (must match the engine's block size) | | `--block-size` | (none) | KV cache block size for initial `--workers` (required when `--workers` is set) |
| `--port` | `8090` | HTTP server listen port | | `--port` | `8090` | HTTP server listen port |
| `--threads` | `1` | Number of indexer threads (1 = single-threaded, >1 = thread pool) | | `--threads` | `1` | Number of indexer threads (1 = single-threaded, >1 = thread pool) |
| `--workers` | (none) | Initial workers as `instance_id=zmq_address,...` pairs | | `--workers` | (none) | Initial workers as `instance_id=zmq_address,...` pairs |
| `--model-name` | `default` | Model name for initial `--workers` |
| `--tenant-id` | `default` | Tenant ID for initial `--workers` |
## HTTP API ## HTTP API
### `POST /register` — Register an endpoint ### `POST /register` — Register an endpoint
Register a ZMQ endpoint for an instance. Call once per dp_rank for data-parallel workers: Register a ZMQ endpoint for an instance. Each call creates or reuses the indexer for the given `(model_name, tenant_id)` pair.
```bash ```bash
# Single dp_rank (dp_rank defaults to 0) # Single model, default tenant
curl -X POST http://localhost:8090/register \ curl -X POST http://localhost:8090/register \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"instance_id": 1, "endpoint": "tcp://127.0.0.1:5557"}' -d '{
"instance_id": 1,
# Multiple dp_ranks — register each separately "endpoint": "tcp://127.0.0.1:5557",
"model_name": "llama-3-8b",
"block_size": 16
}'
# With tenant isolation
curl -X POST http://localhost:8090/register \ curl -X POST http://localhost:8090/register \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"instance_id": 1, "endpoint": "tcp://127.0.0.1:5557", "dp_rank": 0}' -d '{
curl -X POST http://localhost:8090/register \ "instance_id": 2,
-H 'Content-Type: application/json' \ "endpoint": "tcp://127.0.0.1:5558",
-d '{"instance_id": 1, "endpoint": "tcp://127.0.0.1:5558", "dp_rank": 1}' "model_name": "llama-3-8b",
"tenant_id": "customer-a",
"block_size": 16,
"dp_rank": 0
}'
``` ```
The indexer spawns a ZMQ SUB listener for each endpoint and begins consuming KV events. | Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `instance_id` | yes | — | Worker instance identifier |
| `endpoint` | yes | — | ZMQ PUB address to subscribe to |
| `model_name` | yes | — | Model name (used to select the indexer) |
| `block_size` | yes | — | KV cache block size (must match the engine) |
| `tenant_id` | no | `"default"` | Tenant identifier for isolation |
| `dp_rank` | no | `0` | Data parallel rank |
### `POST /unregister` — Deregister an instance ### `POST /unregister` — Deregister an instance
Remove all dp_ranks for an instance, or a specific dp_rank: Remove an instance. Omitting `tenant_id` removes the instance from **all** tenants for the given model; providing it targets only that tenant's indexer.
```bash ```bash
# Remove all dp_ranks # Remove from all tenants
curl -X POST http://localhost:8090/unregister \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1, "model_name": "llama-3-8b"}'
# Remove from a specific tenant
curl -X POST http://localhost:8090/unregister \ curl -X POST http://localhost:8090/unregister \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"instance_id": 1}' -d '{"instance_id": 1, "model_name": "llama-3-8b", "tenant_id": "customer-a"}'
# Remove a specific dp_rank # Remove a specific dp_rank
curl -X POST http://localhost:8090/unregister \ curl -X POST http://localhost:8090/unregister \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"instance_id": 1, "dp_rank": 0}' -d '{"instance_id": 1, "model_name": "llama-3-8b", "tenant_id": "default", "dp_rank": 0}'
``` ```
Cancels ZMQ listeners and removes the instance's blocks from the radix tree. | Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `instance_id` | yes | — | Worker instance to remove |
| `model_name` | yes | — | Model name (identifies the indexer) |
| `tenant_id` | no | — | Tenant identifier (omit to remove from all tenants) |
| `dp_rank` | no | — | Specific dp_rank to remove (omit to remove all) |
### `GET /workers` — List registered instances ### `GET /workers` — List registered instances
...@@ -99,43 +135,66 @@ Returns: ...@@ -99,43 +135,66 @@ Returns:
### `POST /query` — Query overlap for token IDs ### `POST /query` — Query overlap for token IDs
Given raw token IDs, compute block hashes and return per-instance overlap scores: Given raw token IDs, compute block hashes and return per-instance overlap scores (in matched tokens):
```bash ```bash
curl -X POST http://localhost:8090/query \ curl -X POST http://localhost:8090/query \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}' -d '{"token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], "model_name": "llama-3-8b"}'
``` ```
Returns: Returns:
```json ```json
{ {
"scores": {"1": {"0": 2}, "2": {"1": 0}}, "scores": {"1": {"0": 32}, "2": {"1": 0}},
"frequencies": [1, 1], "frequencies": [1, 1],
"tree_sizes": {"1": {"0": 5}, "2": {"1": 3}} "tree_sizes": {"1": {"0": 5}, "2": {"1": 3}}
} }
``` ```
Scores are nested by `instance_id` then `dp_rank`. Higher score means more cached prefix blocks on that instance. Scores are in **matched tokens** (block overlap count × block size). Nested by `instance_id` then `dp_rank`.
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `token_ids` | yes | — | Token sequence to query |
| `model_name` | yes | — | Model name (selects the indexer) |
| `tenant_id` | no | `"default"` | Tenant identifier |
| `lora_name` | no | — | LoRA adapter (overrides indexer-level lora_name for this query) |
### `POST /query_by_hash` — Query overlap for pre-computed hashes ### `POST /query_by_hash` — Query overlap for pre-computed hashes
```bash ```bash
curl -X POST http://localhost:8090/query_by_hash \ curl -X POST http://localhost:8090/query_by_hash \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"block_hashes": [123456, 789012]}' -d '{"block_hashes": [123456, 789012], "model_name": "llama-3-8b"}'
``` ```
Same response format as `/query`. Same response format as `/query`. Scores are in matched tokens.
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `block_hashes` | yes | — | Pre-computed block hash array |
| `model_name` | yes | — | Model name (selects the indexer) |
| `tenant_id` | no | `"default"` | Tenant identifier |
### `GET /dump` — Dump all radix tree events ### `GET /dump` — Dump all radix tree events
Returns the full radix tree state as a JSON array of `RouterEvent` objects: Returns the full radix tree state as a JSON object keyed by `model_name:tenant_id`:
```bash ```bash
curl http://localhost:8090/dump curl http://localhost:8090/dump
``` ```
Returns:
```json
{
"llama-3-8b:default": [<RouterEvent>, ...],
"mistral-7b:customer-a": [<RouterEvent>, ...]
}
```
Each indexer is dumped concurrently.
## Limitations ## Limitations
- **ZMQ only**: Workers must publish KV events via ZMQ PUB sockets. The standalone indexer does not subscribe to NATS event streams. - **ZMQ only**: Workers must publish KV events via ZMQ PUB sockets. The standalone indexer does not subscribe to NATS event streams.
...@@ -153,7 +212,7 @@ graph TD ...@@ -153,7 +212,7 @@ graph TD
subgraph "Standalone Indexer (HTTP)" subgraph "Standalone Indexer (HTTP)"
REG[Worker Registry] REG[Worker Registry]
ZMQ[ZMQ SUB Listeners] ZMQ[ZMQ SUB Listeners]
IDX[Indexer / Radix Tree] IDX["Indexer Map<br/>(model, tenant) → Radix Tree"]
HTTP[HTTP API<br/>/query /dump /register] HTTP[HTTP API<br/>/query /dump /register]
end end
......
...@@ -32,6 +32,13 @@ impl Indexer { ...@@ -32,6 +32,13 @@ impl Indexer {
} }
} }
pub async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: u32) {
match self {
Indexer::Single(idx) => idx.remove_worker_dp_rank(worker_id, dp_rank).await,
Indexer::Concurrent(idx) => idx.remove_worker_dp_rank(worker_id, dp_rank).await,
}
}
pub async fn find_matches(&self, hashes: Vec<LocalBlockHash>) -> Result<OverlapScores> { pub async fn find_matches(&self, hashes: Vec<LocalBlockHash>) -> Result<OverlapScores> {
match self { match self {
Indexer::Single(idx) => idx.find_matches(hashes).await.map_err(Into::into), Indexer::Single(idx) => idx.find_matches(hashes).await.map_err(Into::into),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::atomic::AtomicU32;
use std::time::Duration; use std::time::Duration;
use rmp_serde as rmps; use rmp_serde as rmps;
...@@ -48,7 +48,7 @@ pub async fn run_zmq_listener( ...@@ -48,7 +48,7 @@ pub async fn run_zmq_listener(
return; return;
} }
let next_event_id = AtomicU64::new(0); let mut next_event_id = 0u64;
let warning_count = Arc::new(AtomicU32::new(0)); let warning_count = Arc::new(AtomicU32::new(0));
let mut consecutive_errors = 0u32; let mut consecutive_errors = 0u32;
#[allow(unused_assignments)] #[allow(unused_assignments)]
...@@ -94,29 +94,28 @@ pub async fn run_zmq_listener( ...@@ -94,29 +94,28 @@ pub async fn run_zmq_listener(
consecutive_errors = 0; consecutive_errors = 0;
let mut frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|f| f.to_vec()).collect(); if msg.len() != 3 {
if frames.len() != 3 { tracing::warn!(worker_id, "Unexpected ZMQ frame count: {}", msg.len());
tracing::warn!(worker_id, "Unexpected ZMQ frame count: {}", frames.len());
continue; continue;
} }
let payload = frames.pop().unwrap(); let seq_bytes = msg.get(1).unwrap();
let seq_bytes = frames.pop().unwrap();
if seq_bytes.len() != 8 { if seq_bytes.len() != 8 {
tracing::warn!(worker_id, "Invalid sequence number length: {}", seq_bytes.len()); tracing::warn!(worker_id, "Invalid sequence number length: {}", seq_bytes.len());
continue; continue;
} }
let batch_result = rmps::from_slice::<KvEventBatch>(&payload); let payload = msg.get(2).unwrap();
let batch_result = rmps::from_slice::<KvEventBatch>(payload);
let Ok(batch) = batch_result else { let Ok(batch) = batch_result else {
tracing::warn!(worker_id, "Failed to decode KvEventBatch: {}", batch_result.unwrap_err()); tracing::warn!(worker_id, "Failed to decode KvEventBatch: {}", batch_result.unwrap_err());
continue; continue;
}; };
let effective_dp_rank = batch.data_parallel_rank.map_or(dp_rank, |r| r as u32); let effective_dp_rank = batch.data_parallel_rank.map_or(dp_rank, |r| r as u32);
for raw_event in batch.events.into_iter() { for raw_event in batch.events {
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst); let event_id = next_event_id;
next_event_id += 1;
let kv_event = convert_event(raw_event, event_id, block_size, effective_dp_rank, &warning_count); let kv_event = convert_event(raw_event, event_id, block_size, effective_dp_rank, &warning_count);
let router_event = RouterEvent::new(worker_id, kv_event); let router_event = RouterEvent::new(worker_id, kv_event);
indexer.apply_event(router_event).await; indexer.apply_event(router_event).await;
......
...@@ -11,16 +11,15 @@ mod listener; ...@@ -11,16 +11,15 @@ mod listener;
mod registry; mod registry;
mod server; mod server;
use indexer::create_indexer;
use registry::WorkerRegistry; use registry::WorkerRegistry;
use server::{AppState, create_router}; use server::{AppState, create_router};
#[derive(Parser)] #[derive(Parser)]
#[command(name = "dynamo-kv-indexer", about = "Standalone KV cache indexer")] #[command(name = "dynamo-kv-indexer", about = "Standalone KV cache indexer")]
struct Cli { struct Cli {
/// KV cache block size (must match the vLLM engine's block size) /// KV cache block size for initial workers registered via --workers
#[arg(long)] #[arg(long)]
block_size: u32, block_size: Option<u32>,
/// HTTP server port /// HTTP server port
#[arg(long, default_value_t = 8090)] #[arg(long, default_value_t = 8090)]
...@@ -33,6 +32,14 @@ struct Cli { ...@@ -33,6 +32,14 @@ struct Cli {
/// Initial workers as "worker_id=zmq_address,..." (e.g. "1=tcp://host:5557,2=tcp://host:5558") /// Initial workers as "worker_id=zmq_address,..." (e.g. "1=tcp://host:5557,2=tcp://host:5558")
#[arg(long)] #[arg(long)]
workers: Option<String>, workers: Option<String>,
/// Model name for initial workers registered via --workers
#[arg(long, default_value = "default")]
model_name: String,
/// Tenant ID for initial workers registered via --workers
#[arg(long, default_value = "default")]
tenant_id: String,
} }
fn parse_workers(s: &str) -> Vec<(u64, String)> { fn parse_workers(s: &str) -> Vec<(u64, String)> {
...@@ -58,26 +65,34 @@ async fn main() -> anyhow::Result<()> { ...@@ -58,26 +65,34 @@ async fn main() -> anyhow::Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
tracing::info!( tracing::info!(
block_size = cli.block_size, block_size = ?cli.block_size,
port = cli.port, port = cli.port,
threads = cli.threads, threads = cli.threads,
model_name = %cli.model_name,
tenant_id = %cli.tenant_id,
"Starting standalone KV cache indexer" "Starting standalone KV cache indexer"
); );
let indexer = create_indexer(cli.block_size, cli.threads); let registry = WorkerRegistry::new(cli.threads);
let registry = WorkerRegistry::new(indexer, cli.block_size);
if let Some(ref workers_str) = cli.workers { if let Some(ref workers_str) = cli.workers {
let block_size = cli.block_size.ok_or_else(|| {
anyhow::anyhow!("--block-size is required when --workers is specified")
})?;
for (instance_id, endpoint) in parse_workers(workers_str) { for (instance_id, endpoint) in parse_workers(workers_str) {
tracing::info!(instance_id, endpoint, "Registering initial worker"); tracing::info!(instance_id, endpoint, "Registering initial worker");
registry.register(instance_id, endpoint, 0)?; registry.register(
instance_id,
endpoint,
0,
cli.model_name.clone(),
cli.tenant_id.clone(),
block_size,
)?;
} }
} }
let state = Arc::new(AppState { let state = Arc::new(AppState { registry });
registry,
block_size: cli.block_size,
});
let app = create_router(state); let app = create_router(state);
let listener = TcpListener::bind(("0.0.0.0", cli.port)).await?; let listener = TcpListener::bind(("0.0.0.0", cli.port)).await?;
......
...@@ -5,83 +5,153 @@ use std::collections::HashMap; ...@@ -5,83 +5,153 @@ use std::collections::HashMap;
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use dashmap::DashMap; use dashmap::DashMap;
use dashmap::mapref::one::Ref;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_kv_router::protocols::WorkerId; use dynamo_kv_router::protocols::WorkerId;
use super::indexer::Indexer; use super::indexer::{Indexer, create_indexer};
use super::listener::run_zmq_listener; use super::listener::run_zmq_listener;
pub struct EndpointEntry { #[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub endpoint: String, pub struct IndexerKey {
pub model_name: String,
pub tenant_id: String,
}
pub struct IndexerEntry {
pub indexer: Indexer,
pub block_size: u32,
} }
pub struct WorkerEntry { pub struct WorkerEntry {
pub endpoints: HashMap<u32, EndpointEntry>, pub endpoints: HashMap<u32, String>,
cancel: CancellationToken, cancels: HashMap<u32, CancellationToken>,
} }
pub struct WorkerRegistry { pub struct WorkerRegistry {
workers: DashMap<WorkerId, WorkerEntry>, workers: DashMap<WorkerId, WorkerEntry>,
indexer: Indexer, indexers: DashMap<IndexerKey, IndexerEntry>,
block_size: u32, num_threads: usize,
} }
impl WorkerRegistry { impl WorkerRegistry {
pub fn new(indexer: Indexer, block_size: u32) -> Self { pub fn new(num_threads: usize) -> Self {
Self { Self {
workers: DashMap::new(), workers: DashMap::new(),
indexer, indexers: DashMap::new(),
block_size, num_threads,
} }
} }
pub fn register(&self, instance_id: WorkerId, endpoint: String, dp_rank: u32) -> Result<()> { pub fn register(
&self,
instance_id: WorkerId,
endpoint: String,
dp_rank: u32,
model_name: String,
tenant_id: String,
block_size: u32,
) -> Result<()> {
let key = IndexerKey {
model_name,
tenant_id,
};
// Get or create the indexer for this (model, tenant) pair.
// Use the entry API for atomic check-and-insert.
let indexer_entry = self.indexers.entry(key.clone()).or_insert_with(|| {
tracing::info!(
model_name = %key.model_name,
tenant_id = %key.tenant_id,
block_size,
"Creating new indexer"
);
IndexerEntry {
indexer: create_indexer(block_size, self.num_threads),
block_size,
}
});
if indexer_entry.block_size != block_size {
bail!(
"block_size mismatch for model={} tenant={}: existing={}, requested={}",
key.model_name,
key.tenant_id,
indexer_entry.block_size,
block_size
);
}
let indexer = indexer_entry.indexer.clone();
let bs = indexer_entry.block_size;
drop(indexer_entry);
let mut entry = self let mut entry = self
.workers .workers
.entry(instance_id) .entry(instance_id)
.or_insert_with(|| WorkerEntry { .or_insert_with(|| WorkerEntry {
endpoints: HashMap::new(), endpoints: HashMap::new(),
cancel: CancellationToken::new(), cancels: HashMap::new(),
}); });
if entry.endpoints.contains_key(&dp_rank) { if entry.endpoints.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already registered"); bail!("instance {instance_id} dp_rank {dp_rank} already registered");
} }
let child_cancel = entry.cancel.child_token(); let cancel = CancellationToken::new();
let indexer = self.indexer.clone(); let child_cancel = cancel.child_token();
let block_size = self.block_size;
let addr = endpoint.clone(); let addr = endpoint.clone();
tokio::spawn(async move { tokio::spawn(async move {
run_zmq_listener( run_zmq_listener(instance_id, dp_rank, addr, bs, indexer, child_cancel).await;
instance_id,
dp_rank,
addr,
block_size,
indexer,
child_cancel,
)
.await;
}); });
entry.endpoints.insert(dp_rank, EndpointEntry { endpoint }); entry.endpoints.insert(dp_rank, endpoint);
entry.cancels.insert(dp_rank, cancel);
Ok(()) Ok(())
} }
pub async fn deregister(&self, instance_id: WorkerId) -> Result<()> { pub async fn deregister(
&self,
instance_id: WorkerId,
model_name: &str,
tenant_id: &str,
) -> Result<()> {
let (_, entry) = self let (_, entry) = self
.workers .workers
.remove(&instance_id) .remove(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?; .ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
entry.cancel.cancel(); for cancel in entry.cancels.values() {
self.indexer.remove_worker(instance_id).await; cancel.cancel();
}
let key = IndexerKey {
model_name: model_name.to_string(),
tenant_id: tenant_id.to_string(),
};
if let Some(ie) = self.indexers.get(&key) {
ie.indexer.remove_worker(instance_id).await;
} else {
tracing::warn!(
instance_id,
model_name,
tenant_id,
"indexer key not found on deregister; tree will not be cleaned up"
);
}
Ok(()) Ok(())
} }
pub async fn deregister_dp_rank(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> { pub async fn deregister_dp_rank(
&self,
instance_id: WorkerId,
dp_rank: u32,
model_name: &str,
tenant_id: &str,
) -> Result<()> {
let mut entry = self let mut entry = self
.workers .workers
.get_mut(&instance_id) .get_mut(&instance_id)
...@@ -91,9 +161,62 @@ impl WorkerRegistry { ...@@ -91,9 +161,62 @@ impl WorkerRegistry {
bail!("instance {instance_id} dp_rank {dp_rank} not found"); bail!("instance {instance_id} dp_rank {dp_rank} not found");
} }
if let Some(cancel) = entry.cancels.remove(&dp_rank) {
cancel.cancel();
}
if entry.endpoints.is_empty() { if entry.endpoints.is_empty() {
drop(entry); drop(entry);
return self.deregister(instance_id).await; return self.deregister(instance_id, model_name, tenant_id).await;
}
drop(entry);
let key = IndexerKey {
model_name: model_name.to_string(),
tenant_id: tenant_id.to_string(),
};
if let Some(ie) = self.indexers.get(&key) {
ie.indexer.remove_worker_dp_rank(instance_id, dp_rank).await;
} else {
tracing::warn!(
instance_id,
dp_rank,
model_name,
tenant_id,
"indexer key not found on deregister_dp_rank; tree will not be cleaned up"
);
}
Ok(())
}
pub async fn deregister_all_tenants(
&self,
instance_id: WorkerId,
model_name: &str,
) -> Result<()> {
let (_, entry) = self
.workers
.remove(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
for cancel in entry.cancels.values() {
cancel.cancel();
}
let mut found = false;
for ie in self.indexers.iter() {
if ie.key().model_name == model_name {
ie.indexer.remove_worker(instance_id).await;
found = true;
}
}
if !found {
tracing::warn!(
instance_id,
model_name,
"no indexers found for model on deregister_all_tenants; tree will not be cleaned up"
);
} }
Ok(()) Ok(())
...@@ -102,19 +225,18 @@ impl WorkerRegistry { ...@@ -102,19 +225,18 @@ impl WorkerRegistry {
pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> { pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> {
self.workers self.workers
.iter() .iter()
.map(|entry| { .map(|entry| (*entry.key(), entry.value().endpoints.clone()))
let endpoints: HashMap<u32, String> = entry
.value()
.endpoints
.iter()
.map(|(&dp_rank, e)| (dp_rank, e.endpoint.clone()))
.collect();
(*entry.key(), endpoints)
})
.collect() .collect()
} }
pub fn indexer(&self) -> &Indexer { pub fn get_indexer(&self, key: &IndexerKey) -> Option<Ref<'_, IndexerKey, IndexerEntry>> {
&self.indexer self.indexers.get(key)
}
pub fn all_indexers(&self) -> Vec<(IndexerKey, Indexer)> {
self.indexers
.iter()
.map(|entry| (entry.key().clone(), entry.value().indexer.clone()))
.collect()
} }
} }
...@@ -13,17 +13,24 @@ use serde::{Deserialize, Serialize}; ...@@ -13,17 +13,24 @@ use serde::{Deserialize, Serialize};
use dynamo_kv_router::protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq}; use dynamo_kv_router::protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq};
use super::registry::WorkerRegistry; use super::registry::{IndexerKey, WorkerRegistry};
pub struct AppState { pub struct AppState {
pub registry: WorkerRegistry, pub registry: WorkerRegistry,
pub block_size: u32, }
fn default_tenant() -> String {
"default".to_string()
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct RegisterRequest { pub struct RegisterRequest {
pub instance_id: WorkerId, pub instance_id: WorkerId,
pub endpoint: String, pub endpoint: String,
pub model_name: String,
#[serde(default = "default_tenant")]
pub tenant_id: String,
pub block_size: u32,
#[serde(default)] #[serde(default)]
pub dp_rank: Option<u32>, pub dp_rank: Option<u32>,
} }
...@@ -31,6 +38,9 @@ pub struct RegisterRequest { ...@@ -31,6 +38,9 @@ pub struct RegisterRequest {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct UnregisterRequest { pub struct UnregisterRequest {
pub instance_id: WorkerId, pub instance_id: WorkerId,
pub model_name: String,
#[serde(default)]
pub tenant_id: Option<String>,
#[serde(default)] #[serde(default)]
pub dp_rank: Option<u32>, pub dp_rank: Option<u32>,
} }
...@@ -44,13 +54,24 @@ struct WorkerInfo { ...@@ -44,13 +54,24 @@ struct WorkerInfo {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct QueryRequest { pub struct QueryRequest {
pub token_ids: Vec<u32>, pub token_ids: Vec<u32>,
pub model_name: String,
#[serde(default = "default_tenant")]
pub tenant_id: String,
#[serde(default)] #[serde(default)]
pub lora_name: Option<String>, pub lora_name: Option<String>,
} }
/// Query using pre-computed block hashes.
///
/// Callers must include the LoRA salt in their hashes when applicable — use
/// [`compute_block_hash_for_seq`] with the appropriate `lora_name`. The indexer
/// cannot retroactively apply a LoRA salt to pre-computed hashes.
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct QueryByHashRequest { pub struct QueryByHashRequest {
pub block_hashes: Vec<i64>, pub block_hashes: Vec<i64>,
pub model_name: String,
#[serde(default = "default_tenant")]
pub tenant_id: String,
} }
#[derive(Serialize)] #[derive(Serialize)]
...@@ -64,10 +85,14 @@ async fn register( ...@@ -64,10 +85,14 @@ async fn register(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(req): Json<RegisterRequest>, Json(req): Json<RegisterRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
match state match state.registry.register(
.registry req.instance_id,
.register(req.instance_id, req.endpoint, req.dp_rank.unwrap_or(0)) req.endpoint,
{ req.dp_rank.unwrap_or(0),
req.model_name,
req.tenant_id,
req.block_size,
) {
Ok(()) => ( Ok(()) => (
StatusCode::CREATED, StatusCode::CREATED,
Json(serde_json::json!({"status": "ok"})), Json(serde_json::json!({"status": "ok"})),
...@@ -83,14 +108,27 @@ async fn unregister( ...@@ -83,14 +108,27 @@ async fn unregister(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(req): Json<UnregisterRequest>, Json(req): Json<UnregisterRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let result = match req.dp_rank { let result = match req.tenant_id {
Some(dp_rank) => { Some(tenant_id) => match req.dp_rank {
Some(dp_rank) => {
state
.registry
.deregister_dp_rank(req.instance_id, dp_rank, &req.model_name, &tenant_id)
.await
}
None => {
state
.registry
.deregister(req.instance_id, &req.model_name, &tenant_id)
.await
}
},
None => {
state state
.registry .registry
.deregister_dp_rank(req.instance_id, dp_rank) .deregister_all_tenants(req.instance_id, &req.model_name)
.await .await
} }
None => state.registry.deregister(req.instance_id).await,
}; };
match result { match result {
Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))), Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
...@@ -114,13 +152,16 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> impl IntoResponse { ...@@ -114,13 +152,16 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> impl IntoResponse {
Json(workers) Json(workers)
} }
fn build_score_response(overlap: dynamo_kv_router::protocols::OverlapScores) -> ScoreResponse { fn build_score_response(
overlap: dynamo_kv_router::protocols::OverlapScores,
block_size: u32,
) -> ScoreResponse {
let mut scores: HashMap<String, HashMap<String, u32>> = HashMap::new(); let mut scores: HashMap<String, HashMap<String, u32>> = HashMap::new();
for (k, v) in &overlap.scores { for (k, v) in &overlap.scores {
scores scores
.entry(k.worker_id.to_string()) .entry(k.worker_id.to_string())
.or_default() .or_default()
.insert(k.dp_rank.to_string(), *v); .insert(k.dp_rank.to_string(), v * block_size);
} }
let mut tree_sizes: HashMap<String, HashMap<String, usize>> = HashMap::new(); let mut tree_sizes: HashMap<String, HashMap<String, usize>> = HashMap::new();
for (k, v) in &overlap.tree_sizes { for (k, v) in &overlap.tree_sizes {
...@@ -140,16 +181,28 @@ async fn query( ...@@ -140,16 +181,28 @@ async fn query(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(req): Json<QueryRequest>, Json(req): Json<QueryRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let block_hashes = compute_block_hash_for_seq( let key = IndexerKey {
&req.token_ids, model_name: req.model_name,
state.block_size, tenant_id: req.tenant_id,
None, };
req.lora_name.as_deref(), let Some(ie) = state.registry.get_indexer(&key) else {
); return (
match state.registry.indexer().find_matches(block_hashes).await { StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": format!("no indexer for model={} tenant={}", key.model_name, key.tenant_id)
})),
);
};
let block_size = ie.block_size;
let indexer = ie.indexer.clone();
drop(ie);
let block_hashes =
compute_block_hash_for_seq(&req.token_ids, block_size, None, req.lora_name.as_deref());
match indexer.find_matches(block_hashes).await {
Ok(overlap) => ( Ok(overlap) => (
StatusCode::OK, StatusCode::OK,
Json(serde_json::json!(build_score_response(overlap))), Json(serde_json::json!(build_score_response(overlap, block_size))),
), ),
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
...@@ -162,15 +215,31 @@ async fn query_by_hash( ...@@ -162,15 +215,31 @@ async fn query_by_hash(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(req): Json<QueryByHashRequest>, Json(req): Json<QueryByHashRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let key = IndexerKey {
model_name: req.model_name,
tenant_id: req.tenant_id,
};
let Some(ie) = state.registry.get_indexer(&key) else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": format!("no indexer for model={} tenant={}", key.model_name, key.tenant_id)
})),
);
};
let block_size = ie.block_size;
let indexer = ie.indexer.clone();
drop(ie);
let block_hashes: Vec<LocalBlockHash> = req let block_hashes: Vec<LocalBlockHash> = req
.block_hashes .block_hashes
.iter() .iter()
.map(|h| LocalBlockHash(*h as u64)) .map(|h| LocalBlockHash(*h as u64))
.collect(); .collect();
match state.registry.indexer().find_matches(block_hashes).await { match indexer.find_matches(block_hashes).await {
Ok(overlap) => ( Ok(overlap) => (
StatusCode::OK, StatusCode::OK,
Json(serde_json::json!(build_score_response(overlap))), Json(serde_json::json!(build_score_response(overlap, block_size))),
), ),
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
...@@ -180,13 +249,33 @@ async fn query_by_hash( ...@@ -180,13 +249,33 @@ async fn query_by_hash(
} }
async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse { async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse {
match state.registry.indexer().dump_events().await { let indexers = state.registry.all_indexers();
Ok(events) => (StatusCode::OK, Json(serde_json::json!(events))), let mut handles = Vec::with_capacity(indexers.len());
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, for (key, indexer) in indexers {
Json(serde_json::json!({"error": e.to_string()})), handles.push(tokio::spawn(async move {
), let events = indexer.dump_events().await;
(key, events)
}));
}
let mut result: HashMap<String, serde_json::Value> = HashMap::new();
for handle in handles {
match handle.await {
Ok((key, Ok(events))) => {
let map_key = format!("{}:{}", key.model_name, key.tenant_id);
result.insert(map_key, serde_json::json!(events));
}
Ok((key, Err(e))) => {
let map_key = format!("{}:{}", key.model_name, key.tenant_id);
result.insert(map_key, serde_json::json!({"error": e.to_string()}));
}
Err(e) => {
tracing::warn!("dump task join error: {e}");
}
}
} }
(StatusCode::OK, Json(serde_json::json!(result)))
} }
pub fn create_router(state: Arc<AppState>) -> Router { pub fn create_router(state: Arc<AppState>) -> Router {
......
...@@ -515,6 +515,25 @@ impl ConcurrentRadixTree { ...@@ -515,6 +515,25 @@ impl ConcurrentRadixTree {
} }
} }
fn remove_worker_dp_rank(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
worker_id: WorkerId,
dp_rank: DpRank,
) {
let key = WorkerWithDpRank { worker_id, dp_rank };
if let Some(worker_lookup) = lookup.remove(&key) {
for (_, block) in worker_lookup.into_iter() {
let mut guard = block.write();
guard.workers.remove(&key);
if guard.workers.is_empty() {
guard.children.clear();
}
}
self.tree_sizes.remove(&key);
}
}
/// Clear all blocks for a worker but keep the worker tracked. /// Clear all blocks for a worker but keep the worker tracked.
fn clear_all_blocks( fn clear_all_blocks(
&self, &self,
...@@ -616,6 +635,9 @@ impl SyncIndexer for ConcurrentRadixTree { ...@@ -616,6 +635,9 @@ impl SyncIndexer for ConcurrentRadixTree {
WorkerTask::RemoveWorker(worker_id) => { WorkerTask::RemoveWorker(worker_id) => {
self.remove_or_clear_worker_blocks(&mut lookup, worker_id, false); self.remove_or_clear_worker_blocks(&mut lookup, worker_id, false);
} }
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank);
}
WorkerTask::DumpEvents(_sender) => { WorkerTask::DumpEvents(_sender) => {
// Handled directly via dump_events() on the shared tree. // Handled directly via dump_events() on the shared tree.
// Should not be reached, but respond with empty to avoid blocking. // Should not be reached, but respond with empty to avoid blocking.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use std::future::Future;
use crate::protocols::RouterEvent;
/// Transport abstraction for publishing batched KV cache events.
pub trait EventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
...@@ -327,6 +327,14 @@ pub trait KvIndexerInterface { ...@@ -327,6 +327,14 @@ pub trait KvIndexerInterface {
/// * `worker` - The worker to remove from the trie. /// * `worker` - The worker to remove from the trie.
async fn remove_worker(&self, worker: WorkerId); async fn remove_worker(&self, worker: WorkerId);
/// Remove a single dp_rank for a worker from the trie.
///
/// Default implementation falls back to removing the entire worker.
/// Indexers that track dp_rank-level granularity should override this.
async fn remove_worker_dp_rank(&self, worker: WorkerId, _dp_rank: DpRank) {
self.remove_worker(worker).await;
}
/// Shutdown the KV Indexer. /// Shutdown the KV Indexer.
fn shutdown(&self); fn shutdown(&self);
...@@ -363,6 +371,8 @@ pub enum WorkerTask { ...@@ -363,6 +371,8 @@ pub enum WorkerTask {
Event(RouterEvent), Event(RouterEvent),
/// Permanently remove a worker from tracking (keep_worker: false). /// Permanently remove a worker from tracking (keep_worker: false).
RemoveWorker(WorkerId), RemoveWorker(WorkerId),
/// Remove a single dp_rank for a worker.
RemoveWorkerDpRank(WorkerId, DpRank),
DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>), DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>),
Terminate, Terminate,
} }
...@@ -568,6 +578,14 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -568,6 +578,14 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
} }
} }
async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) {
// Broadcast to all threads — the dp_rank may be on any thread.
// Don't remove from worker_assignments since other dp_ranks may still exist.
for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank));
}
}
fn shutdown(&self) { fn shutdown(&self) {
// Send shutdown signal to all worker threads // Send shutdown signal to all worker threads
for channel in self.worker_event_channels.iter() { for channel in self.worker_event_channels.iter() {
...@@ -668,6 +686,8 @@ pub struct KvIndexer { ...@@ -668,6 +686,8 @@ pub struct KvIndexer {
match_tx: mpsc::Sender<MatchRequest>, match_tx: mpsc::Sender<MatchRequest>,
/// A sender for remove worker requests. /// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for remove worker dp_rank requests.
remove_worker_dp_rank_tx: mpsc::Sender<(WorkerId, DpRank)>,
/// A sender for get workers requests. /// A sender for get workers requests.
get_workers_tx: mpsc::Sender<GetWorkersRequest>, get_workers_tx: mpsc::Sender<GetWorkersRequest>,
/// A sender for dump requests. /// A sender for dump requests.
...@@ -704,6 +724,8 @@ impl KvIndexer { ...@@ -704,6 +724,8 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
let (get_workers_tx, get_workers_rx) = mpsc::channel::<GetWorkersRequest>(16); let (get_workers_tx, get_workers_rx) = mpsc::channel::<GetWorkersRequest>(16);
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16); let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let (routing_tx, mut routing_rx) = mpsc::channel::<RoutingDecisionRequest>(2048); let (routing_tx, mut routing_rx) = mpsc::channel::<RoutingDecisionRequest>(2048);
...@@ -723,6 +745,7 @@ impl KvIndexer { ...@@ -723,6 +745,7 @@ impl KvIndexer {
let mut match_rx = match_rx; let mut match_rx = match_rx;
let mut event_rx = event_rx; let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx; let mut remove_worker_rx = remove_worker_rx;
let mut remove_worker_dp_rank_rx = remove_worker_dp_rank_rx;
let mut get_workers_rx = get_workers_rx; let mut get_workers_rx = get_workers_rx;
let mut dump_rx = dump_rx; let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration); let mut trie = RadixTree::new_with_frequency(expiration_duration);
...@@ -754,6 +777,10 @@ impl KvIndexer { ...@@ -754,6 +777,10 @@ impl KvIndexer {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some((worker_id, dp_rank)) = remove_worker_dp_rank_rx.recv() => {
trie.remove_worker_dp_rank(worker_id, dp_rank);
}
Some(get_workers_req) = get_workers_rx.recv() => { Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers(); let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers); let _ = get_workers_req.resp.send(workers);
...@@ -933,6 +960,7 @@ impl KvIndexer { ...@@ -933,6 +960,7 @@ impl KvIndexer {
event_tx, event_tx,
match_tx, match_tx,
remove_worker_tx, remove_worker_tx,
remove_worker_dp_rank_tx,
get_workers_tx, get_workers_tx,
dump_tx, dump_tx,
routing_tx, routing_tx,
...@@ -1052,6 +1080,13 @@ impl KvIndexerInterface for KvIndexer { ...@@ -1052,6 +1080,13 @@ impl KvIndexerInterface for KvIndexer {
self.remove_worker_tx.send(worker).await.unwrap(); self.remove_worker_tx.send(worker).await.unwrap();
} }
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
self.remove_worker_dp_rank_tx
.send((worker, dp_rank))
.await
.unwrap();
}
fn shutdown(&self) { fn shutdown(&self) {
self.cancel.cancel(); self.cancel.cancel();
} }
...@@ -1461,6 +1496,7 @@ pub struct KvIndexerSharded { ...@@ -1461,6 +1496,7 @@ pub struct KvIndexerSharded {
event_tx: Vec<mpsc::Sender<RouterEvent>>, event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>, request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>, remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
remove_worker_dp_rank_tx: Vec<mpsc::Sender<(WorkerId, DpRank)>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>, dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>, routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Arc<Mutex<Vec<JoinHandle<()>>>>, tasks: Arc<Mutex<Vec<JoinHandle<()>>>>,
...@@ -1493,6 +1529,7 @@ impl KvIndexerSharded { ...@@ -1493,6 +1529,7 @@ impl KvIndexerSharded {
let mut event_tx = Vec::new(); let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new(); let mut remove_worker_tx = Vec::new();
let mut remove_worker_dp_rank_tx = Vec::new();
let mut get_workers_tx = Vec::new(); let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new(); let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new(); let mut routing_tx = Vec::new();
...@@ -1504,6 +1541,8 @@ impl KvIndexerSharded { ...@@ -1504,6 +1541,8 @@ impl KvIndexerSharded {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048); let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
let (shard_remove_worker_tx, mut shard_remove_worker_rx) = let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
mpsc::channel::<WorkerId>(16); mpsc::channel::<WorkerId>(16);
let (shard_remove_worker_dp_rank_tx, mut shard_remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
let (shard_get_workers_tx, mut shard_get_workers_rx) = let (shard_get_workers_tx, mut shard_get_workers_rx) =
mpsc::channel::<GetWorkersRequest>(16); mpsc::channel::<GetWorkersRequest>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16);
...@@ -1517,6 +1556,7 @@ impl KvIndexerSharded { ...@@ -1517,6 +1556,7 @@ impl KvIndexerSharded {
event_tx.push(shard_event_tx); event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx); remove_worker_tx.push(shard_remove_worker_tx);
remove_worker_dp_rank_tx.push(shard_remove_worker_dp_rank_tx);
get_workers_tx.push(shard_get_workers_tx); get_workers_tx.push(shard_get_workers_tx);
dump_tx.push(shard_dump_tx); dump_tx.push(shard_dump_tx);
routing_tx.push(shard_routing_tx); routing_tx.push(shard_routing_tx);
...@@ -1557,6 +1597,10 @@ impl KvIndexerSharded { ...@@ -1557,6 +1597,10 @@ impl KvIndexerSharded {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some((worker_id, dp_rank)) = shard_remove_worker_dp_rank_rx.recv() => {
trie.remove_worker_dp_rank(worker_id, dp_rank);
}
Some(get_workers_req) = shard_get_workers_rx.recv() => { Some(get_workers_req) = shard_get_workers_rx.recv() => {
let workers = trie.get_workers(); let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers); let _ = get_workers_req.resp.send(workers);
...@@ -1736,6 +1780,7 @@ impl KvIndexerSharded { ...@@ -1736,6 +1780,7 @@ impl KvIndexerSharded {
event_tx, event_tx,
request_broadcast_tx, request_broadcast_tx,
remove_worker_tx, remove_worker_tx,
remove_worker_dp_rank_tx,
dump_tx, dump_tx,
routing_tx, routing_tx,
tasks, tasks,
...@@ -1860,6 +1905,17 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1860,6 +1905,17 @@ impl KvIndexerInterface for KvIndexerSharded {
} }
} }
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
// Worker is assigned to a single shard, so route there directly.
// Don't remove from worker_assignments since other dp_ranks may still exist.
if let Some(shard) = self.worker_assignments.get(&worker) {
self.remove_worker_dp_rank_tx[*shard]
.send((worker, dp_rank))
.await
.unwrap();
}
}
/// Shutdown the KV Indexer. /// Shutdown the KV Indexer.
fn shutdown(&self) { fn shutdown(&self) {
self.cancel.cancel(); self.cancel.cancel();
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
pub mod approx; pub mod approx;
pub mod concurrent_radix_tree; pub mod concurrent_radix_tree;
pub mod event_sink;
pub mod indexer; pub mod indexer;
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
pub mod naive_indexers; pub mod naive_indexers;
...@@ -36,6 +37,7 @@ pub use self::multi_worker_sequence::{ ...@@ -36,6 +37,7 @@ pub use self::multi_worker_sequence::{
pub use self::sequence::{ActiveSequences, RequestId}; pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use config::{KvRouterConfig, RouterConfigOverride}; pub use config::{KvRouterConfig, RouterConfigOverride};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
pub use naive_indexers::{InvertedIndex, NaiveNestedMap}; pub use naive_indexers::{InvertedIndex, NaiveNestedMap};
......
...@@ -26,8 +26,9 @@ use std::sync::atomic::{AtomicUsize, Ordering}; ...@@ -26,8 +26,9 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use crate::indexer::{SyncIndexer, WorkerTask}; use crate::indexer::{SyncIndexer, WorkerTask};
use crate::protocols::{ use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError, KvCacheStoreData, DpRank, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank, KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
WorkerWithDpRank,
}; };
/// Entry for the innermost level of the index. /// Entry for the innermost level of the index.
...@@ -150,6 +151,9 @@ impl SyncIndexer for PositionalIndexer { ...@@ -150,6 +151,9 @@ impl SyncIndexer for PositionalIndexer {
WorkerTask::RemoveWorker(worker_id) => { WorkerTask::RemoveWorker(worker_id) => {
self.remove_or_clear_worker_blocks_impl(&mut worker_blocks, worker_id, false); self.remove_or_clear_worker_blocks_impl(&mut worker_blocks, worker_id, false);
} }
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank_impl(&mut worker_blocks, worker_id, dp_rank);
}
WorkerTask::DumpEvents(sender) => { WorkerTask::DumpEvents(sender) => {
let events = self.dump_events(&worker_blocks); let events = self.dump_events(&worker_blocks);
if let Err(e) = sender.send(Ok(events)) { if let Err(e) = sender.send(Ok(events)) {
...@@ -329,6 +333,23 @@ impl PositionalIndexer { ...@@ -329,6 +333,23 @@ impl PositionalIndexer {
self.remove_or_clear_worker_blocks_impl(worker_blocks, worker_id, true); self.remove_or_clear_worker_blocks_impl(worker_blocks, worker_id, true);
} }
fn remove_worker_dp_rank_impl(
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
worker_id: WorkerId,
dp_rank: DpRank,
) {
let key = WorkerWithDpRank { worker_id, dp_rank };
if let Some(worker_map) = worker_blocks.remove(&key) {
for (seq_hash, (position, local_hash)) in worker_map.iter() {
if let Some(mut entry) = self.index.get_mut(&(*position, *local_hash)) {
let _ = entry.remove(*seq_hash, key);
}
}
self.tree_sizes.remove(&key);
}
}
/// Helper function to remove or clear blocks for a worker. /// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains tracked with empty blocks. /// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed. /// If `keep_worker` is false, the worker is completely removed.
......
...@@ -497,6 +497,18 @@ impl RadixTree { ...@@ -497,6 +497,18 @@ impl RadixTree {
self.remove_or_clear_worker_blocks(worker_id, false); self.remove_or_clear_worker_blocks(worker_id, false);
} }
pub fn remove_worker_dp_rank(&mut self, worker_id: WorkerId, dp_rank: DpRank) {
let key = WorkerWithDpRank { worker_id, dp_rank };
if let Some(blocks) = self.lookup.remove(&key) {
for (_, block) in blocks {
block.borrow_mut().workers.remove(&key);
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
}
}
}
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) { pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true); self.remove_or_clear_worker_blocks(worker_id, true);
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait;
use rmp_serde as rmps; use rmp_serde as rmps;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -289,7 +289,7 @@ impl KvEventPublisher { ...@@ -289,7 +289,7 @@ impl KvEventPublisher {
}; };
start_event_processor( start_event_processor(
event_publisher, EventPlanePublisher(event_publisher),
worker_id, worker_id,
cancellation_token_clone, cancellation_token_clone,
rx, rx,
...@@ -315,7 +315,7 @@ impl KvEventPublisher { ...@@ -315,7 +315,7 @@ impl KvEventPublisher {
return; return;
} }
start_event_processor_jetstream( start_event_processor_jetstream(
nats_queue, JetStreamPublisher(nats_queue),
worker_id, worker_id,
cancellation_token_clone, cancellation_token_clone,
rx, rx,
...@@ -366,22 +366,21 @@ impl Drop for KvEventPublisher { ...@@ -366,22 +366,21 @@ impl Drop for KvEventPublisher {
} }
} }
#[async_trait] use dynamo_kv_router::EventSink;
trait EventSink: Send + Sync {
async fn publish_event(&self, event: &RouterEvent) -> Result<()>; struct EventPlanePublisher(EventPublisher);
}
#[async_trait] impl EventSink for EventPlanePublisher {
impl EventSink for EventPublisher { fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> { self.0.publish(event)
self.publish(event).await
} }
} }
#[async_trait] struct JetStreamPublisher(NatsQueue);
impl EventSink for NatsQueue {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> { impl EventSink for JetStreamPublisher {
NatsQueue::publish_event(self, KV_EVENT_SUBJECT, event).await fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
NatsQueue::publish_event(&self.0, KV_EVENT_SUBJECT, event)
} }
} }
...@@ -587,8 +586,8 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>( ...@@ -587,8 +586,8 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
} }
/// Batched event processor using JetStream (durable). /// Batched event processor using JetStream (durable).
async fn start_event_processor_jetstream( async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>(
publisher: NatsQueue, publisher: P,
worker_id: u64, worker_id: u64,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<KvCacheEvent>, rx: mpsc::UnboundedReceiver<KvCacheEvent>,
...@@ -1278,15 +1277,17 @@ mod tests_startup_helpers { ...@@ -1278,15 +1277,17 @@ mod tests_startup_helpers {
} }
} }
#[async_trait::async_trait]
impl EventSink for MockComponent { impl EventSink for MockComponent {
async fn publish_event(&self, event: &RouterEvent) -> anyhow::Result<()> { fn publish_event(
&self,
event: &RouterEvent,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let bytes = rmp_serde::to_vec(event).unwrap(); let bytes = rmp_serde::to_vec(event).unwrap();
self.published self.published
.lock() .lock()
.unwrap() .unwrap()
.push((KV_EVENT_SUBJECT.to_string(), bytes)); .push((KV_EVENT_SUBJECT.to_string(), bytes));
Ok(()) async { Ok(()) }
} }
} }
...@@ -2253,11 +2254,10 @@ mod event_processor_tests { ...@@ -2253,11 +2254,10 @@ mod event_processor_tests {
} }
} }
#[async_trait]
impl EventSink for MockPublisher { impl EventSink for MockPublisher {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> { fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.events.lock().unwrap().push(event.clone()); self.events.lock().unwrap().push(event.clone());
Ok(()) async { Ok(()) }
} }
} }
......
...@@ -1669,8 +1669,17 @@ def _test_router_indexers_sync( ...@@ -1669,8 +1669,17 @@ def _test_router_indexers_sync(
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(f"{standalone_indexer_url}/dump") as resp: async with session.get(f"{standalone_indexer_url}/dump") as resp:
assert resp.status == 200, f"GET /dump failed: {resp.status}" assert resp.status == 200, f"GET /dump failed: {resp.status}"
standalone_state = await resp.json() dump_by_key = await resp.json()
# /dump returns {model:tenant -> events}, extract the expected key
expected_key = f"{model_name}:default"
assert expected_key in dump_by_key, (
f"Expected dump key '{expected_key}', "
f"got keys={list(dump_by_key.keys())}"
)
for k, v in dump_by_key.items():
assert isinstance(v, list), f"Dump key '{k}' returned error: {v}"
standalone_state = dump_by_key[expected_key]
sorted_standalone = sorted(standalone_state, key=sort_key) sorted_standalone = sorted(standalone_state, key=sort_key)
logger.info(f"Standalone HTTP indexer has {len(sorted_standalone)} events") logger.info(f"Standalone HTTP indexer has {len(sorted_standalone)} events")
...@@ -2197,7 +2206,7 @@ def _test_router_decisions( ...@@ -2197,7 +2206,7 @@ def _test_router_decisions(
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
f"{standalone_indexer_url}/query", f"{standalone_indexer_url}/query",
json={"token_ids": req4_tokens}, json={"token_ids": req4_tokens, "model_name": model_name},
) as resp: ) as resp:
assert resp.status == 200, f"POST /query failed: {resp.status}" assert resp.status == 200, f"POST /query failed: {resp.status}"
scores = (await resp.json())["scores"] scores = (await resp.json())["scores"]
......
...@@ -190,10 +190,12 @@ class MockerProcess: ...@@ -190,10 +190,12 @@ class MockerProcess:
request_plane: str = "nats", request_plane: str = "nats",
zmq_kv_events: bool = False, zmq_kv_events: bool = False,
standalone_indexer: bool = False, standalone_indexer: bool = False,
model_name: str = "mocker",
): ):
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}" self.namespace = f"test-namespace-{namespace_suffix}"
self.component_name = "mocker" self.component_name = "mocker"
self.model_name = model_name
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_mockers self.num_workers = num_mockers
self._zmq_kv_events_ports: list[int] = [] self._zmq_kv_events_ports: list[int] = []
...@@ -386,6 +388,10 @@ class MockerProcess: ...@@ -386,6 +388,10 @@ class MockerProcess:
"instance_id": new_worker_id, "instance_id": new_worker_id,
"endpoint": endpoint, "endpoint": endpoint,
"dp_rank": dp_rank, "dp_rank": dp_rank,
"model_name": self.model_name,
"block_size": self._mocker_args_orig.get(
"block_size", BLOCK_SIZE
),
} }
async with session.post(register_url, json=payload) as resp: async with session.post(register_url, json=payload) as resp:
if resp.status != 201: if resp.status != 201:
...@@ -897,6 +903,7 @@ def test_router_decisions( ...@@ -897,6 +903,7 @@ def test_router_decisions(
request_plane=request_plane, request_plane=request_plane,
zmq_kv_events=zmq_kv_events, zmq_kv_events=zmq_kv_events,
standalone_indexer=zmq_kv_events, standalone_indexer=zmq_kv_events,
model_name=MODEL_NAME,
) as mockers: ) as mockers:
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment