kv.rs 34.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use pythonize::{depythonize, pythonize};
5
use std::collections::HashMap;
6
use std::sync::Arc;
7
use std::sync::atomic::AtomicU32;
8
use std::sync::mpsc;
9
use tokio_stream::StreamExt;
10

11
use super::*;
12
use crate::Endpoint;
13
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
Yan Ru Pei's avatar
Yan Ru Pei committed
14
use rs::pipeline::{AsyncEngine, SingleIn};
15
use rs::protocols::annotated::Annotated as RsAnnotated;
16
use tracing;
17

18
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
19
use llm_rs::kv_router::protocols::*;
20
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
21
use llm_rs::protocols::common::timing::RequestTracker;
22
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
23
use serde_json::json;
24

25
26
27
28
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
    depythonize(obj).map_err(to_pyerr)
}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#[pyfunction]
#[pyo3(name = "start_kv_block_indexer", signature = (endpoint, block_size, kv_router_config))]
pub fn start_kv_block_indexer_py<'p>(
    py: Python<'p>,
    endpoint: &Endpoint,
    block_size: u32,
    kv_router_config: &super::entrypoint::KvRouterConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let component = endpoint.inner.component().clone();
    let config = kv_router_config.inner();
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        llm_rs::kv_router::indexer_standalone::start_kv_block_indexer(
            &component, &config, block_size,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

Yan Ru Pei's avatar
Yan Ru Pei committed
49
#[pyfunction]
50
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
51
52
53
54
55
pub fn compute_block_hash_for_seq_py(
    _py: Python,
    tokens: Vec<u32>,
    kv_block_size: usize,
    block_mm_infos: Option<Bound<PyAny>>,
56
    lora_name: Option<String>,
57
) -> PyResult<Vec<u64>> {
Yan Ru Pei's avatar
Yan Ru Pei committed
58
    if kv_block_size == 0 {
59
60
61
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "kv_block_size cannot be 0",
        ));
Yan Ru Pei's avatar
Yan Ru Pei committed
62
63
    }

64
    let mm_infos = block_mm_infos
65
        .as_ref()
66
        .map(depythonize_block_mm_infos)
67
68
        .transpose()?;

69
70
71
72
73
74
    let hashes = compute_block_hash_for_seq(
        &tokens,
        kv_block_size as u32,
        mm_infos.as_deref(),
        lora_name.as_deref(),
    );
75

Yan Ru Pei's avatar
Yan Ru Pei committed
76
77
78
    Ok(hashes.into_iter().map(|h| h.0).collect())
}

GuanLuo's avatar
GuanLuo committed
79
#[pyclass]
80
81
pub(crate) struct WorkerMetricsPublisher {
    inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
GuanLuo's avatar
GuanLuo committed
82
83
84
}

#[pymethods]
85
impl WorkerMetricsPublisher {
GuanLuo's avatar
GuanLuo committed
86
87
    #[new]
    fn new() -> PyResult<Self> {
88
89
        let inner =
            llm_rs::kv_router::publisher::WorkerMetricsPublisher::new().map_err(to_pyerr)?;
GuanLuo's avatar
GuanLuo committed
90
91
92
93
94
        Ok(Self {
            inner: inner.into(),
        })
    }

95
    #[pyo3(signature = (endpoint))]
Alec's avatar
Alec committed
96
    fn create_endpoint<'p>(
GuanLuo's avatar
GuanLuo committed
97
98
        &self,
        py: Python<'p>,
99
        endpoint: Endpoint,
GuanLuo's avatar
GuanLuo committed
100
101
    ) -> PyResult<Bound<'p, PyAny>> {
        let rs_publisher = self.inner.clone();
102
        let rs_component = endpoint.inner.component().clone();
GuanLuo's avatar
GuanLuo committed
103
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
104
            rs_publisher
105
                .create_endpoint(rs_component)
GuanLuo's avatar
GuanLuo committed
106
107
108
109
110
111
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

