Unverified Commit 35f99f93 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-indexer): multi-model and multi-tenant isolation (#6830)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent e0a2e7bb
......@@ -13,6 +13,14 @@ This is distinct from the [Standalone Router](../../../components/src/dynamo/rou
The HTTP API follows the [Mooncake KV Indexer RFC](https://github.com/kvcache-ai/Mooncake/issues/1403) conventions.
## Multi-Model and Multi-Tenant Support
The indexer maintains one radix tree per `(model_name, tenant_id)` pair. Workers registered with different model names or tenant IDs are isolated into separate indexers — queries against one model/tenant never return scores from another.
- **`model_name`** (required on `/register` and `/query`): Identifies the model. Workers serving different models get separate radix trees.
- **`tenant_id`** (optional, defaults to `"default"`): Enables multi-tenant isolation within the same model. Omit for single-tenant deployments.
- **`block_size`** is per-indexer: the first `/register` call for a given `(model_name, tenant_id)` sets the block size. Subsequent registrations for the same pair must use the same block size or the request will fail.
## Compatibility
The standalone indexer works with any engine that publishes KV cache events over ZMQ in the expected msgpack format. This includes bare vLLM and SGLang engines, which emit ZMQ KV events natively — no Dynamo-specific wrapper is required.
......@@ -35,56 +43,84 @@ cargo build -p dynamo-kv-router --features indexer-bin --bin dynamo-kv-indexer
## CLI
```bash
dynamo-kv-indexer --block-size 16 --port 8090 [--threads 1] [--workers "1=tcp://host:5557,2=tcp://host:5558"]
dynamo-kv-indexer --port 8090 [--threads 1] [--block-size 16 --model-name my-model --tenant-id default --workers "1=tcp://host:5557,2=tcp://host:5558"]
```
| Flag | Default | Description |
|------|---------|-------------|
| `--block-size` | (required) | KV cache block size (must match the engine's block size) |
| `--block-size` | (none) | KV cache block size for initial `--workers` (required when `--workers` is set) |
| `--port` | `8090` | HTTP server listen port |
| `--threads` | `1` | Number of indexer threads (1 = single-threaded, >1 = thread pool) |
| `--workers` | (none) | Initial workers as `instance_id=zmq_address,...` pairs |
| `--model-name` | `default` | Model name for initial `--workers` |
| `--tenant-id` | `default` | Tenant ID for initial `--workers` |
## HTTP API
### `POST /register` — Register an endpoint
Register a ZMQ endpoint for an instance. Call once per dp_rank for data-parallel workers:
Register a ZMQ endpoint for an instance. Each call creates or reuses the indexer for the given `(model_name, tenant_id)` pair.
```bash
# Single dp_rank (dp_rank defaults to 0)
# Single model, default tenant
curl -X POST http://localhost:8090/register \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1, "endpoint": "tcp://127.0.0.1:5557"}'
# Multiple dp_ranks — register each separately
-d '{
"instance_id": 1,
"endpoint": "tcp://127.0.0.1:5557",
"model_name": "llama-3-8b",
"block_size": 16
}'
# With tenant isolation
curl -X POST http://localhost:8090/register \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1, "endpoint": "tcp://127.0.0.1:5557", "dp_rank": 0}'
curl -X POST http://localhost:8090/register \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1, "endpoint": "tcp://127.0.0.1:5558", "dp_rank": 1}'
-d '{
"instance_id": 2,
"endpoint": "tcp://127.0.0.1:5558",
"model_name": "llama-3-8b",
"tenant_id": "customer-a",
"block_size": 16,
"dp_rank": 0
}'
```
The indexer spawns a ZMQ SUB listener for each endpoint and begins consuming KV events.
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `instance_id` | yes | — | Worker instance identifier |
| `endpoint` | yes | — | ZMQ PUB address to subscribe to |
| `model_name` | yes | — | Model name (used to select the indexer) |
| `block_size` | yes | — | KV cache block size (must match the engine) |
| `tenant_id` | no | `"default"` | Tenant identifier for isolation |
| `dp_rank` | no | `0` | Data parallel rank |
### `POST /unregister` — Deregister an instance
Remove all dp_ranks for an instance, or a specific dp_rank:
Remove an instance. Omitting `tenant_id` removes the instance from **all** tenants for the given model; providing it targets only that tenant's indexer.
```bash
# Remove all dp_ranks
# Remove from all tenants
curl -X POST http://localhost:8090/unregister \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1, "model_name": "llama-3-8b"}'
# Remove from a specific tenant
curl -X POST http://localhost:8090/unregister \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1}'
-d '{"instance_id": 1, "model_name": "llama-3-8b", "tenant_id": "customer-a"}'
# Remove a specific dp_rank
curl -X POST http://localhost:8090/unregister \
-H 'Content-Type: application/json' \
-d '{"instance_id": 1, "dp_rank": 0}'
-d '{"instance_id": 1, "model_name": "llama-3-8b", "tenant_id": "default", "dp_rank": 0}'
```
Cancels ZMQ listeners and removes the instance's blocks from the radix tree.
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `instance_id` | yes | — | Worker instance to remove |
| `model_name` | yes | — | Model name (identifies the indexer) |
| `tenant_id` | no | — | Tenant identifier (omit to remove from all tenants) |
| `dp_rank` | no | — | Specific dp_rank to remove (omit to remove all) |
### `GET /workers` — List registered instances
......@@ -99,43 +135,66 @@ Returns:
### `POST /query` — Query overlap for token IDs
Given raw token IDs, compute block hashes and return per-instance overlap scores:
Given raw token IDs, compute block hashes and return per-instance overlap scores (in matched tokens):
```bash
curl -X POST http://localhost:8090/query \
-H 'Content-Type: application/json' \
-d '{"token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}'
-d '{"token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], "model_name": "llama-3-8b"}'
```
Returns:
```json
{
"scores": {"1": {"0": 2}, "2": {"1": 0}},
"scores": {"1": {"0": 32}, "2": {"1": 0}},
"frequencies": [1, 1],
"tree_sizes": {"1": {"0": 5}, "2": {"1": 3}}
}
```
Scores are nested by `instance_id` then `dp_rank`. Higher score means more cached prefix blocks on that instance.
Scores are in **matched tokens** (block overlap count × block size). Nested by `instance_id` then `dp_rank`.
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `token_ids` | yes | — | Token sequence to query |
| `model_name` | yes | — | Model name (selects the indexer) |
| `tenant_id` | no | `"default"` | Tenant identifier |
| `lora_name` | no | — | LoRA adapter (overrides indexer-level lora_name for this query) |
### `POST /query_by_hash` — Query overlap for pre-computed hashes
```bash
curl -X POST http://localhost:8090/query_by_hash \
-H 'Content-Type: application/json' \
-d '{"block_hashes": [123456, 789012]}'
-d '{"block_hashes": [123456, 789012], "model_name": "llama-3-8b"}'
```
Same response format as `/query`.
Same response format as `/query`. Scores are in matched tokens.
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `block_hashes` | yes | — | Pre-computed block hash array |
| `model_name` | yes | — | Model name (selects the indexer) |
| `tenant_id` | no | `"default"` | Tenant identifier |
### `GET /dump` — Dump all radix tree events
Returns the full radix tree state as a JSON array of `RouterEvent` objects:
Returns the full radix tree state as a JSON object keyed by `model_name:tenant_id`:
```bash
curl http://localhost:8090/dump
```
Returns:
```json
{
"llama-3-8b:default": [<RouterEvent>, ...],
"mistral-7b:customer-a": [<RouterEvent>, ...]
}
```
Each indexer is dumped concurrently.
## Limitations
- **ZMQ only**: Workers must publish KV events via ZMQ PUB sockets. The standalone indexer does not subscribe to NATS event streams.
......@@ -153,7 +212,7 @@ graph TD
subgraph "Standalone Indexer (HTTP)"
REG[Worker Registry]
ZMQ[ZMQ SUB Listeners]
IDX[Indexer / Radix Tree]
IDX["Indexer Map<br/>(model, tenant) → Radix Tree"]
HTTP[HTTP API<br/>/query /dump /register]
end
......
......@@ -32,6 +32,13 @@ impl Indexer {
}
}
pub async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: u32) {
match self {
Indexer::Single(idx) => idx.remove_worker_dp_rank(worker_id, dp_rank).await,
Indexer::Concurrent(idx) => idx.remove_worker_dp_rank(worker_id, dp_rank).await,
}
}
pub async fn find_matches(&self, hashes: Vec<LocalBlockHash>) -> Result<OverlapScores> {
match self {
Indexer::Single(idx) => idx.find_matches(hashes).await.map_err(Into::into),
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::atomic::AtomicU32;
use std::time::Duration;
use rmp_serde as rmps;
......@@ -48,7 +48,7 @@ pub async fn run_zmq_listener(
return;
}
let next_event_id = AtomicU64::new(0);
let mut next_event_id = 0u64;
let warning_count = Arc::new(AtomicU32::new(0));
let mut consecutive_errors = 0u32;
#[allow(unused_assignments)]
......@@ -94,29 +94,28 @@ pub async fn run_zmq_listener(
consecutive_errors = 0;
let mut frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|f| f.to_vec()).collect();
if frames.len() != 3 {
tracing::warn!(worker_id, "Unexpected ZMQ frame count: {}", frames.len());
if msg.len() != 3 {
tracing::warn!(worker_id, "Unexpected ZMQ frame count: {}", msg.len());
continue;
}
let payload = frames.pop().unwrap();
let seq_bytes = frames.pop().unwrap();
let seq_bytes = msg.get(1).unwrap();
if seq_bytes.len() != 8 {
tracing::warn!(worker_id, "Invalid sequence number length: {}", seq_bytes.len());
continue;
}
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let payload = msg.get(2).unwrap();
let batch_result = rmps::from_slice::<KvEventBatch>(payload);
let Ok(batch) = batch_result else {
tracing::warn!(worker_id, "Failed to decode KvEventBatch: {}", batch_result.unwrap_err());
continue;
};
let effective_dp_rank = batch.data_parallel_rank.map_or(dp_rank, |r| r as u32);
for raw_event in batch.events.into_iter() {
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
for raw_event in batch.events {
let event_id = next_event_id;
next_event_id += 1;
let kv_event = convert_event(raw_event, event_id, block_size, effective_dp_rank, &warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
indexer.apply_event(router_event).await;
......
......@@ -11,16 +11,15 @@ mod listener;
mod registry;
mod server;
use indexer::create_indexer;
use registry::WorkerRegistry;
use server::{AppState, create_router};
#[derive(Parser)]
#[command(name = "dynamo-kv-indexer", about = "Standalone KV cache indexer")]
struct Cli {
/// KV cache block size (must match the vLLM engine's block size)
/// KV cache block size for initial workers registered via --workers
#[arg(long)]
block_size: u32,
block_size: Option<u32>,
/// HTTP server port
#[arg(long, default_value_t = 8090)]
......@@ -33,6 +32,14 @@ struct Cli {
/// Initial workers as "worker_id=zmq_address,..." (e.g. "1=tcp://host:5557,2=tcp://host:5558")
#[arg(long)]
workers: Option<String>,
/// Model name for initial workers registered via --workers
#[arg(long, default_value = "default")]
model_name: String,
/// Tenant ID for initial workers registered via --workers
#[arg(long, default_value = "default")]
tenant_id: String,
}
fn parse_workers(s: &str) -> Vec<(u64, String)> {
......@@ -58,26 +65,34 @@ async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
tracing::info!(
block_size = cli.block_size,
block_size = ?cli.block_size,
port = cli.port,
threads = cli.threads,
model_name = %cli.model_name,
tenant_id = %cli.tenant_id,
"Starting standalone KV cache indexer"
);
let indexer = create_indexer(cli.block_size, cli.threads);
let registry = WorkerRegistry::new(indexer, cli.block_size);
let registry = WorkerRegistry::new(cli.threads);
if let Some(ref workers_str) = cli.workers {
let block_size = cli.block_size.ok_or_else(|| {
anyhow::anyhow!("--block-size is required when --workers is specified")
})?;
for (instance_id, endpoint) in parse_workers(workers_str) {
tracing::info!(instance_id, endpoint, "Registering initial worker");
registry.register(instance_id, endpoint, 0)?;
registry.register(
instance_id,
endpoint,
0,
cli.model_name.clone(),
cli.tenant_id.clone(),
block_size,
)?;
}
}
let state = Arc::new(AppState {
registry,
block_size: cli.block_size,
});
let state = Arc::new(AppState { registry });
let app = create_router(state);
let listener = TcpListener::bind(("0.0.0.0", cli.port)).await?;
......
......@@ -5,83 +5,153 @@ use std::collections::HashMap;
use anyhow::{Result, bail};
use dashmap::DashMap;
use dashmap::mapref::one::Ref;
use tokio_util::sync::CancellationToken;
use dynamo_kv_router::protocols::WorkerId;
use super::indexer::Indexer;
use super::indexer::{Indexer, create_indexer};
use super::listener::run_zmq_listener;
pub struct EndpointEntry {
pub endpoint: String,
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct IndexerKey {
pub model_name: String,
pub tenant_id: String,
}
pub struct IndexerEntry {
pub indexer: Indexer,
pub block_size: u32,
}
pub struct WorkerEntry {
pub endpoints: HashMap<u32, EndpointEntry>,
cancel: CancellationToken,
pub endpoints: HashMap<u32, String>,
cancels: HashMap<u32, CancellationToken>,
}
pub struct WorkerRegistry {
workers: DashMap<WorkerId, WorkerEntry>,
indexer: Indexer,
block_size: u32,
indexers: DashMap<IndexerKey, IndexerEntry>,
num_threads: usize,
}
impl WorkerRegistry {
pub fn new(indexer: Indexer, block_size: u32) -> Self {
pub fn new(num_threads: usize) -> Self {
Self {
workers: DashMap::new(),
indexer,
block_size,
indexers: DashMap::new(),
num_threads,
}
}
pub fn register(&self, instance_id: WorkerId, endpoint: String, dp_rank: u32) -> Result<()> {
pub fn register(
&self,
instance_id: WorkerId,
endpoint: String,
dp_rank: u32,
model_name: String,
tenant_id: String,
block_size: u32,
) -> Result<()> {
let key = IndexerKey {
model_name,
tenant_id,
};
// Get or create the indexer for this (model, tenant) pair.
// Use the entry API for atomic check-and-insert.
let indexer_entry = self.indexers.entry(key.clone()).or_insert_with(|| {
tracing::info!(
model_name = %key.model_name,
tenant_id = %key.tenant_id,
block_size,
"Creating new indexer"
);
IndexerEntry {
indexer: create_indexer(block_size, self.num_threads),
block_size,
}
});
if indexer_entry.block_size != block_size {
bail!(
"block_size mismatch for model={} tenant={}: existing={}, requested={}",
key.model_name,
key.tenant_id,
indexer_entry.block_size,
block_size
);
}
let indexer = indexer_entry.indexer.clone();
let bs = indexer_entry.block_size;
drop(indexer_entry);
let mut entry = self
.workers
.entry(instance_id)
.or_insert_with(|| WorkerEntry {
endpoints: HashMap::new(),
cancel: CancellationToken::new(),
cancels: HashMap::new(),
});
if entry.endpoints.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already registered");
}
let child_cancel = entry.cancel.child_token();
let indexer = self.indexer.clone();
let block_size = self.block_size;
let cancel = CancellationToken::new();
let child_cancel = cancel.child_token();
let addr = endpoint.clone();
tokio::spawn(async move {
run_zmq_listener(
instance_id,
dp_rank,
addr,
block_size,
indexer,
child_cancel,
)
.await;
run_zmq_listener(instance_id, dp_rank, addr, bs, indexer, child_cancel).await;
});
entry.endpoints.insert(dp_rank, EndpointEntry { endpoint });
entry.endpoints.insert(dp_rank, endpoint);
entry.cancels.insert(dp_rank, cancel);
Ok(())
}
pub async fn deregister(&self, instance_id: WorkerId) -> Result<()> {
pub async fn deregister(
&self,
instance_id: WorkerId,
model_name: &str,
tenant_id: &str,
) -> Result<()> {
let (_, entry) = self
.workers
.remove(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
entry.cancel.cancel();
self.indexer.remove_worker(instance_id).await;
for cancel in entry.cancels.values() {
cancel.cancel();
}
let key = IndexerKey {
model_name: model_name.to_string(),
tenant_id: tenant_id.to_string(),
};
if let Some(ie) = self.indexers.get(&key) {
ie.indexer.remove_worker(instance_id).await;
} else {
tracing::warn!(
instance_id,
model_name,
tenant_id,
"indexer key not found on deregister; tree will not be cleaned up"
);
}
Ok(())
}
pub async fn deregister_dp_rank(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
pub async fn deregister_dp_rank(
&self,
instance_id: WorkerId,
dp_rank: u32,
model_name: &str,
tenant_id: &str,
) -> Result<()> {
let mut entry = self
.workers
.get_mut(&instance_id)
......@@ -91,9 +161,62 @@ impl WorkerRegistry {
bail!("instance {instance_id} dp_rank {dp_rank} not found");
}
if let Some(cancel) = entry.cancels.remove(&dp_rank) {
cancel.cancel();
}
if entry.endpoints.is_empty() {
drop(entry);
return self.deregister(instance_id).await;
return self.deregister(instance_id, model_name, tenant_id).await;
}
drop(entry);
let key = IndexerKey {
model_name: model_name.to_string(),
tenant_id: tenant_id.to_string(),
};
if let Some(ie) = self.indexers.get(&key) {
ie.indexer.remove_worker_dp_rank(instance_id, dp_rank).await;
} else {
tracing::warn!(
instance_id,
dp_rank,
model_name,
tenant_id,
"indexer key not found on deregister_dp_rank; tree will not be cleaned up"
);
}
Ok(())
}
pub async fn deregister_all_tenants(
&self,
instance_id: WorkerId,
model_name: &str,
) -> Result<()> {
let (_, entry) = self
.workers
.remove(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
for cancel in entry.cancels.values() {
cancel.cancel();
}
let mut found = false;
for ie in self.indexers.iter() {
if ie.key().model_name == model_name {
ie.indexer.remove_worker(instance_id).await;
found = true;
}
}
if !found {
tracing::warn!(
instance_id,
model_name,
"no indexers found for model on deregister_all_tenants; tree will not be cleaned up"
);
}
Ok(())
......@@ -102,19 +225,18 @@ impl WorkerRegistry {
pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> {
self.workers
.iter()
.map(|entry| {
let endpoints: HashMap<u32, String> = entry
.value()
.endpoints
.iter()
.map(|(&dp_rank, e)| (dp_rank, e.endpoint.clone()))
.collect();
(*entry.key(), endpoints)
})
.map(|entry| (*entry.key(), entry.value().endpoints.clone()))
.collect()
}
pub fn indexer(&self) -> &Indexer {
&self.indexer
pub fn get_indexer(&self, key: &IndexerKey) -> Option<Ref<'_, IndexerKey, IndexerEntry>> {
self.indexers.get(key)
}
pub fn all_indexers(&self) -> Vec<(IndexerKey, Indexer)> {
self.indexers
.iter()
.map(|entry| (entry.key().clone(), entry.value().indexer.clone()))
.collect()
}
}
......@@ -13,17 +13,24 @@ use serde::{Deserialize, Serialize};
use dynamo_kv_router::protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq};
use super::registry::WorkerRegistry;
use super::registry::{IndexerKey, WorkerRegistry};
pub struct AppState {
pub registry: WorkerRegistry,
pub block_size: u32,
}
fn default_tenant() -> String {
"default".to_string()
}
#[derive(Deserialize)]
pub struct RegisterRequest {
pub instance_id: WorkerId,
pub endpoint: String,
pub model_name: String,
#[serde(default = "default_tenant")]
pub tenant_id: String,
pub block_size: u32,
#[serde(default)]
pub dp_rank: Option<u32>,
}
......@@ -31,6 +38,9 @@ pub struct RegisterRequest {
#[derive(Deserialize)]
pub struct UnregisterRequest {
pub instance_id: WorkerId,
pub model_name: String,
#[serde(default)]
pub tenant_id: Option<String>,
#[serde(default)]
pub dp_rank: Option<u32>,
}
......@@ -44,13 +54,24 @@ struct WorkerInfo {
#[derive(Deserialize)]
pub struct QueryRequest {
pub token_ids: Vec<u32>,
pub model_name: String,
#[serde(default = "default_tenant")]
pub tenant_id: String,
#[serde(default)]
pub lora_name: Option<String>,
}
/// Query using pre-computed block hashes.
///
/// Callers must include the LoRA salt in their hashes when applicable — use
/// [`compute_block_hash_for_seq`] with the appropriate `lora_name`. The indexer
/// cannot retroactively apply a LoRA salt to pre-computed hashes.
#[derive(Deserialize)]
pub struct QueryByHashRequest {
pub block_hashes: Vec<i64>,
pub model_name: String,
#[serde(default = "default_tenant")]
pub tenant_id: String,
}
#[derive(Serialize)]
......@@ -64,10 +85,14 @@ async fn register(
State(state): State<Arc<AppState>>,
Json(req): Json<RegisterRequest>,
) -> impl IntoResponse {
match state
.registry
.register(req.instance_id, req.endpoint, req.dp_rank.unwrap_or(0))
{
match state.registry.register(
req.instance_id,
req.endpoint,
req.dp_rank.unwrap_or(0),
req.model_name,
req.tenant_id,
req.block_size,
) {
Ok(()) => (
StatusCode::CREATED,
Json(serde_json::json!({"status": "ok"})),
......@@ -83,14 +108,27 @@ async fn unregister(
State(state): State<Arc<AppState>>,
Json(req): Json<UnregisterRequest>,
) -> impl IntoResponse {
let result = match req.dp_rank {
Some(dp_rank) => {
let result = match req.tenant_id {
Some(tenant_id) => match req.dp_rank {
Some(dp_rank) => {
state
.registry
.deregister_dp_rank(req.instance_id, dp_rank, &req.model_name, &tenant_id)
.await
}
None => {
state
.registry
.deregister(req.instance_id, &req.model_name, &tenant_id)
.await
}
},
None => {
state
.registry
.deregister_dp_rank(req.instance_id, dp_rank)
.deregister_all_tenants(req.instance_id, &req.model_name)
.await
}
None => state.registry.deregister(req.instance_id).await,
};
match result {
Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
......@@ -114,13 +152,16 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> impl IntoResponse {
Json(workers)
}
fn build_score_response(overlap: dynamo_kv_router::protocols::OverlapScores) -> ScoreResponse {
fn build_score_response(
overlap: dynamo_kv_router::protocols::OverlapScores,
block_size: u32,
) -> ScoreResponse {
let mut scores: HashMap<String, HashMap<String, u32>> = HashMap::new();
for (k, v) in &overlap.scores {
scores
.entry(k.worker_id.to_string())
.or_default()
.insert(k.dp_rank.to_string(), *v);
.insert(k.dp_rank.to_string(), v * block_size);
}
let mut tree_sizes: HashMap<String, HashMap<String, usize>> = HashMap::new();
for (k, v) in &overlap.tree_sizes {
......@@ -140,16 +181,28 @@ async fn query(
State(state): State<Arc<AppState>>,
Json(req): Json<QueryRequest>,
) -> impl IntoResponse {
let block_hashes = compute_block_hash_for_seq(
&req.token_ids,
state.block_size,
None,
req.lora_name.as_deref(),
);
match state.registry.indexer().find_matches(block_hashes).await {
let key = IndexerKey {
model_name: req.model_name,
tenant_id: req.tenant_id,
};
let Some(ie) = state.registry.get_indexer(&key) else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": format!("no indexer for model={} tenant={}", key.model_name, key.tenant_id)
})),
);
};
let block_size = ie.block_size;
let indexer = ie.indexer.clone();
drop(ie);
let block_hashes =
compute_block_hash_for_seq(&req.token_ids, block_size, None, req.lora_name.as_deref());
match indexer.find_matches(block_hashes).await {
Ok(overlap) => (
StatusCode::OK,
Json(serde_json::json!(build_score_response(overlap))),
Json(serde_json::json!(build_score_response(overlap, block_size))),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
......@@ -162,15 +215,31 @@ async fn query_by_hash(
State(state): State<Arc<AppState>>,
Json(req): Json<QueryByHashRequest>,
) -> impl IntoResponse {
let key = IndexerKey {
model_name: req.model_name,
tenant_id: req.tenant_id,
};
let Some(ie) = state.registry.get_indexer(&key) else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": format!("no indexer for model={} tenant={}", key.model_name, key.tenant_id)
})),
);
};
let block_size = ie.block_size;
let indexer = ie.indexer.clone();
drop(ie);
let block_hashes: Vec<LocalBlockHash> = req
.block_hashes
.iter()
.map(|h| LocalBlockHash(*h as u64))
.collect();
match state.registry.indexer().find_matches(block_hashes).await {
match indexer.find_matches(block_hashes).await {
Ok(overlap) => (
StatusCode::OK,
Json(serde_json::json!(build_score_response(overlap))),
Json(serde_json::json!(build_score_response(overlap, block_size))),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
......@@ -180,13 +249,33 @@ async fn query_by_hash(
}
async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse {
match state.registry.indexer().dump_events().await {
Ok(events) => (StatusCode::OK, Json(serde_json::json!(events))),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": e.to_string()})),
),
let indexers = state.registry.all_indexers();
let mut handles = Vec::with_capacity(indexers.len());
for (key, indexer) in indexers {
handles.push(tokio::spawn(async move {
let events = indexer.dump_events().await;
(key, events)
}));
}
let mut result: HashMap<String, serde_json::Value> = HashMap::new();
for handle in handles {
match handle.await {
Ok((key, Ok(events))) => {
let map_key = format!("{}:{}", key.model_name, key.tenant_id);
result.insert(map_key, serde_json::json!(events));
}
Ok((key, Err(e))) => {
let map_key = format!("{}:{}", key.model_name, key.tenant_id);
result.insert(map_key, serde_json::json!({"error": e.to_string()}));
}
Err(e) => {
tracing::warn!("dump task join error: {e}");
}
}
}
(StatusCode::OK, Json(serde_json::json!(result)))
}
pub fn create_router(state: Arc<AppState>) -> Router {
......
......@@ -515,6 +515,25 @@ impl ConcurrentRadixTree {
}
}
fn remove_worker_dp_rank(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
worker_id: WorkerId,
dp_rank: DpRank,
) {
let key = WorkerWithDpRank { worker_id, dp_rank };
if let Some(worker_lookup) = lookup.remove(&key) {
for (_, block) in worker_lookup.into_iter() {
let mut guard = block.write();
guard.workers.remove(&key);
if guard.workers.is_empty() {
guard.children.clear();
}
}
self.tree_sizes.remove(&key);
}
}
/// Clear all blocks for a worker but keep the worker tracked.
fn clear_all_blocks(
&self,
......@@ -616,6 +635,9 @@ impl SyncIndexer for ConcurrentRadixTree {
WorkerTask::RemoveWorker(worker_id) => {
self.remove_or_clear_worker_blocks(&mut lookup, worker_id, false);
}
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank(&mut lookup, worker_id, dp_rank);
}
WorkerTask::DumpEvents(_sender) => {
// Handled directly via dump_events() on the shared tree.
// Should not be reached, but respond with empty to avoid blocking.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use std::future::Future;
use crate::protocols::RouterEvent;
/// Transport abstraction for publishing batched KV cache events.
pub trait EventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
......@@ -327,6 +327,14 @@ pub trait KvIndexerInterface {
/// * `worker` - The worker to remove from the trie.
async fn remove_worker(&self, worker: WorkerId);
/// Remove a single dp_rank for a worker from the trie.
///
/// Default implementation falls back to removing the entire worker.
/// Indexers that track dp_rank-level granularity should override this.
async fn remove_worker_dp_rank(&self, worker: WorkerId, _dp_rank: DpRank) {
self.remove_worker(worker).await;
}
/// Shutdown the KV Indexer.
fn shutdown(&self);
......@@ -363,6 +371,8 @@ pub enum WorkerTask {
Event(RouterEvent),
/// Permanently remove a worker from tracking (keep_worker: false).
RemoveWorker(WorkerId),
/// Remove a single dp_rank for a worker.
RemoveWorkerDpRank(WorkerId, DpRank),
DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>),
Terminate,
}
......@@ -568,6 +578,14 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
}
async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) {
// Broadcast to all threads — the dp_rank may be on any thread.
// Don't remove from worker_assignments since other dp_ranks may still exist.
for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank));
}
}
fn shutdown(&self) {
// Send shutdown signal to all worker threads
for channel in self.worker_event_channels.iter() {
......@@ -668,6 +686,8 @@ pub struct KvIndexer {
match_tx: mpsc::Sender<MatchRequest>,
/// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for remove worker dp_rank requests.
remove_worker_dp_rank_tx: mpsc::Sender<(WorkerId, DpRank)>,
/// A sender for get workers requests.
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
/// A sender for dump requests.
......@@ -704,6 +724,8 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (remove_worker_dp_rank_tx, remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
let (get_workers_tx, get_workers_rx) = mpsc::channel::<GetWorkersRequest>(16);
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let (routing_tx, mut routing_rx) = mpsc::channel::<RoutingDecisionRequest>(2048);
......@@ -723,6 +745,7 @@ impl KvIndexer {
let mut match_rx = match_rx;
let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx;
let mut remove_worker_dp_rank_rx = remove_worker_dp_rank_rx;
let mut get_workers_rx = get_workers_rx;
let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration);
......@@ -754,6 +777,10 @@ impl KvIndexer {
trie.remove_worker(worker);
}
Some((worker_id, dp_rank)) = remove_worker_dp_rank_rx.recv() => {
trie.remove_worker_dp_rank(worker_id, dp_rank);
}
Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
......@@ -933,6 +960,7 @@ impl KvIndexer {
event_tx,
match_tx,
remove_worker_tx,
remove_worker_dp_rank_tx,
get_workers_tx,
dump_tx,
routing_tx,
......@@ -1052,6 +1080,13 @@ impl KvIndexerInterface for KvIndexer {
self.remove_worker_tx.send(worker).await.unwrap();
}
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
self.remove_worker_dp_rank_tx
.send((worker, dp_rank))
.await
.unwrap();
}
fn shutdown(&self) {
self.cancel.cancel();
}
......@@ -1461,6 +1496,7 @@ pub struct KvIndexerSharded {
event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
remove_worker_dp_rank_tx: Vec<mpsc::Sender<(WorkerId, DpRank)>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Arc<Mutex<Vec<JoinHandle<()>>>>,
......@@ -1493,6 +1529,7 @@ impl KvIndexerSharded {
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut remove_worker_dp_rank_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new();
......@@ -1504,6 +1541,8 @@ impl KvIndexerSharded {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
mpsc::channel::<WorkerId>(16);
let (shard_remove_worker_dp_rank_tx, mut shard_remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
let (shard_get_workers_tx, mut shard_get_workers_rx) =
mpsc::channel::<GetWorkersRequest>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16);
......@@ -1517,6 +1556,7 @@ impl KvIndexerSharded {
event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx);
remove_worker_dp_rank_tx.push(shard_remove_worker_dp_rank_tx);
get_workers_tx.push(shard_get_workers_tx);
dump_tx.push(shard_dump_tx);
routing_tx.push(shard_routing_tx);
......@@ -1557,6 +1597,10 @@ impl KvIndexerSharded {
trie.remove_worker(worker);
}
Some((worker_id, dp_rank)) = shard_remove_worker_dp_rank_rx.recv() => {
trie.remove_worker_dp_rank(worker_id, dp_rank);
}
Some(get_workers_req) = shard_get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
......@@ -1736,6 +1780,7 @@ impl KvIndexerSharded {
event_tx,
request_broadcast_tx,
remove_worker_tx,
remove_worker_dp_rank_tx,
dump_tx,
routing_tx,
tasks,
......@@ -1860,6 +1905,17 @@ impl KvIndexerInterface for KvIndexerSharded {
}
}
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
// Worker is assigned to a single shard, so route there directly.
// Don't remove from worker_assignments since other dp_ranks may still exist.
if let Some(shard) = self.worker_assignments.get(&worker) {
self.remove_worker_dp_rank_tx[*shard]
.send((worker, dp_rank))
.await
.unwrap();
}
}
/// Shutdown the KV Indexer.
fn shutdown(&self) {
self.cancel.cancel();
......
......@@ -8,6 +8,7 @@
pub mod approx;
pub mod concurrent_radix_tree;
pub mod event_sink;
pub mod indexer;
#[cfg(feature = "bench")]
pub mod naive_indexers;
......@@ -36,6 +37,7 @@ pub use self::multi_worker_sequence::{
pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use config::{KvRouterConfig, RouterConfigOverride};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
#[cfg(feature = "bench")]
pub use naive_indexers::{InvertedIndex, NaiveNestedMap};
......
......@@ -26,8 +26,9 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use crate::indexer::{SyncIndexer, WorkerTask};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank,
DpRank, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
WorkerWithDpRank,
};
/// Entry for the innermost level of the index.
......@@ -150,6 +151,9 @@ impl SyncIndexer for PositionalIndexer {
WorkerTask::RemoveWorker(worker_id) => {
self.remove_or_clear_worker_blocks_impl(&mut worker_blocks, worker_id, false);
}
WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => {
self.remove_worker_dp_rank_impl(&mut worker_blocks, worker_id, dp_rank);
}
WorkerTask::DumpEvents(sender) => {
let events = self.dump_events(&worker_blocks);
if let Err(e) = sender.send(Ok(events)) {
......@@ -329,6 +333,23 @@ impl PositionalIndexer {
self.remove_or_clear_worker_blocks_impl(worker_blocks, worker_id, true);
}
fn remove_worker_dp_rank_impl(
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
worker_id: WorkerId,
dp_rank: DpRank,
) {
let key = WorkerWithDpRank { worker_id, dp_rank };
if let Some(worker_map) = worker_blocks.remove(&key) {
for (seq_hash, (position, local_hash)) in worker_map.iter() {
if let Some(mut entry) = self.index.get_mut(&(*position, *local_hash)) {
let _ = entry.remove(*seq_hash, key);
}
}
self.tree_sizes.remove(&key);
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed.
......
......@@ -497,6 +497,18 @@ impl RadixTree {
self.remove_or_clear_worker_blocks(worker_id, false);
}
pub fn remove_worker_dp_rank(&mut self, worker_id: WorkerId, dp_rank: DpRank) {
let key = WorkerWithDpRank { worker_id, dp_rank };
if let Some(blocks) = self.lookup.remove(&key) {
for (_, block) in blocks {
block.borrow_mut().workers.remove(&key);
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
}
}
}
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use anyhow::Result;
use async_trait::async_trait;
use rmp_serde as rmps;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
......@@ -289,7 +289,7 @@ impl KvEventPublisher {
};
start_event_processor(
event_publisher,
EventPlanePublisher(event_publisher),
worker_id,
cancellation_token_clone,
rx,
......@@ -315,7 +315,7 @@ impl KvEventPublisher {
return;
}
start_event_processor_jetstream(
nats_queue,
JetStreamPublisher(nats_queue),
worker_id,
cancellation_token_clone,
rx,
......@@ -366,22 +366,21 @@ impl Drop for KvEventPublisher {
}
}
#[async_trait]
trait EventSink: Send + Sync {
async fn publish_event(&self, event: &RouterEvent) -> Result<()>;
}
use dynamo_kv_router::EventSink;
struct EventPlanePublisher(EventPublisher);
#[async_trait]
impl EventSink for EventPublisher {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
self.publish(event).await
impl EventSink for EventPlanePublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.0.publish(event)
}
}
#[async_trait]
impl EventSink for NatsQueue {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
NatsQueue::publish_event(self, KV_EVENT_SUBJECT, event).await
struct JetStreamPublisher(NatsQueue);
impl EventSink for JetStreamPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
NatsQueue::publish_event(&self.0, KV_EVENT_SUBJECT, event)
}
}
......@@ -587,8 +586,8 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
}
/// Batched event processor using JetStream (durable).
async fn start_event_processor_jetstream(
publisher: NatsQueue,
async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<KvCacheEvent>,
......@@ -1278,15 +1277,17 @@ mod tests_startup_helpers {
}
}
#[async_trait::async_trait]
impl EventSink for MockComponent {
async fn publish_event(&self, event: &RouterEvent) -> anyhow::Result<()> {
fn publish_event(
&self,
event: &RouterEvent,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let bytes = rmp_serde::to_vec(event).unwrap();
self.published
.lock()
.unwrap()
.push((KV_EVENT_SUBJECT.to_string(), bytes));
Ok(())
async { Ok(()) }
}
}
......@@ -2253,11 +2254,10 @@ mod event_processor_tests {
}
}
#[async_trait]
impl EventSink for MockPublisher {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.events.lock().unwrap().push(event.clone());
Ok(())
async { Ok(()) }
}
}
......
......@@ -1669,8 +1669,17 @@ def _test_router_indexers_sync(
async with aiohttp.ClientSession() as session:
async with session.get(f"{standalone_indexer_url}/dump") as resp:
assert resp.status == 200, f"GET /dump failed: {resp.status}"
standalone_state = await resp.json()
dump_by_key = await resp.json()
# /dump returns {model:tenant -> events}, extract the expected key
expected_key = f"{model_name}:default"
assert expected_key in dump_by_key, (
f"Expected dump key '{expected_key}', "
f"got keys={list(dump_by_key.keys())}"
)
for k, v in dump_by_key.items():
assert isinstance(v, list), f"Dump key '{k}' returned error: {v}"
standalone_state = dump_by_key[expected_key]
sorted_standalone = sorted(standalone_state, key=sort_key)
logger.info(f"Standalone HTTP indexer has {len(sorted_standalone)} events")
......@@ -2197,7 +2206,7 @@ def _test_router_decisions(
async with aiohttp.ClientSession() as session:
async with session.post(
f"{standalone_indexer_url}/query",
json={"token_ids": req4_tokens},
json={"token_ids": req4_tokens, "model_name": model_name},
) as resp:
assert resp.status == 200, f"POST /query failed: {resp.status}"
scores = (await resp.json())["scores"]
......
......@@ -190,10 +190,12 @@ class MockerProcess:
request_plane: str = "nats",
zmq_kv_events: bool = False,
standalone_indexer: bool = False,
model_name: str = "mocker",
):
namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}"
self.component_name = "mocker"
self.model_name = model_name
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_mockers
self._zmq_kv_events_ports: list[int] = []
......@@ -386,6 +388,10 @@ class MockerProcess:
"instance_id": new_worker_id,
"endpoint": endpoint,
"dp_rank": dp_rank,
"model_name": self.model_name,
"block_size": self._mocker_args_orig.get(
"block_size", BLOCK_SIZE
),
}
async with session.post(register_url, json=payload) as resp:
if resp.status != 201:
......@@ -897,6 +903,7 @@ def test_router_decisions(
request_plane=request_plane,
zmq_kv_events=zmq_kv_events,
standalone_indexer=zmq_kv_events,
model_name=MODEL_NAME,
) as mockers:
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment