kv.rs 40.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use pythonize::{depythonize, pythonize};
5
use std::collections::HashMap;
6
use std::ffi::OsString;
7
use std::sync::Arc;
8
use std::sync::atomic::AtomicU32;
9
use std::sync::mpsc;
10
use tokio_stream::StreamExt;
11

12
use super::*;
13
use crate::Endpoint;
14
15
#[cfg(feature = "kv-indexer")]
use clap::Parser;
16
17
18
use dynamo_kv_router::config::{KvRouterConfig, RouterConfigOverride};
use dynamo_kv_router::protocols::compute_block_hash_for_seq;
use dynamo_kv_router::protocols::*;
19
20
#[cfg(feature = "kv-indexer")]
use dynamo_kv_router::standalone_indexer::{self, IndexerConfig};
Yan Ru Pei's avatar
Yan Ru Pei committed
21
use rs::pipeline::{AsyncEngine, SingleIn};
22
use rs::protocols::annotated::Annotated as RsAnnotated;
23
use tracing;
24

25
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
26
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
27
use llm_rs::protocols::common::timing::RequestTracker;
28
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
29
use serde_json::json;
30

31
32
33
use super::aic_callback::create_aic_prefill_load_estimator;
use super::entrypoint::AicPerfConfig;

34
35
36
37
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
    depythonize(obj).map_err(to_pyerr)
}

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#[cfg(feature = "kv-indexer")]
#[derive(Parser)]
#[command(
    name = "python -m dynamo.indexer",
    about = "Standalone KV cache indexer"
)]
struct KvIndexerCli {
    /// KV cache block size for initial workers registered via --workers
    #[arg(long)]
    block_size: Option<u32>,

    /// HTTP server port
    #[arg(long, default_value_t = 8090)]
    port: u16,

    /// Number of indexer threads (1 = single-threaded KvIndexer, >1 = ThreadPoolIndexer)
    #[arg(long, default_value_t = 4)]
    threads: usize,

    /// Initial workers as "worker_id[:dp_rank]=zmq_address,..." (e.g. "1=tcp://host:5557,1:1=tcp://host:5558")
    #[arg(long)]
    workers: Option<String>,

    /// Model name for initial workers registered via --workers
    #[arg(long, default_value = "default")]
    model_name: String,

    /// Tenant ID for initial workers registered via --workers
    #[arg(long, default_value = "default")]
    tenant_id: String,

    /// Comma-separated peer URLs for P2P recovery (e.g. "http://host1:8090,http://host2:8091")
    #[arg(long)]
    peers: Option<String>,
}

pub fn run_kv_indexer_cli<I, T>(args: I) -> anyhow::Result<()>
where
    I: IntoIterator<Item = T>,
    T: Into<OsString>,
{
    #[cfg(feature = "kv-indexer")]
    {
        let cli = KvIndexerCli::try_parse_from(
            std::iter::once(OsString::from("python -m dynamo.indexer"))
                .chain(args.into_iter().map(Into::into)),
        )?;

        init_standalone_logging();

        let rt = tokio::runtime::Runtime::new()?;
        rt.block_on(standalone_indexer::run_server(IndexerConfig {
            block_size: cli.block_size,
            port: cli.port,
            threads: cli.threads,
            workers: cli.workers,
            model_name: cli.model_name,
            tenant_id: cli.tenant_id,
            peers: cli.peers,
        }))
    }

    #[cfg(not(feature = "kv-indexer"))]
    {
        let _ = args;
        anyhow::bail!(
            "dynamo.indexer is not available in this build; reinstall with --features kv-indexer"
        )
    }
}

#[cfg(feature = "kv-indexer")]
fn init_standalone_logging() {
    let _ = tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
        )
        .try_init();
}

Yan Ru Pei's avatar
Yan Ru Pei committed
119
#[pyfunction]
120
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None, is_eagle=None))]
121
122
123
124
125
pub fn compute_block_hash_for_seq_py(
    _py: Python,
    tokens: Vec<u32>,
    kv_block_size: usize,
    block_mm_infos: Option<Bound<PyAny>>,
126
    lora_name: Option<String>,
127
    is_eagle: Option<bool>,
128
) -> PyResult<Vec<u64>> {
Yan Ru Pei's avatar
Yan Ru Pei committed
129
    if kv_block_size == 0 {
130
131
132
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "kv_block_size cannot be 0",
        ));
Yan Ru Pei's avatar
Yan Ru Pei committed
133
134
    }

135
    let mm_infos = block_mm_infos
136
        .as_ref()
137
        .map(depythonize_block_mm_infos)
138
139
        .transpose()?;

140
141
142
    let hashes = compute_block_hash_for_seq(
        &tokens,
        kv_block_size as u32,
143
144
145
146
147
        BlockHashOptions {
            block_mm_infos: mm_infos.as_deref(),
            lora_name: lora_name.as_deref(),
            is_eagle,
        },
148
    );
149

Yan Ru Pei's avatar
Yan Ru Pei committed
150
151
152
    Ok(hashes.into_iter().map(|h| h.0).collect())
}

GuanLuo's avatar
GuanLuo committed
153
#[pyclass]
154
155
pub(crate) struct WorkerMetricsPublisher {
    inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
GuanLuo's avatar
GuanLuo committed
156
157
158
}

#[pymethods]
159
impl WorkerMetricsPublisher {
GuanLuo's avatar
GuanLuo committed
160
161
    #[new]
    fn new() -> PyResult<Self> {
162
163
        let inner =
            llm_rs::kv_router::publisher::WorkerMetricsPublisher::new().map_err(to_pyerr)?;
GuanLuo's avatar
GuanLuo committed
164
165
166
167
168
        Ok(Self {
            inner: inner.into(),
        })
    }

169
    #[pyo3(signature = (endpoint))]
Alec's avatar
Alec committed
170
    fn create_endpoint<'p>(
GuanLuo's avatar
GuanLuo committed
171
172
        &self,
        py: Python<'p>,
173
        endpoint: Endpoint,
GuanLuo's avatar
GuanLuo committed
174
175
    ) -> PyResult<Bound<'p, PyAny>> {
        let rs_publisher = self.inner.clone();
176
        let rs_component = endpoint.inner.component().clone();
GuanLuo's avatar
GuanLuo committed
177
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
178
            rs_publisher
179
                .create_endpoint(rs_component)
GuanLuo's avatar
GuanLuo committed
180
181
182
183
184
185
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

186
187
188
189
    /// Publish worker metrics for load monitoring.
    ///
    /// # Arguments
    /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
190
191
192
193
194
195
196
197
198
    /// * `active_decode_blocks` - Scheduler-compatible active decode block count
    /// * `kv_used_blocks` - Authoritative total KV blocks currently in use
    #[pyo3(signature = (dp_rank=None, active_decode_blocks=None, kv_used_blocks=None))]
    fn publish(
        &self,
        dp_rank: Option<u32>,
        active_decode_blocks: Option<u64>,
        kv_used_blocks: Option<u64>,
    ) -> PyResult<()> {
GuanLuo's avatar
GuanLuo committed
199
        self.inner
200
            .publish(dp_rank, active_decode_blocks, kv_used_blocks)
GuanLuo's avatar
GuanLuo committed
201
202
203
            .map_err(to_pyerr)
    }
}
204

205
206
207
#[pyclass]
pub(crate) struct KvEventPublisher {
    inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
208
    kv_block_size: usize,
Yan Ru Pei's avatar
Yan Ru Pei committed
209
    dp_rank: DpRank,
210
    warning_count: Arc<AtomicU32>,
211
212
213
214
}

#[pymethods]
impl KvEventPublisher {
215
216
217
218
219
220
221
222
223
224
225
226
    /// Create a KV event publisher that batches raw engine events before forwarding
    /// them to NATS / the event plane.
    ///
    /// Args:
    ///     endpoint: The Dynamo component endpoint for this worker.
    ///     worker_id: Identifier of this worker (default 0).
    ///     kv_block_size: KV cache block size in tokens; must be > 0.
    ///     dp_rank: Data-parallel rank of this worker (default 0).
    ///     enable_local_indexer: When True, a local KV indexer is kept in-process
    ///         so that routers can recover events directly from this worker.
    ///     zmq_endpoint: Optional ZMQ SUB endpoint to read raw engine events from.
    ///     zmq_topic: ZMQ topic filter (default "").
227
    ///     batching_timeout_ms: Maximum time (in **milliseconds**) to accumulate
228
    ///         events into a single batch before flushing.
229
230
231
232
    ///         ``None`` disables batching: every event is published immediately.
    ///         ``50`` to enable batching with a 50 ms window.
    ///         ``0`` is treated as ``None`` (also disables batching).
    ///         Maximum allowed is 15_000 (15 seconds); larger values are capped.
233
    #[new]
234
    #[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None, batching_timeout_ms=llm_rs::kv_router::publisher::DEFAULT_BATCHING_TIMEOUT_MS))]
235
    #[allow(clippy::too_many_arguments)]
236
    fn new(
237
        endpoint: Endpoint,
238
239
240
        worker_id: WorkerId,
        kv_block_size: usize,
        dp_rank: DpRank,
241
        enable_local_indexer: bool,
242
243
        zmq_endpoint: Option<String>,
        zmq_topic: Option<String>,
244
        batching_timeout_ms: Option<u64>,
245
    ) -> PyResult<Self> {
246
247
        let _ = worker_id;

248
249
        let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
            endpoint: ep,
250
251
            topic: zmq_topic.unwrap_or_default(),
        });
252

Yan Ru Pei's avatar
Yan Ru Pei committed
253
254
255
256
        if kv_block_size == 0 {
            return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
        }

257
258
259
        // Extract component from endpoint
        let component = endpoint.inner.component().clone();

260
        let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
261
            component,
262
            kv_block_size as u32,
263
            source_config,
264
            enable_local_indexer,
265
            dp_rank,
266
            batching_timeout_ms,
267
268
        )
        .map_err(to_pyerr)?;
269

270
271
        Ok(Self {
            inner: inner.into(),
272
            kv_block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
273
            dp_rank,
274
            warning_count: Arc::new(AtomicU32::new(0)),
275
276
277
278
        })
    }

    #[allow(clippy::too_many_arguments)]
279
    #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None, is_eagle=None))]
280
    fn publish_stored(
281
        &self,
282
        py: Python,
283
284
        token_ids: Vec<u32>,
        num_block_tokens: Vec<u64>,
285
286
        block_hashes: Vec<i64>,
        parent_hash: Option<i64>,
287
        block_mm_infos: Option<Bound<PyAny>>,
288
        lora_name: Option<String>,
289
        is_eagle: Option<bool>,
290
    ) -> PyResult<()> {
291
292
293
294
295
        let kv_block_size = self.kv_block_size as u32;
        let dp_rank = self.dp_rank;
        let warning_count = self.warning_count.clone();
        let inner = self.inner.clone();

296
297
        let event_id = inner.next_event_id();

298
        let mm_infos = block_mm_infos
299
            .as_ref()
300
            .map(depythonize_block_mm_infos)
301
302
            .transpose()?;

303
304
305
306
307
308
        py.allow_threads(|| {
            let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
309
                    start_position: None,
310
311
312
313
314
                    blocks: create_stored_blocks(
                        kv_block_size,
                        &token_ids,
                        &num_block_tokens,
                        &block_hashes_u64,
315
                        lora_name.as_deref(),
316
                        &warning_count,
317
                        mm_infos.as_deref(),
318
                        is_eagle,
319
320
321
322
323
324
325
                    ),
                }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
326
327
    }

328
    fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
329
330
331
        let dp_rank = self.dp_rank;
        let inner = self.inner.clone();

332
333
334
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

335
336
337
338
339
340
341
342
343
344
345
346
347
        py.allow_threads(|| {
            let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
                .into_iter()
                .map(ExternalSequenceBlockHash::from)
                .collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
348
    }
349
350
351
352
353
354
355
356

    fn shutdown(&mut self) {
        // If no other Arc clones exist, shut down eagerly.
        // Otherwise the Drop impl handles cleanup when the last reference is freed.
        if let Some(inner) = Arc::get_mut(&mut self.inner) {
            inner.shutdown();
        }
    }
357
358
}

359
360
361
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
362
    inner: dynamo_kv_router::protocols::OverlapScores,
363
364
365
366
367
}

#[pymethods]
impl OverlapScores {
    #[getter]
368
    fn scores(&self) -> HashMap<(u64, u32), u32> {
Yan Ru Pei's avatar
Yan Ru Pei committed
369
370
371
372
373
374
        // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
        self.inner
            .scores
            .iter()
            .map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
            .collect()
375
376
377
378
379
380
381
382
    }

    #[getter]
    fn frequencies(&self) -> Vec<usize> {
        self.inner.frequencies.clone()
    }
}

383
384
385
#[derive(Debug)]
enum RadixTreeRequest {
    FindMatches {
386
        local_block_hashes: Vec<LocalBlockHash>,
387
        early_exit: bool,
388
        response_tx: mpsc::SyncSender<dynamo_kv_router::protocols::OverlapScores>,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    },
    ApplyEvent {
        worker_id: WorkerId,
        kv_cache_event_bytes: Vec<u8>,
        response_tx: mpsc::SyncSender<PyResult<()>>,
    },
    RemoveWorker {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    ClearAllBlocks {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    DumpTreeAsEvents {
404
        response_tx: mpsc::SyncSender<Vec<RouterEvent>>,
405
406
407
408
409
410
    },
    Shutdown,
}

// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
Yan Ru Pei's avatar
Yan Ru Pei committed
411
pub(crate) struct RadixTree {
412
    request_tx: mpsc::Sender<RadixTreeRequest>,
Yan Ru Pei's avatar
Yan Ru Pei committed
413
414
415
416
417
418
419
420
}

#[pymethods]
impl RadixTree {
    #[new]
    #[pyo3(signature = (expiration_duration_secs=None))]
    fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
        let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
421
422
423
424
425
426

        let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();

        // Spawn dedicated thread with simplified sync processing
        std::thread::spawn(move || {
            let mut radix_tree =
427
                dynamo_kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

            loop {
                match request_rx.recv() {
                    Ok(RadixTreeRequest::Shutdown) => {
                        tracing::debug!("RadixTree thread received shutdown request");
                        break;
                    }
                    Ok(request) => {
                        Self::handle_request(&mut radix_tree, request);
                    }
                    Err(mpsc::RecvError) => {
                        tracing::debug!("RadixTree request channel disconnected");
                        break;
                    }
                }
            }
        });

        Ok(Self { request_tx })
Yan Ru Pei's avatar
Yan Ru Pei committed
447
448
449
450
451
    }

    #[pyo3(signature = (sequence, early_exit=false))]
    fn find_matches(
        &self,
452
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
453
454
455
        sequence: Vec<u64>,
        early_exit: bool,
    ) -> PyResult<OverlapScores> {
456
457
        let (response_tx, response_rx) = mpsc::sync_channel(1);

458
459
        let local_block_hashes =
            py.allow_threads(|| sequence.into_iter().map(LocalBlockHash).collect());
460
461
462
463
464
465

        let request = RadixTreeRequest::FindMatches {
            local_block_hashes,
            early_exit,
            response_tx,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
466

467
468
469
470
471
472
473
474
475
476
477
478
479
480
        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })?;

        Ok(OverlapScores { inner: result })
Yan Ru Pei's avatar
Yan Ru Pei committed
481
482
483
    }

    fn apply_event(
484
485
        &self,
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
486
        worker_id: WorkerId,
Yan Ru Pei's avatar
Yan Ru Pei committed
487
488
        kv_cache_event_bytes: &[u8],
    ) -> PyResult<()> {
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ApplyEvent {
            worker_id,
            kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || response_rx.recv());

        result.map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
        })?
    }

    fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::RemoveWorker {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }

    fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ClearAllBlocks {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
554

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
        })?;

        // Release GIL while waiting for response from dedicated thread
        let events = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Failed to receive dump tree response",
                )
            })
        })?;

        // Serialize RouterEvent structs to JSON strings with GIL released
        py.allow_threads(move || {
            events
                .into_iter()
                .map(|event| {
                    serde_json::to_string(&event).map_err(|e| {
                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                            "Failed to serialize event to JSON: {}",
                            e
                        ))
                    })
                })
                .collect::<Result<Vec<String>, PyErr>>()
        })
Yan Ru Pei's avatar
Yan Ru Pei committed
587
    }
588
}
Yan Ru Pei's avatar
Yan Ru Pei committed
589