112
113
114
115
116
117
118
    /// Publish worker metrics for load monitoring.
    ///
    /// # Arguments
    /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
    /// * `active_decode_blocks` - Number of active KV cache blocks
    #[pyo3(signature = (dp_rank, active_decode_blocks))]
    fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> {
GuanLuo's avatar
GuanLuo committed
119
        self.inner
120
            .publish(dp_rank, active_decode_blocks)
GuanLuo's avatar
GuanLuo committed
121
122
123
            .map_err(to_pyerr)
    }
}
124

125
126
127
#[pyclass]
pub(crate) struct KvEventPublisher {
    inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
128
    kv_block_size: usize,
Yan Ru Pei's avatar
Yan Ru Pei committed
129
    dp_rank: DpRank,
130
    warning_count: Arc<AtomicU32>,
131
132
133
134
}

#[pymethods]
impl KvEventPublisher {
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    /// Create a KV event publisher that batches raw engine events before forwarding
    /// them to NATS / the event plane.
    ///
    /// Args:
    ///     endpoint: The Dynamo component endpoint for this worker.
    ///     worker_id: Identifier of this worker (default 0).
    ///     kv_block_size: KV cache block size in tokens; must be > 0.
    ///     dp_rank: Data-parallel rank of this worker (default 0).
    ///     enable_local_indexer: When True, a local KV indexer is kept in-process
    ///         so that routers can recover events directly from this worker.
    ///     zmq_endpoint: Optional ZMQ SUB endpoint to read raw engine events from.
    ///     zmq_topic: ZMQ topic filter (default "").
    ///     batching_timeout_us: Maximum time (in **microseconds**) to accumulate
    ///         events into a single batch before flushing.
    ///         ``None`` uses the default window of 10000 µs (10 ms).
    ///         ``0`` disables batching: every event is published immediately.
151
    #[new]
152
153
    #[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None, batching_timeout_us=None))]
    #[allow(clippy::too_many_arguments)]
154
    fn new(
155
        endpoint: Endpoint,
156
157
158
        worker_id: WorkerId,
        kv_block_size: usize,
        dp_rank: DpRank,
159
        enable_local_indexer: bool,
160
161
        zmq_endpoint: Option<String>,
        zmq_topic: Option<String>,
162
        batching_timeout_us: Option<u64>,
163
    ) -> PyResult<Self> {
164
165
        let _ = worker_id;

166
167
        let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
            endpoint: ep,
168
169
            topic: zmq_topic.unwrap_or_default(),
        });
170

Yan Ru Pei's avatar
Yan Ru Pei committed
171
172
173
174
        if kv_block_size == 0 {
            return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
        }

175
176
177
        // Extract component from endpoint
        let component = endpoint.inner.component().clone();

178
        let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
179
            component,
180
            kv_block_size as u32,
181
            source_config,
182
            enable_local_indexer,
183
            dp_rank,
184
            batching_timeout_us,
185
186
        )
        .map_err(to_pyerr)?;
187

188
189
        Ok(Self {
            inner: inner.into(),
190
            kv_block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
191
            dp_rank,
192
            warning_count: Arc::new(AtomicU32::new(0)),
193
194
195
196
        })
    }

    #[allow(clippy::too_many_arguments)]
197
    #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
198
    fn publish_stored(
199
        &self,
200
        py: Python,
201
202
        token_ids: Vec<u32>,
        num_block_tokens: Vec<u64>,
203
204
        block_hashes: Vec<i64>,
        parent_hash: Option<i64>,
205
        block_mm_infos: Option<Bound<PyAny>>,
206
        lora_name: Option<String>,
207
    ) -> PyResult<()> {
208
209
210
211
212
        let kv_block_size = self.kv_block_size as u32;
        let dp_rank = self.dp_rank;
        let warning_count = self.warning_count.clone();
        let inner = self.inner.clone();

213
214
        let event_id = inner.next_event_id();

215
        let mm_infos = block_mm_infos
216
            .as_ref()
217
            .map(depythonize_block_mm_infos)
218
219
            .transpose()?;

220
221
222
223
224
225
226
227
228
229
230
        py.allow_threads(|| {
            let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
                    blocks: create_stored_blocks(
                        kv_block_size,
                        &token_ids,
                        &num_block_tokens,
                        &block_hashes_u64,
231
                        lora_name.as_deref(),
232
                        &warning_count,
233
                        mm_infos.as_deref(),
234
235
236
237
238
239
240
                    ),
                }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
241
242
    }

243
    fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
244
245
246
        let dp_rank = self.dp_rank;
        let inner = self.inner.clone();

247
248
249
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

250
251
252
253
254
255
256
257
258
259
260
261
262
        py.allow_threads(|| {
            let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
                .into_iter()
                .map(ExternalSequenceBlockHash::from)
                .collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
263
    }
264
265
266
267
268
269
270
271

    fn shutdown(&mut self) {
        // If no other Arc clones exist, shut down eagerly.
        // Otherwise the Drop impl handles cleanup when the last reference is freed.
        if let Some(inner) = Arc::get_mut(&mut self.inner) {
            inner.shutdown();
        }
    }
272
273
}

274
275
276
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
277
    inner: llm_rs::kv_router::protocols::OverlapScores,
278
279
280
281
282
}

#[pymethods]
impl OverlapScores {
    #[getter]
283
    fn scores(&self) -> HashMap<(u64, u32), u32> {
Yan Ru Pei's avatar
Yan Ru Pei committed
284
285
286
287
288
289
        // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
        self.inner
            .scores
            .iter()
            .map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
            .collect()
290
291
292
293
294
295
296
297
    }

    #[getter]
    fn frequencies(&self) -> Vec<usize> {
        self.inner.frequencies.clone()
    }
}

298
299
300
301
302
#[derive(Debug)]
enum RadixTreeRequest {
    FindMatches {
        local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
        early_exit: bool,
303
        response_tx: mpsc::SyncSender<llm_rs::kv_router::protocols::OverlapScores>,
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    },
    ApplyEvent {
        worker_id: WorkerId,
        kv_cache_event_bytes: Vec<u8>,
        response_tx: mpsc::SyncSender<PyResult<()>>,
    },
    RemoveWorker {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    ClearAllBlocks {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    DumpTreeAsEvents {
319
        response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::protocols::RouterEvent>>,
320
321
322
323
324
325
    },
    Shutdown,
}

// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
Yan Ru Pei's avatar
Yan Ru Pei committed
326
pub(crate) struct RadixTree {
327
    request_tx: mpsc::Sender<RadixTreeRequest>,
Yan Ru Pei's avatar
Yan Ru Pei committed
328
329
330
331
332
333
334
335
}

#[pymethods]
impl RadixTree {
    #[new]
    #[pyo3(signature = (expiration_duration_secs=None))]
    fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
        let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

        let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();

        // Spawn dedicated thread with simplified sync processing
        std::thread::spawn(move || {
            let mut radix_tree =
                llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);

            loop {
                match request_rx.recv() {
                    Ok(RadixTreeRequest::Shutdown) => {
                        tracing::debug!("RadixTree thread received shutdown request");
                        break;
                    }
                    Ok(request) => {
                        Self::handle_request(&mut radix_tree, request);
                    }
                    Err(mpsc::RecvError) => {
                        tracing::debug!("RadixTree request channel disconnected");
                        break;
                    }
                }
            }
        });

        Ok(Self { request_tx })
Yan Ru Pei's avatar
Yan Ru Pei committed
362
363
364
365
366
    }

    #[pyo3(signature = (sequence, early_exit=false))]
    fn find_matches(
        &self,
367
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
368
369
370
        sequence: Vec<u64>,
        early_exit: bool,
    ) -> PyResult<OverlapScores> {
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let local_block_hashes = py.allow_threads(|| {
            sequence
                .into_iter()
                .map(llm_rs::kv_router::protocols::LocalBlockHash)
                .collect()
        });

        let request = RadixTreeRequest::FindMatches {
            local_block_hashes,
            early_exit,
            response_tx,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })?;

        Ok(OverlapScores { inner: result })
Yan Ru Pei's avatar
Yan Ru Pei committed
400
401
402
    }

    fn apply_event(
403
404
        &self,
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
405
        worker_id: WorkerId,
Yan Ru Pei's avatar
Yan Ru Pei committed
406
407
        kv_cache_event_bytes: &[u8],
    ) -> PyResult<()> {
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ApplyEvent {
            worker_id,
            kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || response_rx.recv());

        result.map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
        })?
    }

    fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::RemoveWorker {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }

    fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ClearAllBlocks {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
473

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
        })?;

        // Release GIL while waiting for response from dedicated thread
        let events = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Failed to receive dump tree response",
                )
            })
        })?;

        // Serialize RouterEvent structs to JSON strings with GIL released
        py.allow_threads(move || {
            events
                .into_iter()
                .map(|event| {
                    serde_json::to_string(&event).map_err(|e| {
                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                            "Failed to serialize event to JSON: {}",
                            e
                        ))
                    })
                })
                .collect::<Result<Vec<String>, PyErr>>()
        })
Yan Ru Pei's avatar
Yan Ru Pei committed
506
    }
507
}
Yan Ru Pei's avatar
Yan Ru Pei committed
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
impl RadixTree {
    fn handle_request(
        radix_tree: &mut llm_rs::kv_router::indexer::RadixTree,
        request: RadixTreeRequest,
    ) {
        match request {
            RadixTreeRequest::FindMatches {
                local_block_hashes,
                early_exit,
                response_tx,
            } => {
                let result = radix_tree.find_matches(local_block_hashes, early_exit);
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::ApplyEvent {
                worker_id,
                kv_cache_event_bytes,
                response_tx,
            } => {
                let result = match serde_json::from_slice::<
                    llm_rs::kv_router::protocols::KvCacheEvent,
                >(&kv_cache_event_bytes)
                {
                    Ok(kv_cache_event) => {
533
534
535
536
                        let router_event = llm_rs::kv_router::protocols::RouterEvent::new(
                            worker_id,
                            kv_cache_event,
                        );
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                        match radix_tree.apply_event(router_event) {
                            Ok(_) => Ok(()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Failed to apply event: {}", e),
                            )),
                        }
                    }
                    Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Failed to deserialize KvCacheEvent: {}",
                        e
                    ))),
                };
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::RemoveWorker {
                worker_id,
                response_tx,
            } => {
                radix_tree.remove_worker(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::ClearAllBlocks {
                worker_id,
                response_tx,
            } => {
                radix_tree.clear_all_blocks(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
                let events = radix_tree.dump_tree_as_events();
                let _ = response_tx.send(events);
            }
            RadixTreeRequest::Shutdown => {
                // This is handled in the main loop
            }
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
573
    }
574
}
Yan Ru Pei's avatar
Yan Ru Pei committed
575

576
577
578
579
580
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
    fn drop(&mut self) {
        // Only need graceful shutdown via RadixTreeRequest::Shutdown
        let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
Yan Ru Pei's avatar
Yan Ru Pei committed
581
582
583
    }
}

584
/// Helper function to create a KV router from an endpoint using the ModelManager
585
586
587
588
589
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
/// - If endpoint name/component contains "prefill", treat as prefill
/// - If router_track_active_blocks is disabled, treat as prefill
/// - Otherwise, default to decode
590
591
592
593
594
async fn create_kv_router_from_endpoint(
    endpoint: &Endpoint,
    block_size: usize,
    kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
595
    // Create ModelManager and use it to create KvRouter (ensures registration)
596
    let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    let endpoint_id = endpoint.inner.id();
    let namespace = endpoint_id.namespace.to_lowercase();
    let component = endpoint_id.component.to_lowercase();
    let name = endpoint_id.name.to_lowercase();
    let endpoint_is_prefill =
        namespace.contains("prefill") || component.contains("prefill") || name.contains("prefill");
    let track_active_blocks = kv_router_config
        .as_ref()
        .map(|cfg| cfg.router_track_active_blocks)
        .unwrap_or(true);
    let worker_type = if endpoint_is_prefill || !track_active_blocks {
        llm_rs::discovery::WORKER_TYPE_PREFILL
    } else {
        llm_rs::discovery::WORKER_TYPE_DECODE
    };
612
    let kv_router = model_manager
613
614
615
616
617
618
        .kv_chooser_for(
            &endpoint.inner,
            block_size as u32,
            kv_router_config,
            worker_type,
        )
619
620
621
622
623
624
        .await
        .map_err(to_pyerr)?;

    Ok(kv_router)
}

625
#[pyclass]
626
627
pub(crate) struct KvRouter {
    inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
628
629
}

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
    data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
    tracker: &RequestTracker,
) {
    let Some(worker_info) = tracker.get_worker_info() else {
        return;
    };

    let worker_id_json =
        serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");

    if let Some(obj) = data
        .disaggregated_params
        .as_mut()
        .and_then(|p| p.as_object_mut())
    {
        obj.insert("worker_id".to_string(), worker_id_json);
    } else {
        data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
655
// TODO: can this reuse the stream conversion method in Client bindings?
656
impl KvRouter {
Yan Ru Pei's avatar
Yan Ru Pei committed
657
658
659
    /// Helper method to process a request and create a Python async generator
    fn process_request_to_stream<'p>(
        py: Python<'p>,
660
        inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
661
        request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
662
        tracker: Option<Arc<RequestTracker>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
663
664
665
666
    ) -> PyResult<Bound<'p, PyAny>> {
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let single_in = SingleIn::new(request);
            let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
667
            let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
Yan Ru Pei's avatar
Yan Ru Pei committed
668
669
670

            tokio::spawn(async move {
                let mut stream = stream;
671
                let mut first_item = true;
672
                let mut first_token_gauges_observed = false;
673
674
675
676
677
678
679
680
681

                while let Some(mut response) = stream.next().await {
                    if first_item {
                        first_item = false;
                        if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
                            inject_worker_id_from_tracker(data, tracker);
                        }
                    }

682
683
684
685
686
687
688
689
690
691
692
693
694
695
                    if !first_token_gauges_observed {
                        let has_tokens = response
                            .data
                            .as_ref()
                            .map(|d| !d.token_ids.is_empty())
                            .unwrap_or(false);
                        if has_tokens {
                            if let Some(ref tracker) = tracker {
                                tracker.observe_first_token_gauges();
                            }
                            first_token_gauges_observed = true;
                        }
                    }

Yan Ru Pei's avatar
Yan Ru Pei committed
696
697
698
699
700
701
702
703
                    let py_response = Python::with_gil(|py| {
                        pythonize(py, &response.data)
                            .map(|obj| obj.unbind())
                            .map_err(|e| e.to_string())
                    });

                    match py_response {
                        Ok(obj) => {
704
705
                            if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
                                break;
Yan Ru Pei's avatar
Yan Ru Pei committed
706
707
708
709
710
711
712
713
                            }
                        }
                        Err(e) => {
                            tracing::error!("Failed to pythonize response: {}", e);
                            break;
                        }
                    }
                }
714
715
716
717

                if let Some(ref tracker) = tracker {
                    tracker.observe_finish_gauges();
                }
Yan Ru Pei's avatar
Yan Ru Pei committed
718
719
            });

720
            Ok(crate::AsyncResponseStream::new(rx, false))
Yan Ru Pei's avatar
Yan Ru Pei committed
721
722
        })
    }
723
724
725
}

#[pymethods]
726
727
impl KvRouter {
    /// Create a new KvRouter for KV-aware routing to workers.
728
729
730
731
732
733
734
735
    ///
    /// # Arguments
    /// * `endpoint` - The endpoint to route requests to
    /// * `block_size` - KV cache block size for routing decisions
    /// * `kv_router_config` - Configuration for the KV router
    ///
    /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
    /// (contains "prefill") or by `router_track_active_blocks` being disabled.
736
    #[new]
737
    #[pyo3(signature = (endpoint, block_size, kv_router_config))]
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    fn new(
        endpoint: &Endpoint,
        block_size: usize,
        kv_router_config: &super::entrypoint::KvRouterConfig,
    ) -> PyResult<Self> {
        let runtime = pyo3_async_runtimes::tokio::get_runtime();
        runtime.block_on(async move {
            let client = endpoint.inner.client().await.map_err(to_pyerr)?;

            // Create PushRouter with KV router mode
            let push_router = rs::pipeline::PushRouter::<
                llm_rs::protocols::common::preprocessor::PreprocessedRequest,
                rs::protocols::annotated::Annotated<
                    llm_rs::protocols::common::llm_backend::LLMEngineOutput,
                >,
            >::from_client(
                client,
                rs::pipeline::network::egress::push_router::RouterMode::KV,
            )
            .await
            .map_err(to_pyerr)?;

760
761
762
763
764
765
766
            // Create KvRouter using helper function (ensures etcd registration)
            let kv_router = create_kv_router_from_endpoint(
                endpoint,
                block_size,
                Some(kv_router_config.inner()),
            )
            .await?;
767

768
            let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
769
770
771
772
773
774
775
776

            Ok(Self {
                inner: Arc::new(kv_push_router),
            })
        })
    }

    #[allow(clippy::too_many_arguments)]
777
    #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
778
779
780
781
782
783
784
785
786
    fn generate<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        model: String,
        stop_conditions: Option<PyObject>,
        sampling_options: Option<PyObject>,
        output_options: Option<PyObject>,
        router_config_override: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
787
788
        worker_id: Option<WorkerId>,
        dp_rank: Option<DpRank>,
789
        extra_args: Option<PyObject>,
790
791
792
        block_mm_infos: Option<PyObject>,
        multi_modal_data: Option<PyObject>,
        mm_routing_info: Option<PyObject>,
793
794
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the options with defaults
795
796
797
798
799
        let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            StopConditions::default()
        };
800

801
802
803
804
805
        let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            SamplingOptions::default()
        };
806

807
808
809
810
811
        let output_options: OutputOptions = if let Some(obj) = output_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            OutputOptions::default()
        };
812

813
814
815
816
817
818
        let router_config_override: Option<llm_rs::kv_router::RouterConfigOverride> =
            if let Some(obj) = router_config_override {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };
819

820
821
822
823
824
        let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
            Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
        } else {
            None
        };
825

826
827
828
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848

        let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
            if let Some(obj) = multi_modal_data {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };

        let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
            if let Some(obj) = mm_routing_info {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                block_mm_infos.map(
                    |infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
                        routing_token_ids: token_ids.clone(),
                        block_mm_infos: infos,
                    },
                )
            };

849
850
851
        // Create tracker to capture worker routing info from KvRouter
        let tracker = Arc::new(RequestTracker::new());

852
        // Build the PreprocessedRequest
853
854
855
        let mut request_builder =
            llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
        request_builder
856
857
858
859
860
            .model(model)
            .token_ids(token_ids)
            .stop_conditions(stop_conditions)
            .sampling_options(sampling_options)
            .output_options(output_options)
861
            .router_config_override(router_config_override)
862
863
            .multi_modal_data(multi_modal_data)
            .mm_routing_info(mm_routing_info)
864
865
            .extra_args(extra_args)
            .tracker(Some(tracker.clone()));
866

867
868
869
870
871
872
873
874
        // Set routing hints if worker_id or dp_rank is provided
        if worker_id.is_some() || dp_rank.is_some() {
            let routing = llm_rs::protocols::common::preprocessor::RoutingHints {
                backend_instance_id: worker_id,
                dp_rank,
                ..Default::default()
            };
            request_builder.routing(Some(routing));
875
876
877
        }

        let request = request_builder.build().map_err(to_pyerr)?;
878

Yan Ru Pei's avatar
Yan Ru Pei committed
879
        // Use the helper method to process the request
880
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
Yan Ru Pei's avatar
Yan Ru Pei committed
881
    }
882

Yan Ru Pei's avatar
Yan Ru Pei committed
883
884
885
886
887
888
    fn generate_from_request<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the request directly into PreprocessedRequest
889
        let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
890
            depythonize(request.bind(py)).map_err(to_pyerr)?;
891

892
893
894
895
896
897
898
899
900
901
        // Create tracker if not already set, to capture worker routing info
        let tracker = match request.tracker {
            Some(ref t) => t.clone(),
            None => {
                let t = Arc::new(RequestTracker::new());
                request.tracker = Some(t.clone());
                t
            }
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
902
        // Use the helper method to process the request
903
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
904
    }
905

906
    #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
Yan Ru Pei's avatar
Yan Ru Pei committed
907
908
909
910
911
912
    fn best_worker<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        router_config_override: Option<PyObject>,
        request_id: Option<String>,
913
        block_mm_infos: Option<PyObject>,
914
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
915
916
    ) -> PyResult<Bound<'p, PyAny>> {
        let router_config_override = if let Some(obj) = router_config_override {
917
918
919
            let override_config: llm_rs::kv_router::RouterConfigOverride =
                depythonize(obj.bind(py)).map_err(to_pyerr)?;
            Some(override_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
920
921
922
923
        } else {
            None
        };

924
925
926
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
927

Yan Ru Pei's avatar
Yan Ru Pei committed
928
929
930
931
932
933
934
935
        let chooser = self.inner.chooser.clone();
        let update_states = request_id.is_some();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let (best_worker, overlap_blocks) = chooser
                .find_best_match(
                    request_id.as_deref(),
                    &token_ids,
936
                    block_mm_infos.as_deref(),
Yan Ru Pei's avatar
Yan Ru Pei committed
937
938
                    router_config_override.as_ref(),
                    update_states,
939
                    lora_name,
940
                    0.0,
941
                    None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
Yan Ru Pei's avatar
Yan Ru Pei committed
942
943
944
945
946
947
948
949
                )
                .await
                .map_err(to_pyerr)?;

            Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
        })
    }

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
    /// Mark prefill as completed for a request
    fn mark_prefill_complete<'p>(
        &self,
        py: Python<'p>,
        request_id: String,
    ) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser
                .mark_prefill_completed(&request_id)
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

    /// Free a request by its ID, signaling the router to release resources
    fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser.free(&request_id).await.map_err(to_pyerr)?;
            Ok(())
        })
    }

977
    #[pyo3(signature = (token_ids, lora_name=None))]
978
979
980
981
    fn get_potential_loads<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
982
        lora_name: Option<String>,
983
    ) -> PyResult<Bound<'p, PyAny>> {
984
        let chooser = self.inner.chooser.clone();
985
986

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
987
            let loads = chooser
988
                .get_potential_loads(&token_ids, None, lora_name.as_deref())
989
990
991
                .await
                .map_err(to_pyerr)?;

Yan Ru Pei's avatar
Yan Ru Pei committed
992
            // Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
993
994
995
996
997
998
999
1000
1001
            // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
            Python::with_gil(|py| {
                pythonize(py, &loads)
                    .map(|obj| obj.unbind())
                    .map_err(to_pyerr)
            })
        })
    }

1002
1003
    /// Dump all events from the KV router's indexer as a JSON string
    fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
1004
        let chooser = self.inner.chooser.clone();
1005
1006

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
1007
            let events = chooser.dump_events().await.map_err(to_pyerr)?;
1008
1009
1010
1011
1012
1013
            // Serialize to JSON string
            let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
            Ok(json_str)
        })
    }
}