kv.rs 40.1 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use pythonize::{depythonize, pythonize};
5
use std::collections::HashMap;
6
use std::ffi::OsString;
7
use std::sync::Arc;
8
use std::sync::atomic::AtomicU32;
9
use std::sync::mpsc;
10
use tokio_stream::StreamExt;
11

12
use super::*;
13
use crate::Endpoint;
14
15
16
17
18
19
#[cfg(feature = "kv-indexer")]
use clap::Parser;
#[cfg(feature = "kv-indexer-runtime")]
use dynamo_kv_router::standalone_indexer::RuntimeConfig;
#[cfg(feature = "kv-indexer")]
use dynamo_kv_router::standalone_indexer::{self, IndexerConfig};
20
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
Yan Ru Pei's avatar
Yan Ru Pei committed
21
use rs::pipeline::{AsyncEngine, SingleIn};
22
use rs::protocols::annotated::Annotated as RsAnnotated;
23
use tracing;
24

25
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
26
use llm_rs::kv_router::protocols::*;
27
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
28
use llm_rs::protocols::common::timing::RequestTracker;
29
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
30
use serde_json::json;
31

32
33
34
35
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
    depythonize(obj).map_err(to_pyerr)
}

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#[cfg(feature = "kv-indexer")]
#[derive(Parser)]
#[command(
    name = "python -m dynamo.indexer",
    about = "Standalone KV cache indexer"
)]
struct KvIndexerCli {
    /// KV cache block size for initial workers registered via --workers
    #[arg(long)]
    block_size: Option<u32>,

    /// HTTP server port
    #[arg(long, default_value_t = 8090)]
    port: u16,

    /// Number of indexer threads (1 = single-threaded KvIndexer, >1 = ThreadPoolIndexer)
    #[arg(long, default_value_t = 4)]
    threads: usize,

    /// Initial workers as "worker_id[:dp_rank]=zmq_address,..." (e.g. "1=tcp://host:5557,1:1=tcp://host:5558")
    #[arg(long)]
    workers: Option<String>,

    /// Model name for initial workers registered via --workers
    #[arg(long, default_value = "default")]
    model_name: String,

    /// Tenant ID for initial workers registered via --workers
    #[arg(long, default_value = "default")]
    tenant_id: String,

    /// Comma-separated peer URLs for P2P recovery (e.g. "http://host1:8090,http://host2:8091")
    #[arg(long)]
    peers: Option<String>,

    /// Enable Dynamo runtime integration (discovery, event plane, request plane).
    #[cfg(feature = "kv-indexer-runtime")]
    #[arg(long)]
    dynamo_runtime: bool,

    /// Dynamo namespace to register the indexer component under.
    #[cfg(feature = "kv-indexer-runtime")]
    #[arg(long, default_value = "default")]
    namespace: String,

    /// Component name for this indexer in the Dynamo runtime.
    #[cfg(feature = "kv-indexer-runtime")]
    #[arg(long, default_value = "kv-indexer")]
    component_name: String,

    /// Component name that workers register under.
    #[cfg(feature = "kv-indexer-runtime")]
    #[arg(long, default_value = "backend")]
    worker_component: String,
}

pub fn run_kv_indexer_cli<I, T>(args: I) -> anyhow::Result<()>
where
    I: IntoIterator<Item = T>,
    T: Into<OsString>,
{
    #[cfg(feature = "kv-indexer")]
    {
        let cli = KvIndexerCli::try_parse_from(
            std::iter::once(OsString::from("python -m dynamo.indexer"))
                .chain(args.into_iter().map(Into::into)),
        )?;

        #[cfg(feature = "kv-indexer-runtime")]
        if cli.dynamo_runtime {
            dynamo_runtime::logging::init();
            let worker = dynamo_runtime::Worker::from_settings()?;
            return worker.execute(move |runtime| {
                standalone_indexer::run_with_runtime(
                    runtime,
                    IndexerConfig {
                        block_size: cli.block_size,
                        port: cli.port,
                        threads: cli.threads,
                        workers: cli.workers,
                        model_name: cli.model_name,
                        tenant_id: cli.tenant_id,
                        peers: cli.peers,
                    },
                    RuntimeConfig {
                        namespace: cli.namespace,
                        component_name: cli.component_name,
                        worker_component: cli.worker_component,
                    },
                )
            });
        }

        init_standalone_logging();

        let rt = tokio::runtime::Runtime::new()?;
        rt.block_on(standalone_indexer::run_server(IndexerConfig {
            block_size: cli.block_size,
            port: cli.port,
            threads: cli.threads,
            workers: cli.workers,
            model_name: cli.model_name,
            tenant_id: cli.tenant_id,
            peers: cli.peers,
        }))
    }

    #[cfg(not(feature = "kv-indexer"))]
    {
        let _ = args;
        anyhow::bail!(
            "dynamo.indexer is not available in this build; reinstall with --features kv-indexer"
        )
    }
}

#[cfg(feature = "kv-indexer")]
fn init_standalone_logging() {
    let _ = tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
        )
        .try_init();
}

Yan Ru Pei's avatar
Yan Ru Pei committed
162
#[pyfunction]
163
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
164
165
166
167
168
pub fn compute_block_hash_for_seq_py(
    _py: Python,
    tokens: Vec<u32>,
    kv_block_size: usize,
    block_mm_infos: Option<Bound<PyAny>>,
169
    lora_name: Option<String>,
170
) -> PyResult<Vec<u64>> {
Yan Ru Pei's avatar
Yan Ru Pei committed
171
    if kv_block_size == 0 {
172
173
174
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "kv_block_size cannot be 0",
        ));
Yan Ru Pei's avatar
Yan Ru Pei committed
175
176
    }

177
    let mm_infos = block_mm_infos
178
        .as_ref()
179
        .map(depythonize_block_mm_infos)
180
181
        .transpose()?;

182
183
184
185
186
187
    let hashes = compute_block_hash_for_seq(
        &tokens,
        kv_block_size as u32,
        mm_infos.as_deref(),
        lora_name.as_deref(),
    );
188

Yan Ru Pei's avatar
Yan Ru Pei committed
189
190
191
    Ok(hashes.into_iter().map(|h| h.0).collect())
}

GuanLuo's avatar
GuanLuo committed
192
#[pyclass]
193
194
pub(crate) struct WorkerMetricsPublisher {
    inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
GuanLuo's avatar
GuanLuo committed
195
196
197
}

#[pymethods]
198
impl WorkerMetricsPublisher {
GuanLuo's avatar
GuanLuo committed
199
200
    #[new]
    fn new() -> PyResult<Self> {
201
202
        let inner =
            llm_rs::kv_router::publisher::WorkerMetricsPublisher::new().map_err(to_pyerr)?;
GuanLuo's avatar
GuanLuo committed
203
204
205
206
207
        Ok(Self {
            inner: inner.into(),
        })
    }

208
    #[pyo3(signature = (endpoint))]
Alec's avatar
Alec committed
209
    fn create_endpoint<'p>(
GuanLuo's avatar
GuanLuo committed
210
211
        &self,
        py: Python<'p>,
212
        endpoint: Endpoint,
GuanLuo's avatar
GuanLuo committed
213
214
    ) -> PyResult<Bound<'p, PyAny>> {
        let rs_publisher = self.inner.clone();
215
        let rs_component = endpoint.inner.component().clone();
GuanLuo's avatar
GuanLuo committed
216
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
217
            rs_publisher
218
                .create_endpoint(rs_component)
GuanLuo's avatar
GuanLuo committed
219
220
221
222
223
224
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

225
226
227
228
229
230
231
    /// Publish worker metrics for load monitoring.
    ///
    /// # Arguments
    /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
    /// * `active_decode_blocks` - Number of active KV cache blocks
    #[pyo3(signature = (dp_rank, active_decode_blocks))]
    fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> {
GuanLuo's avatar
GuanLuo committed
232
        self.inner
233
            .publish(dp_rank, active_decode_blocks)
GuanLuo's avatar
GuanLuo committed
234
235
236
            .map_err(to_pyerr)
    }
}
237

238
239
240
#[pyclass]
pub(crate) struct KvEventPublisher {
    inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
241
    kv_block_size: usize,
Yan Ru Pei's avatar
Yan Ru Pei committed
242
    dp_rank: DpRank,
243
    warning_count: Arc<AtomicU32>,
244
245
246
247
}

#[pymethods]
impl KvEventPublisher {
248
249
250
251
252
253
254
255
256
257
258
259
    /// Create a KV event publisher that batches raw engine events before forwarding
    /// them to NATS / the event plane.
    ///
    /// Args:
    ///     endpoint: The Dynamo component endpoint for this worker.
    ///     worker_id: Identifier of this worker (default 0).
    ///     kv_block_size: KV cache block size in tokens; must be > 0.
    ///     dp_rank: Data-parallel rank of this worker (default 0).
    ///     enable_local_indexer: When True, a local KV indexer is kept in-process
    ///         so that routers can recover events directly from this worker.
    ///     zmq_endpoint: Optional ZMQ SUB endpoint to read raw engine events from.
    ///     zmq_topic: ZMQ topic filter (default "").
260
    ///     batching_timeout_ms: Maximum time (in **milliseconds**) to accumulate
261
    ///         events into a single batch before flushing.
262
263
264
265
    ///         ``None`` disables batching: every event is published immediately.
    ///         ``50`` to enable batching with a 50 ms window.
    ///         ``0`` is treated as ``None`` (also disables batching).
    ///         Maximum allowed is 15_000 (15 seconds); larger values are capped.
266
    #[new]
267
    #[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None, batching_timeout_ms=llm_rs::kv_router::publisher::DEFAULT_BATCHING_TIMEOUT_MS))]
268
    #[allow(clippy::too_many_arguments)]
269
    fn new(
270
        endpoint: Endpoint,
271
272
273
        worker_id: WorkerId,
        kv_block_size: usize,
        dp_rank: DpRank,
274
        enable_local_indexer: bool,
275
276
        zmq_endpoint: Option<String>,
        zmq_topic: Option<String>,
277
        batching_timeout_ms: Option<u64>,
278
    ) -> PyResult<Self> {
279
280
        let _ = worker_id;

281
282
        let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
            endpoint: ep,
283
284
            topic: zmq_topic.unwrap_or_default(),
        });
285

Yan Ru Pei's avatar
Yan Ru Pei committed
286
287
288
289
        if kv_block_size == 0 {
            return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
        }

290
291
292
        // Extract component from endpoint
        let component = endpoint.inner.component().clone();

293
        let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
294
            component,
295
            kv_block_size as u32,
296
            source_config,
297
            enable_local_indexer,
298
            dp_rank,
299
            batching_timeout_ms,
300
301
        )
        .map_err(to_pyerr)?;
302

303
304
        Ok(Self {
            inner: inner.into(),
305
            kv_block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
306
            dp_rank,
307
            warning_count: Arc::new(AtomicU32::new(0)),
308
309
310
311
        })
    }

    #[allow(clippy::too_many_arguments)]
312
    #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
313
    fn publish_stored(
314
        &self,
315
        py: Python,
316
317
        token_ids: Vec<u32>,
        num_block_tokens: Vec<u64>,
318
319
        block_hashes: Vec<i64>,
        parent_hash: Option<i64>,
320
        block_mm_infos: Option<Bound<PyAny>>,
321
        lora_name: Option<String>,
322
    ) -> PyResult<()> {
323
324
325
326
327
        let kv_block_size = self.kv_block_size as u32;
        let dp_rank = self.dp_rank;
        let warning_count = self.warning_count.clone();
        let inner = self.inner.clone();

328
329
        let event_id = inner.next_event_id();

330
        let mm_infos = block_mm_infos
331
            .as_ref()
332
            .map(depythonize_block_mm_infos)
333
334
            .transpose()?;

335
336
337
338
339
340
341
342
343
344
345
        py.allow_threads(|| {
            let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
                    blocks: create_stored_blocks(
                        kv_block_size,
                        &token_ids,
                        &num_block_tokens,
                        &block_hashes_u64,
346
                        lora_name.as_deref(),
347
                        &warning_count,
348
                        mm_infos.as_deref(),
349
350
351
352
353
354
355
                    ),
                }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
356
357
    }

358
    fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
359
360
361
        let dp_rank = self.dp_rank;
        let inner = self.inner.clone();

362
363
364
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

365
366
367
368
369
370
371
372
373
374
375
376
377
        py.allow_threads(|| {
            let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
                .into_iter()
                .map(ExternalSequenceBlockHash::from)
                .collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
378
    }
379
380
381
382
383
384
385
386

    fn shutdown(&mut self) {
        // If no other Arc clones exist, shut down eagerly.
        // Otherwise the Drop impl handles cleanup when the last reference is freed.
        if let Some(inner) = Arc::get_mut(&mut self.inner) {
            inner.shutdown();
        }
    }
387
388
}

389
390
391
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
392
    inner: llm_rs::kv_router::protocols::OverlapScores,
393
394
395
396
397
}

#[pymethods]
impl OverlapScores {
    #[getter]
398
    fn scores(&self) -> HashMap<(u64, u32), u32> {
Yan Ru Pei's avatar
Yan Ru Pei committed
399
400
401
402
403
404
        // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
        self.inner
            .scores
            .iter()
            .map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
            .collect()
405
406
407
408
409
410
411
412
    }

    #[getter]
    fn frequencies(&self) -> Vec<usize> {
        self.inner.frequencies.clone()
    }
}

413
414
415
416
417
#[derive(Debug)]
enum RadixTreeRequest {
    FindMatches {
        local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
        early_exit: bool,
418
        response_tx: mpsc::SyncSender<llm_rs::kv_router::protocols::OverlapScores>,
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    },
    ApplyEvent {
        worker_id: WorkerId,
        kv_cache_event_bytes: Vec<u8>,
        response_tx: mpsc::SyncSender<PyResult<()>>,
    },
    RemoveWorker {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    ClearAllBlocks {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    DumpTreeAsEvents {
434
        response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::protocols::RouterEvent>>,
435
436
437
438
439
440
    },
    Shutdown,
}

// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
Yan Ru Pei's avatar
Yan Ru Pei committed
441
pub(crate) struct RadixTree {
442
    request_tx: mpsc::Sender<RadixTreeRequest>,
Yan Ru Pei's avatar
Yan Ru Pei committed
443
444
445
446
447
448
449
450
}

#[pymethods]
impl RadixTree {
    #[new]
    #[pyo3(signature = (expiration_duration_secs=None))]
    fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
        let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

        let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();

        // Spawn dedicated thread with simplified sync processing
        std::thread::spawn(move || {
            let mut radix_tree =
                llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);

            loop {
                match request_rx.recv() {
                    Ok(RadixTreeRequest::Shutdown) => {
                        tracing::debug!("RadixTree thread received shutdown request");
                        break;
                    }
                    Ok(request) => {
                        Self::handle_request(&mut radix_tree, request);
                    }
                    Err(mpsc::RecvError) => {
                        tracing::debug!("RadixTree request channel disconnected");
                        break;
                    }
                }
            }
        });

        Ok(Self { request_tx })
Yan Ru Pei's avatar
Yan Ru Pei committed
477
478
479
480
481
    }

    #[pyo3(signature = (sequence, early_exit=false))]
    fn find_matches(
        &self,
482
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
483
484
485
        sequence: Vec<u64>,
        early_exit: bool,
    ) -> PyResult<OverlapScores> {
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let local_block_hashes = py.allow_threads(|| {
            sequence
                .into_iter()
                .map(llm_rs::kv_router::protocols::LocalBlockHash)
                .collect()
        });

        let request = RadixTreeRequest::FindMatches {
            local_block_hashes,
            early_exit,
            response_tx,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })?;

        Ok(OverlapScores { inner: result })
Yan Ru Pei's avatar
Yan Ru Pei committed
515
516
517
    }

    fn apply_event(
518
519
        &self,
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
520
        worker_id: WorkerId,
Yan Ru Pei's avatar
Yan Ru Pei committed
521
522
        kv_cache_event_bytes: &[u8],
    ) -> PyResult<()> {
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ApplyEvent {
            worker_id,
            kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || response_rx.recv());

        result.map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
        })?
    }

    fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::RemoveWorker {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }

    fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ClearAllBlocks {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
    fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
        })?;

        // Release GIL while waiting for response from dedicated thread
        let events = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Failed to receive dump tree response",
                )
            })
        })?;

        // Serialize RouterEvent structs to JSON strings with GIL released
        py.allow_threads(move || {
            events
                .into_iter()
                .map(|event| {
                    serde_json::to_string(&event).map_err(|e| {
                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                            "Failed to serialize event to JSON: {}",
                            e
                        ))
                    })
                })
                .collect::<Result<Vec<String>, PyErr>>()
        })
Yan Ru Pei's avatar
Yan Ru Pei committed
621
    }
622
}
Yan Ru Pei's avatar
Yan Ru Pei committed
623

624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
impl RadixTree {
    fn handle_request(
        radix_tree: &mut llm_rs::kv_router::indexer::RadixTree,
        request: RadixTreeRequest,
    ) {
        match request {
            RadixTreeRequest::FindMatches {
                local_block_hashes,
                early_exit,
                response_tx,
            } => {
                let result = radix_tree.find_matches(local_block_hashes, early_exit);
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::ApplyEvent {
                worker_id,
                kv_cache_event_bytes,
                response_tx,
            } => {
                let result = match serde_json::from_slice::<
                    llm_rs::kv_router::protocols::KvCacheEvent,
                >(&kv_cache_event_bytes)
                {
                    Ok(kv_cache_event) => {
648
649
650
651
                        let router_event = llm_rs::kv_router::protocols::RouterEvent::new(
                            worker_id,
                            kv_cache_event,
                        );
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
                        match radix_tree.apply_event(router_event) {
                            Ok(_) => Ok(()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Failed to apply event: {}", e),
                            )),
                        }
                    }
                    Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Failed to deserialize KvCacheEvent: {}",
                        e
                    ))),
                };
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::RemoveWorker {
                worker_id,
                response_tx,
            } => {
                radix_tree.remove_worker(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::ClearAllBlocks {
                worker_id,
                response_tx,
            } => {
                radix_tree.clear_all_blocks(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
                let events = radix_tree.dump_tree_as_events();
                let _ = response_tx.send(events);
            }
            RadixTreeRequest::Shutdown => {
                // This is handled in the main loop
            }
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
688
    }
689
}
Yan Ru Pei's avatar
Yan Ru Pei committed
690

691
692
693
694
695
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
    fn drop(&mut self) {
        // Only need graceful shutdown via RadixTreeRequest::Shutdown
        let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
Yan Ru Pei's avatar
Yan Ru Pei committed
696
697
698
    }
}

699
/// Helper function to create a KV router from an endpoint using the ModelManager
700
701
702
703
704
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
/// - If endpoint name/component contains "prefill", treat as prefill
/// - If router_track_active_blocks is disabled, treat as prefill
/// - Otherwise, default to decode
705
706
707
708
709
async fn create_kv_router_from_endpoint(
    endpoint: &Endpoint,
    block_size: usize,
    kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
710
    // Create ModelManager and use it to create KvRouter (ensures registration)
711
    let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    let endpoint_id = endpoint.inner.id();
    let namespace = endpoint_id.namespace.to_lowercase();
    let component = endpoint_id.component.to_lowercase();
    let name = endpoint_id.name.to_lowercase();
    let endpoint_is_prefill =
        namespace.contains("prefill") || component.contains("prefill") || name.contains("prefill");
    let track_active_blocks = kv_router_config
        .as_ref()
        .map(|cfg| cfg.router_track_active_blocks)
        .unwrap_or(true);
    let worker_type = if endpoint_is_prefill || !track_active_blocks {
        llm_rs::discovery::WORKER_TYPE_PREFILL
    } else {
        llm_rs::discovery::WORKER_TYPE_DECODE
    };
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764

    // Only query discovery for model_name when a remote indexer is configured,
    // since model_name is only needed for the RemoteIndexer path.
    let needs_model_name = kv_router_config
        .as_ref()
        .map(|cfg| cfg.remote_indexer_component.is_some())
        .unwrap_or(false);

    let model_name = if needs_model_name {
        let discovery = endpoint.inner.component().drt().discovery();
        let instances = discovery
            .list(rs::discovery::DiscoveryQuery::EndpointModels {
                namespace: endpoint_id.namespace.clone(),
                component: endpoint_id.component.clone(),
                endpoint: endpoint_id.name.clone(),
            })
            .await
            .map_err(to_pyerr)?;

        Some(
            instances
                .into_iter()
                .find_map(|inst| {
                    inst.deserialize_model::<llm_rs::model_card::ModelDeploymentCard>()
                        .ok()
                        .map(|card| card.display_name)
                })
                .ok_or_else(|| {
                    PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
                        "no model card found in discovery for endpoint {}/{}/{}",
                        endpoint_id.namespace, endpoint_id.component, endpoint_id.name
                    ))
                })?,
        )
    } else {
        None
    };

765
    let kv_router = model_manager
766
767
768
769
770
        .kv_chooser_for(
            &endpoint.inner,
            block_size as u32,
            kv_router_config,
            worker_type,
771
            model_name,
772
        )
773
774
775
776
777
778
        .await
        .map_err(to_pyerr)?;

    Ok(kv_router)
}

779
#[pyclass]
780
781
pub(crate) struct KvRouter {
    inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
782
783
}

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
    data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
    tracker: &RequestTracker,
) {
    let Some(worker_info) = tracker.get_worker_info() else {
        return;
    };

    let worker_id_json =
        serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");

    if let Some(obj) = data
        .disaggregated_params
        .as_mut()
        .and_then(|p| p.as_object_mut())
    {
        obj.insert("worker_id".to_string(), worker_id_json);
    } else {
        data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
809
// TODO: can this reuse the stream conversion method in Client bindings?
810
impl KvRouter {
Yan Ru Pei's avatar
Yan Ru Pei committed
811
812
813
    /// Helper method to process a request and create a Python async generator
    fn process_request_to_stream<'p>(
        py: Python<'p>,
814
        inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
815
        request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
816
        tracker: Option<Arc<RequestTracker>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
817
818
819
820
    ) -> PyResult<Bound<'p, PyAny>> {
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let single_in = SingleIn::new(request);
            let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
821
            let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
Yan Ru Pei's avatar
Yan Ru Pei committed
822
823
824

            tokio::spawn(async move {
                let mut stream = stream;
825
                let mut first_item = true;
826
                let mut first_token_gauges_observed = false;
827
828
829
830
831
832
833
834
835

                while let Some(mut response) = stream.next().await {
                    if first_item {
                        first_item = false;
                        if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
                            inject_worker_id_from_tracker(data, tracker);
                        }
                    }

836
837
838
839
840
841
842
843
844
845
846
847
848
849
                    if !first_token_gauges_observed {
                        let has_tokens = response
                            .data
                            .as_ref()
                            .map(|d| !d.token_ids.is_empty())
                            .unwrap_or(false);
                        if has_tokens {
                            if let Some(ref tracker) = tracker {
                                tracker.observe_first_token_gauges();
                            }
                            first_token_gauges_observed = true;
                        }
                    }

Yan Ru Pei's avatar
Yan Ru Pei committed
850
851
852
853
854
855
856
857
                    let py_response = Python::with_gil(|py| {
                        pythonize(py, &response.data)
                            .map(|obj| obj.unbind())
                            .map_err(|e| e.to_string())
                    });

                    match py_response {
                        Ok(obj) => {
858
859
                            if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
                                break;
Yan Ru Pei's avatar
Yan Ru Pei committed
860
861
862
863
864
865
866
867
                            }
                        }
                        Err(e) => {
                            tracing::error!("Failed to pythonize response: {}", e);
                            break;
                        }
                    }
                }
868
869
870
871

                if let Some(ref tracker) = tracker {
                    tracker.observe_finish_gauges();
                }
Yan Ru Pei's avatar
Yan Ru Pei committed
872
873
            });

874
            Ok(crate::AsyncResponseStream::new(rx, false))
Yan Ru Pei's avatar
Yan Ru Pei committed
875
876
        })
    }
877
878
879
}