590
591
impl RadixTree {
    fn handle_request(
592
        radix_tree: &mut dynamo_kv_router::indexer::RadixTree,
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        request: RadixTreeRequest,
    ) {
        match request {
            RadixTreeRequest::FindMatches {
                local_block_hashes,
                early_exit,
                response_tx,
            } => {
                let result = radix_tree.find_matches(local_block_hashes, early_exit);
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::ApplyEvent {
                worker_id,
                kv_cache_event_bytes,
                response_tx,
            } => {
609
                let result = match serde_json::from_slice::<KvCacheEvent>(&kv_cache_event_bytes) {
610
                    Ok(kv_cache_event) => {
611
                        let router_event = RouterEvent::new(worker_id, kv_cache_event);
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
                        match radix_tree.apply_event(router_event) {
                            Ok(_) => Ok(()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Failed to apply event: {}", e),
                            )),
                        }
                    }
                    Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Failed to deserialize KvCacheEvent: {}",
                        e
                    ))),
                };
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::RemoveWorker {
                worker_id,
                response_tx,
            } => {
                radix_tree.remove_worker(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::ClearAllBlocks {
                worker_id,
                response_tx,
            } => {
                radix_tree.clear_all_blocks(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
                let events = radix_tree.dump_tree_as_events();
                let _ = response_tx.send(events);
            }
            RadixTreeRequest::Shutdown => {
                // This is handled in the main loop
            }
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
648
    }
649
}
Yan Ru Pei's avatar
Yan Ru Pei committed
650

651
652
653
654
655
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
    fn drop(&mut self) {
        // Only need graceful shutdown via RadixTreeRequest::Shutdown
        let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
Yan Ru Pei's avatar
Yan Ru Pei committed
656
657
658
    }
}

659
/// Helper function to create a KV router from an endpoint using the ModelManager
660
661
662
663
664
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
/// - If endpoint name/component contains "prefill", treat as prefill
/// - If router_track_active_blocks is disabled, treat as prefill
/// - Otherwise, default to decode
665
666
667
async fn create_kv_router_from_endpoint(
    endpoint: &Endpoint,
    block_size: usize,
668
    kv_router_config: Option<KvRouterConfig>,
669
    prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
670
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
671
    // Create ModelManager and use it to create KvRouter (ensures registration)
672
    let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    let endpoint_id = endpoint.inner.id();
    let namespace = endpoint_id.namespace.to_lowercase();
    let component = endpoint_id.component.to_lowercase();
    let name = endpoint_id.name.to_lowercase();
    let endpoint_is_prefill =
        namespace.contains("prefill") || component.contains("prefill") || name.contains("prefill");
    let track_active_blocks = kv_router_config
        .as_ref()
        .map(|cfg| cfg.router_track_active_blocks)
        .unwrap_or(true);
    let worker_type = if endpoint_is_prefill || !track_active_blocks {
        llm_rs::discovery::WORKER_TYPE_PREFILL
    } else {
        llm_rs::discovery::WORKER_TYPE_DECODE
    };
688

689
    // Query discovery once so we can derive both model_name (for remote/served indexer)
690
    // and Eagle routing semantics from the model card.
691
692
    let needs_model_name = kv_router_config
        .as_ref()
693
        .map(|cfg| cfg.use_remote_indexer || cfg.serve_indexer)
694
        .unwrap_or(false);
695
    let (model_name, enable_eagle) = {
696
697
698
699
700
701
702
703
704
705
        let discovery = endpoint.inner.component().drt().discovery();
        let instances = discovery
            .list(rs::discovery::DiscoveryQuery::EndpointModels {
                namespace: endpoint_id.namespace.clone(),
                component: endpoint_id.component.clone(),
                endpoint: endpoint_id.name.clone(),
            })
            .await
            .map_err(to_pyerr)?;

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        let maybe_card = instances.into_iter().find_map(|inst| {
            inst.deserialize_model::<llm_rs::model_card::ModelDeploymentCard>()
                .ok()
        });

        match maybe_card {
            Some(card) => {
                let model_name = needs_model_name.then(|| card.display_name.clone());
                (model_name, card.runtime_config.enable_eagle)
            }
            None => {
                tracing::warn!(
                    namespace = %endpoint_id.namespace,
                    component = %endpoint_id.component,
                    endpoint = %endpoint_id.name,
                    "No model card found in discovery; defaulting to non-Eagle routing semantics"
                );
                (None, false)
            }
        }
726
727
    };

728
    let kv_router = model_manager
729
730
731
732
        .kv_chooser_for(
            &endpoint.inner,
            block_size as u32,
            kv_router_config,
733
            prefill_load_estimator,
734
            worker_type,
735
            model_name,
736
            enable_eagle,
737
        )
738
739
740
741
742
743
        .await
        .map_err(to_pyerr)?;

    Ok(kv_router)
}

744
#[pyclass]
745
746
pub(crate) struct KvRouter {
    inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
747
748
}

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
    data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
    tracker: &RequestTracker,
) {
    let Some(worker_info) = tracker.get_worker_info() else {
        return;
    };

    let worker_id_json =
        serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");

    if let Some(obj) = data
        .disaggregated_params
        .as_mut()
        .and_then(|p| p.as_object_mut())
    {
        obj.insert("worker_id".to_string(), worker_id_json);
    } else {
        data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
774
// TODO: can this reuse the stream conversion method in Client bindings?
775
impl KvRouter {
Yan Ru Pei's avatar
Yan Ru Pei committed
776
777
778
    /// Helper method to process a request and create a Python async generator
    fn process_request_to_stream<'p>(
        py: Python<'p>,
779
        inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
780
        request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
781
        tracker: Option<Arc<RequestTracker>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
782
783
784
785
    ) -> PyResult<Bound<'p, PyAny>> {
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let single_in = SingleIn::new(request);
            let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
786
            let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
Yan Ru Pei's avatar
Yan Ru Pei committed
787
788
789

            tokio::spawn(async move {
                let mut stream = stream;
790
                let mut first_item = true;
791
                let mut first_token_gauges_observed = false;
792
793
794
795
796
797
798
799
800

                while let Some(mut response) = stream.next().await {
                    if first_item {
                        first_item = false;
                        if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
                            inject_worker_id_from_tracker(data, tracker);
                        }
                    }

801
802
803
804
805
806
807
808
809
810
811
812
813
814
                    if !first_token_gauges_observed {
                        let has_tokens = response
                            .data
                            .as_ref()
                            .map(|d| !d.token_ids.is_empty())
                            .unwrap_or(false);
                        if has_tokens {
                            if let Some(ref tracker) = tracker {
                                tracker.observe_first_token_gauges();
                            }
                            first_token_gauges_observed = true;
                        }
                    }

Yan Ru Pei's avatar
Yan Ru Pei committed
815
816
817
818
819
820
821
822
                    let py_response = Python::with_gil(|py| {
                        pythonize(py, &response.data)
                            .map(|obj| obj.unbind())
                            .map_err(|e| e.to_string())
                    });

                    match py_response {
                        Ok(obj) => {
823
824
                            if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
                                break;
Yan Ru Pei's avatar
Yan Ru Pei committed
825
826
827
828
829
830
831
832
                            }
                        }
                        Err(e) => {
                            tracing::error!("Failed to pythonize response: {}", e);
                            break;
                        }
                    }
                }
833
834
835
836

                if let Some(ref tracker) = tracker {
                    tracker.observe_finish_gauges();
                }
Yan Ru Pei's avatar
Yan Ru Pei committed
837
838
            });

839
            Ok(crate::AsyncResponseStream::new(rx, false))
Yan Ru Pei's avatar
Yan Ru Pei committed
840
841
        })
    }
842
843
844
}

#[pymethods]
845
846
impl KvRouter {
    /// Create a new KvRouter for KV-aware routing to workers.
847
848
849
850
851
852
853
854
    ///
    /// # Arguments
    /// * `endpoint` - The endpoint to route requests to
    /// * `block_size` - KV cache block size for routing decisions
    /// * `kv_router_config` - Configuration for the KV router
    ///
    /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
    /// (contains "prefill") or by `router_track_active_blocks` being disabled.
855
    #[new]
856
    #[pyo3(signature = (endpoint, block_size, kv_router_config, aic_perf_config=None))]
857
858
859
860
    fn new(
        endpoint: &Endpoint,
        block_size: usize,
        kv_router_config: &super::entrypoint::KvRouterConfig,
861
        aic_perf_config: Option<&AicPerfConfig>,
862
    ) -> PyResult<Self> {
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        let prefill_load_estimator = aic_perf_config
            .map(|config| {
                Python::with_gil(|py| {
                    create_aic_prefill_load_estimator(
                        py,
                        config.backend_name(),
                        config.system(),
                        config.model_path(),
                        config.tp_size(),
                        config.backend_version(),
                    )
                })
            })
            .transpose()
            .map_err(to_pyerr)?;

879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        let runtime = pyo3_async_runtimes::tokio::get_runtime();
        runtime.block_on(async move {
            let client = endpoint.inner.client().await.map_err(to_pyerr)?;

            // Create PushRouter with KV router mode
            let push_router = rs::pipeline::PushRouter::<
                llm_rs::protocols::common::preprocessor::PreprocessedRequest,
                rs::protocols::annotated::Annotated<
                    llm_rs::protocols::common::llm_backend::LLMEngineOutput,
                >,
            >::from_client(
                client,
                rs::pipeline::network::egress::push_router::RouterMode::KV,
            )
            .await
            .map_err(to_pyerr)?;

896
897
898
899
900
            // Create KvRouter using helper function (ensures etcd registration)
            let kv_router = create_kv_router_from_endpoint(
                endpoint,
                block_size,
                Some(kv_router_config.inner()),
901
                prefill_load_estimator,
902
903
            )
            .await?;
904

905
            let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
906
907
908
909
910
911
912
913

            Ok(Self {
                inner: Arc::new(kv_push_router),
            })
        })
    }

    #[allow(clippy::too_many_arguments)]
914
    #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
915
916
917
918
919
920
921
922
923
    fn generate<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        model: String,
        stop_conditions: Option<PyObject>,
        sampling_options: Option<PyObject>,
        output_options: Option<PyObject>,
        router_config_override: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
924
925
        worker_id: Option<WorkerId>,
        dp_rank: Option<DpRank>,
926
        extra_args: Option<PyObject>,
927
928
929
        block_mm_infos: Option<PyObject>,
        multi_modal_data: Option<PyObject>,
        mm_routing_info: Option<PyObject>,
930
931
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the options with defaults
932
933
934
935
936
        let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            StopConditions::default()
        };
937

938
939
940
941
942
        let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            SamplingOptions::default()
        };
943

944
945
946
947
948
        let output_options: OutputOptions = if let Some(obj) = output_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            OutputOptions::default()
        };
949

950
        let router_config_override: Option<RouterConfigOverride> =
951
952
953
954
955
            if let Some(obj) = router_config_override {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };
956

957
958
959
960
961
        let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
            Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
        } else {
            None
        };
962

963
964
965
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985

        let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
            if let Some(obj) = multi_modal_data {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };

        let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
            if let Some(obj) = mm_routing_info {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                block_mm_infos.map(
                    |infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
                        routing_token_ids: token_ids.clone(),
                        block_mm_infos: infos,
                    },
                )
            };

986
987
988
        // Create tracker to capture worker routing info from KvRouter
        let tracker = Arc::new(RequestTracker::new());

989
        // Build the PreprocessedRequest
990
991
992
        let mut request_builder =
            llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
        request_builder
993
994
995
996
997
            .model(model)
            .token_ids(token_ids)
            .stop_conditions(stop_conditions)
            .sampling_options(sampling_options)
            .output_options(output_options)
998
            .router_config_override(router_config_override)
999
1000
            .multi_modal_data(multi_modal_data)
            .mm_routing_info(mm_routing_info)
1001
1002
            .extra_args(extra_args)
            .tracker(Some(tracker.clone()));
1003

1004
1005
1006
1007
1008
1009
1010
1011
        // Set routing hints if worker_id or dp_rank is provided
        if worker_id.is_some() || dp_rank.is_some() {
            let routing = llm_rs::protocols::common::preprocessor::RoutingHints {
                backend_instance_id: worker_id,
                dp_rank,
                ..Default::default()
            };
            request_builder.routing(Some(routing));
1012
1013
1014
        }

        let request = request_builder.build().map_err(to_pyerr)?;
1015

Yan Ru Pei's avatar
Yan Ru Pei committed
1016
        // Use the helper method to process the request
1017
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
Yan Ru Pei's avatar
Yan Ru Pei committed
1018
    }
1019

Yan Ru Pei's avatar
Yan Ru Pei committed
1020
1021
1022
1023
1024
1025
    fn generate_from_request<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the request directly into PreprocessedRequest
1026
        let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
1027
            depythonize(request.bind(py)).map_err(to_pyerr)?;
1028

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        // Create tracker if not already set, to capture worker routing info
        let tracker = match request.tracker {
            Some(ref t) => t.clone(),
            None => {
                let t = Arc::new(RequestTracker::new());
                request.tracker = Some(t.clone());
                t
            }
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
1039
        // Use the helper method to process the request
1040
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
1041
    }
1042

1043
1044
    #[allow(clippy::too_many_arguments)]
    #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, update_indexer=false, block_mm_infos=None, lora_name=None))]
Yan Ru Pei's avatar
Yan Ru Pei committed
1045
1046
1047
1048
1049
1050
    fn best_worker<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        router_config_override: Option<PyObject>,
        request_id: Option<String>,
1051
        update_indexer: bool,
1052
        block_mm_infos: Option<PyObject>,
1053
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
1054
1055
    ) -> PyResult<Bound<'p, PyAny>> {
        let router_config_override = if let Some(obj) = router_config_override {
1056
            let override_config: RouterConfigOverride =
1057
1058
                depythonize(obj.bind(py)).map_err(to_pyerr)?;
            Some(override_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
1059
1060
1061
1062
        } else {
            None
        };

1063
1064
1065
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
1066

Yan Ru Pei's avatar
Yan Ru Pei committed
1067
1068
1069
1070
1071
1072
1073
1074
        let chooser = self.inner.chooser.clone();
        let update_states = request_id.is_some();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let (best_worker, overlap_blocks) = chooser
                .find_best_match(
                    request_id.as_deref(),
                    &token_ids,
1075
                    block_mm_infos.as_deref(),
Yan Ru Pei's avatar
Yan Ru Pei committed
1076
1077
                    router_config_override.as_ref(),
                    update_states,
1078
                    lora_name.clone(),
1079
                    0.0,
1080
                    None,
1081
                    None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
Yan Ru Pei's avatar
Yan Ru Pei committed
1082
1083
1084
1085
                )
                .await
                .map_err(to_pyerr)?;

1086
            if update_indexer && !chooser.kv_router_config().use_kv_events {
1087
1088
1089
1090
1091
1092
1093
1094
1095
                let mut tokens_with_hashes =
                    TokensWithHashes::new(token_ids.clone(), chooser.block_size())
                        .with_is_eagle(chooser.is_eagle());
                if let Some(infos) = block_mm_infos.as_ref() {
                    tokens_with_hashes = tokens_with_hashes.with_mm_infos(infos.clone());
                }
                if let Some(lora_name) = lora_name.as_ref() {
                    tokens_with_hashes = tokens_with_hashes.with_lora_name(lora_name.clone());
                }
1096
                chooser
1097
                    .record_routing_decision(tokens_with_hashes, best_worker)
1098
1099
1100
1101
                    .await
                    .map_err(to_pyerr)?;
            }

Yan Ru Pei's avatar
Yan Ru Pei committed
1102
1103
1104
1105
            Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
        })
    }

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    /// Mark prefill as completed for a request
    fn mark_prefill_complete<'p>(
        &self,
        py: Python<'p>,
        request_id: String,
    ) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser
                .mark_prefill_completed(&request_id)
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

    /// Free a request by its ID, signaling the router to release resources
    fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser.free(&request_id).await.map_err(to_pyerr)?;
            Ok(())
        })
    }

1133
    #[pyo3(signature = (token_ids, block_mm_infos=None, lora_name=None))]
1134
1135
1136
1137
    fn get_potential_loads<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
1138
        block_mm_infos: Option<PyObject>,
1139
        lora_name: Option<String>,
1140
    ) -> PyResult<Bound<'p, PyAny>> {
1141
1142
1143
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
1144
        let chooser = self.inner.chooser.clone();
1145
1146

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
1147
            let loads = chooser
1148
1149
1150
1151
1152
1153
                .get_potential_loads(
                    &token_ids,
                    None,
                    block_mm_infos.as_deref(),
                    lora_name.as_deref(),
                )
1154
1155
1156
                .await
                .map_err(to_pyerr)?;

Yan Ru Pei's avatar
Yan Ru Pei committed
1157
            // Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
1158
1159
1160
1161
1162
1163
1164
1165
1166
            // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
            Python::with_gil(|py| {
                pythonize(py, &loads)
                    .map(|obj| obj.unbind())
                    .map_err(to_pyerr)
            })
        })
    }

1167
1168
    /// Dump all events from the KV router's indexer as a JSON string
    fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
1169
        let chooser = self.inner.chooser.clone();
1170
1171

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
1172
            let events = chooser.dump_events().await.map_err(to_pyerr)?;
1173
1174
1175
1176
1177
1178
            // Serialize to JSON string
            let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
            Ok(json_str)
        })
    }
}