#[pymethods]
880
881
impl KvRouter {
    /// Create a new KvRouter for KV-aware routing to workers.
882
883
884
885
886
887
888
889
    ///
    /// # Arguments
    /// * `endpoint` - The endpoint to route requests to
    /// * `block_size` - KV cache block size for routing decisions
    /// * `kv_router_config` - Configuration for the KV router
    ///
    /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
    /// (contains "prefill") or by `router_track_active_blocks` being disabled.
890
    #[new]
891
    #[pyo3(signature = (endpoint, block_size, kv_router_config))]
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
    fn new(
        endpoint: &Endpoint,
        block_size: usize,
        kv_router_config: &super::entrypoint::KvRouterConfig,
    ) -> PyResult<Self> {
        let runtime = pyo3_async_runtimes::tokio::get_runtime();
        runtime.block_on(async move {
            let client = endpoint.inner.client().await.map_err(to_pyerr)?;

            // Create PushRouter with KV router mode
            let push_router = rs::pipeline::PushRouter::<
                llm_rs::protocols::common::preprocessor::PreprocessedRequest,
                rs::protocols::annotated::Annotated<
                    llm_rs::protocols::common::llm_backend::LLMEngineOutput,
                >,
            >::from_client(
                client,
                rs::pipeline::network::egress::push_router::RouterMode::KV,
            )
            .await
            .map_err(to_pyerr)?;

914
915
916
917
918
919
920
            // Create KvRouter using helper function (ensures etcd registration)
            let kv_router = create_kv_router_from_endpoint(
                endpoint,
                block_size,
                Some(kv_router_config.inner()),
            )
            .await?;
921

922
            let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
923
924
925
926
927
928
929
930

            Ok(Self {
                inner: Arc::new(kv_push_router),
            })
        })
    }

    #[allow(clippy::too_many_arguments)]
931
    #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
932
933
934
935
936
937
938
939
940
    fn generate<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        model: String,
        stop_conditions: Option<PyObject>,
        sampling_options: Option<PyObject>,
        output_options: Option<PyObject>,
        router_config_override: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
941
942
        worker_id: Option<WorkerId>,
        dp_rank: Option<DpRank>,
943
        extra_args: Option<PyObject>,
944
945
946
        block_mm_infos: Option<PyObject>,
        multi_modal_data: Option<PyObject>,
        mm_routing_info: Option<PyObject>,
947
948
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the options with defaults
949
950
951
952
953
        let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            StopConditions::default()
        };
954

955
956
957
958
959
        let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            SamplingOptions::default()
        };
960

961
962
963
964
965
        let output_options: OutputOptions = if let Some(obj) = output_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            OutputOptions::default()
        };
966

967
968
969
970
971
972
        let router_config_override: Option<llm_rs::kv_router::RouterConfigOverride> =
            if let Some(obj) = router_config_override {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };
973

974
975
976
977
978
        let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
            Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
        } else {
            None
        };
979

980
981
982
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002

        let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
            if let Some(obj) = multi_modal_data {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };

        let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
            if let Some(obj) = mm_routing_info {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                block_mm_infos.map(
                    |infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
                        routing_token_ids: token_ids.clone(),
                        block_mm_infos: infos,
                    },
                )
            };

1003
1004
1005
        // Create tracker to capture worker routing info from KvRouter
        let tracker = Arc::new(RequestTracker::new());

1006
        // Build the PreprocessedRequest
1007
1008
1009
        let mut request_builder =
            llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
        request_builder
1010
1011
1012
1013
1014
            .model(model)
            .token_ids(token_ids)
            .stop_conditions(stop_conditions)
            .sampling_options(sampling_options)
            .output_options(output_options)
1015
            .router_config_override(router_config_override)
1016
1017
            .multi_modal_data(multi_modal_data)
            .mm_routing_info(mm_routing_info)
1018
1019
            .extra_args(extra_args)
            .tracker(Some(tracker.clone()));
1020

1021
1022
1023
1024
1025
1026
1027
1028
        // Set routing hints if worker_id or dp_rank is provided
        if worker_id.is_some() || dp_rank.is_some() {
            let routing = llm_rs::protocols::common::preprocessor::RoutingHints {
                backend_instance_id: worker_id,
                dp_rank,
                ..Default::default()
            };
            request_builder.routing(Some(routing));
1029
1030
1031
        }

        let request = request_builder.build().map_err(to_pyerr)?;
1032

Yan Ru Pei's avatar
Yan Ru Pei committed
1033
        // Use the helper method to process the request
1034
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
Yan Ru Pei's avatar
Yan Ru Pei committed
1035
    }
1036

Yan Ru Pei's avatar
Yan Ru Pei committed
1037
1038
1039
1040
1041
1042
    fn generate_from_request<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the request directly into PreprocessedRequest
1043
        let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
1044
            depythonize(request.bind(py)).map_err(to_pyerr)?;
1045

1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        // Create tracker if not already set, to capture worker routing info
        let tracker = match request.tracker {
            Some(ref t) => t.clone(),
            None => {
                let t = Arc::new(RequestTracker::new());
                request.tracker = Some(t.clone());
                t
            }
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
1056
        // Use the helper method to process the request
1057
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
1058
    }
1059

1060
    #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
Yan Ru Pei's avatar
Yan Ru Pei committed
1061
1062
1063
1064
1065
1066
    fn best_worker<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        router_config_override: Option<PyObject>,
        request_id: Option<String>,
1067
        block_mm_infos: Option<PyObject>,
1068
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
1069
1070
    ) -> PyResult<Bound<'p, PyAny>> {
        let router_config_override = if let Some(obj) = router_config_override {
1071
1072
1073
            let override_config: llm_rs::kv_router::RouterConfigOverride =
                depythonize(obj.bind(py)).map_err(to_pyerr)?;
            Some(override_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
1074
1075
1076
1077
        } else {
            None
        };

1078
1079
1080
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
1081

Yan Ru Pei's avatar
Yan Ru Pei committed
1082
1083
1084
1085
1086
1087
1088
1089
        let chooser = self.inner.chooser.clone();
        let update_states = request_id.is_some();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let (best_worker, overlap_blocks) = chooser
                .find_best_match(
                    request_id.as_deref(),
                    &token_ids,
1090
                    block_mm_infos.as_deref(),
Yan Ru Pei's avatar
Yan Ru Pei committed
1091
1092
                    router_config_override.as_ref(),
                    update_states,
1093
                    lora_name,
1094
                    0.0,
1095
                    None,
1096
                    None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
Yan Ru Pei's avatar
Yan Ru Pei committed
1097
1098
1099
1100
1101
1102
1103
1104
                )
                .await
                .map_err(to_pyerr)?;

            Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
        })
    }

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
    /// Mark prefill as completed for a request
    fn mark_prefill_complete<'p>(
        &self,
        py: Python<'p>,
        request_id: String,
    ) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser
                .mark_prefill_completed(&request_id)
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

    /// Free a request by its ID, signaling the router to release resources
    fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser.free(&request_id).await.map_err(to_pyerr)?;
            Ok(())
        })
    }

1132
    #[pyo3(signature = (token_ids, lora_name=None))]
1133
1134
1135
1136
    fn get_potential_loads<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
1137
        lora_name: Option<String>,
1138
    ) -> PyResult<Bound<'p, PyAny>> {
1139
        let chooser = self.inner.chooser.clone();
1140
1141

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
1142
            let loads = chooser
1143
                .get_potential_loads(&token_ids, None, lora_name.as_deref())
1144
1145
1146
                .await
                .map_err(to_pyerr)?;

Yan Ru Pei's avatar
Yan Ru Pei committed
1147
            // Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
1148
1149
1150
1151
1152
1153
1154
1155
1156
            // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
            Python::with_gil(|py| {
                pythonize(py, &loads)
                    .map(|obj| obj.unbind())
                    .map_err(to_pyerr)
            })
        })
    }

1157
1158
    /// Dump all events from the KV router's indexer as a JSON string
    fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
1159
        let chooser = self.inner.chooser.clone();
1160
1161

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
1162
            let events = chooser.dump_events().await.map_err(to_pyerr)?;
1163
1164
1165
1166
1167
1168
            // Serialize to JSON string
            let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
            Ok(json_str)
        })
    }
}