"docs/integrations/flexkv_integration.md" did not exist on "87449e1c97e31acfedf555d8d8d4233db3d560f4"
kv.rs 34.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use pythonize::{depythonize, pythonize};
5
use std::collections::HashMap;
6
use std::sync::Arc;
7
use std::sync::atomic::AtomicU32;
8
use std::sync::mpsc;
9
use tokio_stream::StreamExt;
10

11
use super::*;
12
use crate::Endpoint;
13
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
Yan Ru Pei's avatar
Yan Ru Pei committed
14
use rs::pipeline::{AsyncEngine, SingleIn};
15
use rs::protocols::annotated::Annotated as RsAnnotated;
16
use tracing;
17

18
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
19
use llm_rs::kv_router::protocols::*;
20
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
21
use llm_rs::protocols::common::timing::RequestTracker;
22
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
23
use serde_json::json;
24

25
26
27
28
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
    depythonize(obj).map_err(to_pyerr)
}

Yan Ru Pei's avatar
Yan Ru Pei committed
29
#[pyfunction]
30
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
31
32
33
34
35
pub fn compute_block_hash_for_seq_py(
    _py: Python,
    tokens: Vec<u32>,
    kv_block_size: usize,
    block_mm_infos: Option<Bound<PyAny>>,
36
    lora_name: Option<String>,
37
) -> PyResult<Vec<u64>> {
Yan Ru Pei's avatar
Yan Ru Pei committed
38
    if kv_block_size == 0 {
39
40
41
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "kv_block_size cannot be 0",
        ));
Yan Ru Pei's avatar
Yan Ru Pei committed
42
43
    }

44
    let mm_infos = block_mm_infos
45
        .as_ref()
46
        .map(depythonize_block_mm_infos)
47
48
        .transpose()?;

49
50
51
52
53
54
    let hashes = compute_block_hash_for_seq(
        &tokens,
        kv_block_size as u32,
        mm_infos.as_deref(),
        lora_name.as_deref(),
    );
55

Yan Ru Pei's avatar
Yan Ru Pei committed
56
57
58
    Ok(hashes.into_iter().map(|h| h.0).collect())
}

GuanLuo's avatar
GuanLuo committed
59
#[pyclass]
60
61
pub(crate) struct WorkerMetricsPublisher {
    inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
GuanLuo's avatar
GuanLuo committed
62
63
64
}

#[pymethods]
65
impl WorkerMetricsPublisher {
GuanLuo's avatar
GuanLuo committed
66
67
    #[new]
    fn new() -> PyResult<Self> {
68
69
        let inner =
            llm_rs::kv_router::publisher::WorkerMetricsPublisher::new().map_err(to_pyerr)?;
GuanLuo's avatar
GuanLuo committed
70
71
72
73
74
        Ok(Self {
            inner: inner.into(),
        })
    }

75
    #[pyo3(signature = (endpoint))]
Alec's avatar
Alec committed
76
    fn create_endpoint<'p>(
GuanLuo's avatar
GuanLuo committed
77
78
        &self,
        py: Python<'p>,
79
        endpoint: Endpoint,
GuanLuo's avatar
GuanLuo committed
80
81
    ) -> PyResult<Bound<'p, PyAny>> {
        let rs_publisher = self.inner.clone();
82
        let rs_component = endpoint.inner.component().clone();
GuanLuo's avatar
GuanLuo committed
83
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
84
            rs_publisher
85
                .create_endpoint(rs_component)
GuanLuo's avatar
GuanLuo committed
86
87
88
89
90
91
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

92
93
94
95
96
97
98
    /// Publish worker metrics for load monitoring.
    ///
    /// # Arguments
    /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
    /// * `active_decode_blocks` - Number of active KV cache blocks
    #[pyo3(signature = (dp_rank, active_decode_blocks))]
    fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> {
GuanLuo's avatar
GuanLuo committed
99
        self.inner
100
            .publish(dp_rank, active_decode_blocks)
GuanLuo's avatar
GuanLuo committed
101
102
103
            .map_err(to_pyerr)
    }
}
104

105
106
107
#[pyclass]
pub(crate) struct KvEventPublisher {
    inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
108
    kv_block_size: usize,
Yan Ru Pei's avatar
Yan Ru Pei committed
109
    dp_rank: DpRank,
110
    warning_count: Arc<AtomicU32>,
111
112
113
114
}

#[pymethods]
impl KvEventPublisher {
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    /// Create a KV event publisher that batches raw engine events before forwarding
    /// them to NATS / the event plane.
    ///
    /// Args:
    ///     endpoint: The Dynamo component endpoint for this worker.
    ///     worker_id: Identifier of this worker (default 0).
    ///     kv_block_size: KV cache block size in tokens; must be > 0.
    ///     dp_rank: Data-parallel rank of this worker (default 0).
    ///     enable_local_indexer: When True, a local KV indexer is kept in-process
    ///         so that routers can recover events directly from this worker.
    ///     zmq_endpoint: Optional ZMQ SUB endpoint to read raw engine events from.
    ///     zmq_topic: ZMQ topic filter (default "").
    ///     batching_timeout_us: Maximum time (in **microseconds**) to accumulate
    ///         events into a single batch before flushing.
    ///         ``None`` uses the default window of 10000 µs (10 ms).
    ///         ``0`` disables batching: every event is published immediately.
131
    #[new]
132
133
    #[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None, batching_timeout_us=None))]
    #[allow(clippy::too_many_arguments)]
134
    fn new(
135
        endpoint: Endpoint,
136
137
138
        worker_id: WorkerId,
        kv_block_size: usize,
        dp_rank: DpRank,
139
        enable_local_indexer: bool,
140
141
        zmq_endpoint: Option<String>,
        zmq_topic: Option<String>,
142
        batching_timeout_us: Option<u64>,
143
    ) -> PyResult<Self> {
144
145
        let _ = worker_id;

146
147
        let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
            endpoint: ep,
148
149
            topic: zmq_topic.unwrap_or_default(),
        });
150

Yan Ru Pei's avatar
Yan Ru Pei committed
151
152
153
154
        if kv_block_size == 0 {
            return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
        }

155
156
157
        // Extract component from endpoint
        let component = endpoint.inner.component().clone();

158
        let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
159
            component,
160
            kv_block_size as u32,
161
            source_config,
162
            enable_local_indexer,
163
            dp_rank,
164
            batching_timeout_us,
165
166
        )
        .map_err(to_pyerr)?;
167

168
169
        Ok(Self {
            inner: inner.into(),
170
            kv_block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
171
            dp_rank,
172
            warning_count: Arc::new(AtomicU32::new(0)),
173
174
175
176
        })
    }

    #[allow(clippy::too_many_arguments)]
177
    #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
178
    fn publish_stored(
179
        &self,
180
        py: Python,
181
182
        token_ids: Vec<u32>,
        num_block_tokens: Vec<u64>,
183
184
        block_hashes: Vec<i64>,
        parent_hash: Option<i64>,
185
        block_mm_infos: Option<Bound<PyAny>>,
186
        lora_name: Option<String>,
187
    ) -> PyResult<()> {
188
189
190
191
192
        let kv_block_size = self.kv_block_size as u32;
        let dp_rank = self.dp_rank;
        let warning_count = self.warning_count.clone();
        let inner = self.inner.clone();

193
194
        let event_id = inner.next_event_id();

195
        let mm_infos = block_mm_infos
196
            .as_ref()
197
            .map(depythonize_block_mm_infos)
198
199
            .transpose()?;

200
201
202
203
204
205
206
207
208
209
210
        py.allow_threads(|| {
            let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
                    blocks: create_stored_blocks(
                        kv_block_size,
                        &token_ids,
                        &num_block_tokens,
                        &block_hashes_u64,
211
                        lora_name.as_deref(),
212
                        &warning_count,
213
                        mm_infos.as_deref(),
214
215
216
217
218
219
220
                    ),
                }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
221
222
    }

223
    fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
224
225
226
        let dp_rank = self.dp_rank;
        let inner = self.inner.clone();

227
228
229
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

230
231
232
233
234
235
236
237
238
239
240
241
242
        py.allow_threads(|| {
            let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
                .into_iter()
                .map(ExternalSequenceBlockHash::from)
                .collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
243
    }
244
245
246
247
248
249
250
251

    fn shutdown(&mut self) {
        // If no other Arc clones exist, shut down eagerly.
        // Otherwise the Drop impl handles cleanup when the last reference is freed.
        if let Some(inner) = Arc::get_mut(&mut self.inner) {
            inner.shutdown();
        }
    }
252
253
}

254
255
256
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
257
    inner: llm_rs::kv_router::protocols::OverlapScores,
258
259
260
261
262
}

#[pymethods]
impl OverlapScores {
    #[getter]
263
    fn scores(&self) -> HashMap<(u64, u32), u32> {
Yan Ru Pei's avatar
Yan Ru Pei committed
264
265
266
267
268
269
        // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
        self.inner
            .scores
            .iter()
            .map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
            .collect()
270
271
272
273
274
275
276
277
    }

    #[getter]
    fn frequencies(&self) -> Vec<usize> {
        self.inner.frequencies.clone()
    }
}

278
279
280
281
282
#[derive(Debug)]
enum RadixTreeRequest {
    FindMatches {
        local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
        early_exit: bool,
283
        response_tx: mpsc::SyncSender<llm_rs::kv_router::protocols::OverlapScores>,
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    },
    ApplyEvent {
        worker_id: WorkerId,
        kv_cache_event_bytes: Vec<u8>,
        response_tx: mpsc::SyncSender<PyResult<()>>,
    },
    RemoveWorker {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    ClearAllBlocks {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    DumpTreeAsEvents {
299
        response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::protocols::RouterEvent>>,
300
301
302
303
304
305
    },
    Shutdown,
}

// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
Yan Ru Pei's avatar
Yan Ru Pei committed
306
pub(crate) struct RadixTree {
307
    request_tx: mpsc::Sender<RadixTreeRequest>,
Yan Ru Pei's avatar
Yan Ru Pei committed
308
309
310
311
312
313
314
315
}

#[pymethods]
impl RadixTree {
    #[new]
    #[pyo3(signature = (expiration_duration_secs=None))]
    fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
        let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

        let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();

        // Spawn dedicated thread with simplified sync processing
        std::thread::spawn(move || {
            let mut radix_tree =
                llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);

            loop {
                match request_rx.recv() {
                    Ok(RadixTreeRequest::Shutdown) => {
                        tracing::debug!("RadixTree thread received shutdown request");
                        break;
                    }
                    Ok(request) => {
                        Self::handle_request(&mut radix_tree, request);
                    }
                    Err(mpsc::RecvError) => {
                        tracing::debug!("RadixTree request channel disconnected");
                        break;
                    }
                }
            }
        });

        Ok(Self { request_tx })
Yan Ru Pei's avatar
Yan Ru Pei committed
342
343
344
345
346
    }

    #[pyo3(signature = (sequence, early_exit=false))]
    fn find_matches(
        &self,
347
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
348
349
350
        sequence: Vec<u64>,
        early_exit: bool,
    ) -> PyResult<OverlapScores> {
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let local_block_hashes = py.allow_threads(|| {
            sequence
                .into_iter()
                .map(llm_rs::kv_router::protocols::LocalBlockHash)
                .collect()
        });

        let request = RadixTreeRequest::FindMatches {
            local_block_hashes,
            early_exit,
            response_tx,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
365

366
367
368
369
370
371
372
373
374
375
376
377
378
379
        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })?;

        Ok(OverlapScores { inner: result })
Yan Ru Pei's avatar
Yan Ru Pei committed
380
381
382
    }

    fn apply_event(
383
384
        &self,
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
385
        worker_id: WorkerId,
Yan Ru Pei's avatar
Yan Ru Pei committed
386
387
        kv_cache_event_bytes: &[u8],
    ) -> PyResult<()> {
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ApplyEvent {
            worker_id,
            kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || response_rx.recv());

        result.map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
        })?
    }

    fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::RemoveWorker {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }

    fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ClearAllBlocks {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
453

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
        })?;

        // Release GIL while waiting for response from dedicated thread
        let events = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Failed to receive dump tree response",
                )
            })
        })?;

        // Serialize RouterEvent structs to JSON strings with GIL released
        py.allow_threads(move || {
            events
                .into_iter()
                .map(|event| {
                    serde_json::to_string(&event).map_err(|e| {
                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                            "Failed to serialize event to JSON: {}",
                            e
                        ))
                    })
                })
                .collect::<Result<Vec<String>, PyErr>>()
        })
Yan Ru Pei's avatar
Yan Ru Pei committed
486
    }
487
}
Yan Ru Pei's avatar
Yan Ru Pei committed
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
impl RadixTree {
    fn handle_request(
        radix_tree: &mut llm_rs::kv_router::indexer::RadixTree,
        request: RadixTreeRequest,
    ) {
        match request {
            RadixTreeRequest::FindMatches {
                local_block_hashes,
                early_exit,
                response_tx,
            } => {
                let result = radix_tree.find_matches(local_block_hashes, early_exit);
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::ApplyEvent {
                worker_id,
                kv_cache_event_bytes,
                response_tx,
            } => {
                let result = match serde_json::from_slice::<
                    llm_rs::kv_router::protocols::KvCacheEvent,
                >(&kv_cache_event_bytes)
                {
                    Ok(kv_cache_event) => {
513
514
515
516
                        let router_event = llm_rs::kv_router::protocols::RouterEvent::new(
                            worker_id,
                            kv_cache_event,
                        );
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
                        match radix_tree.apply_event(router_event) {
                            Ok(_) => Ok(()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Failed to apply event: {}", e),
                            )),
                        }
                    }
                    Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Failed to deserialize KvCacheEvent: {}",
                        e
                    ))),
                };
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::RemoveWorker {
                worker_id,
                response_tx,
            } => {
                radix_tree.remove_worker(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::ClearAllBlocks {
                worker_id,
                response_tx,
            } => {
                radix_tree.clear_all_blocks(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
                let events = radix_tree.dump_tree_as_events();
                let _ = response_tx.send(events);
            }
            RadixTreeRequest::Shutdown => {
                // This is handled in the main loop
            }
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
553
    }
554
}
Yan Ru Pei's avatar
Yan Ru Pei committed
555

556
557
558
559
560
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
    fn drop(&mut self) {
        // Only need graceful shutdown via RadixTreeRequest::Shutdown
        let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
Yan Ru Pei's avatar
Yan Ru Pei committed
561
562
563
    }
}

564
/// Helper function to create a KV router from an endpoint using the ModelManager
565
566
567
568
569
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
/// - If endpoint name/component contains "prefill", treat as prefill
/// - If router_track_active_blocks is disabled, treat as prefill
/// - Otherwise, default to decode
570
571
572
573
574
async fn create_kv_router_from_endpoint(
    endpoint: &Endpoint,
    block_size: usize,
    kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
575
    // Create ModelManager and use it to create KvRouter (ensures registration)
576
    let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    let endpoint_id = endpoint.inner.id();
    let namespace = endpoint_id.namespace.to_lowercase();
    let component = endpoint_id.component.to_lowercase();
    let name = endpoint_id.name.to_lowercase();
    let endpoint_is_prefill =
        namespace.contains("prefill") || component.contains("prefill") || name.contains("prefill");
    let track_active_blocks = kv_router_config
        .as_ref()
        .map(|cfg| cfg.router_track_active_blocks)
        .unwrap_or(true);
    let worker_type = if endpoint_is_prefill || !track_active_blocks {
        llm_rs::discovery::WORKER_TYPE_PREFILL
    } else {
        llm_rs::discovery::WORKER_TYPE_DECODE
    };
592
    let kv_router = model_manager
593
594
595
596
597
598
        .kv_chooser_for(
            &endpoint.inner,
            block_size as u32,
            kv_router_config,
            worker_type,
        )
599
600
601
602
603
604
        .await
        .map_err(to_pyerr)?;

    Ok(kv_router)
}

605
#[pyclass]
606
607
pub(crate) struct KvRouter {
    inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
608
609
}

610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
    data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
    tracker: &RequestTracker,
) {
    let Some(worker_info) = tracker.get_worker_info() else {
        return;
    };

    let worker_id_json =
        serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");

    if let Some(obj) = data
        .disaggregated_params
        .as_mut()
        .and_then(|p| p.as_object_mut())
    {
        obj.insert("worker_id".to_string(), worker_id_json);
    } else {
        data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
635
// TODO: can this reuse the stream conversion method in Client bindings?
636
impl KvRouter {
Yan Ru Pei's avatar
Yan Ru Pei committed
637
638
639
    /// Helper method to process a request and create a Python async generator
    fn process_request_to_stream<'p>(
        py: Python<'p>,
640
        inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
641
        request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
642
        tracker: Option<Arc<RequestTracker>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
643
644
645
646
    ) -> PyResult<Bound<'p, PyAny>> {
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let single_in = SingleIn::new(request);
            let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
647
            let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
Yan Ru Pei's avatar
Yan Ru Pei committed
648
649
650

            tokio::spawn(async move {
                let mut stream = stream;
651
                let mut first_item = true;
652
                let mut first_token_gauges_observed = false;
653
654
655
656
657
658
659
660
661

                while let Some(mut response) = stream.next().await {
                    if first_item {
                        first_item = false;
                        if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
                            inject_worker_id_from_tracker(data, tracker);
                        }
                    }

662
663
664
665
666
667
668
669
670
671
672
673
674
675
                    if !first_token_gauges_observed {
                        let has_tokens = response
                            .data
                            .as_ref()
                            .map(|d| !d.token_ids.is_empty())
                            .unwrap_or(false);
                        if has_tokens {
                            if let Some(ref tracker) = tracker {
                                tracker.observe_first_token_gauges();
                            }
                            first_token_gauges_observed = true;
                        }
                    }

Yan Ru Pei's avatar
Yan Ru Pei committed
676
677
678
679
680
681
682
683
                    let py_response = Python::with_gil(|py| {
                        pythonize(py, &response.data)
                            .map(|obj| obj.unbind())
                            .map_err(|e| e.to_string())
                    });

                    match py_response {
                        Ok(obj) => {
684
685
                            if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
                                break;
Yan Ru Pei's avatar
Yan Ru Pei committed
686
687
688
689
690
691
692
693
                            }
                        }
                        Err(e) => {
                            tracing::error!("Failed to pythonize response: {}", e);
                            break;
                        }
                    }
                }
694
695
696
697

                if let Some(ref tracker) = tracker {
                    tracker.observe_finish_gauges();
                }
Yan Ru Pei's avatar
Yan Ru Pei committed
698
699
            });

700
            Ok(crate::AsyncResponseStream::new(rx, false))
Yan Ru Pei's avatar
Yan Ru Pei committed
701
702
        })
    }
703
704
705
}

#[pymethods]
706
707
impl KvRouter {
    /// Create a new KvRouter for KV-aware routing to workers.
708
709
710
711
712
713
714
715
    ///
    /// # Arguments
    /// * `endpoint` - The endpoint to route requests to
    /// * `block_size` - KV cache block size for routing decisions
    /// * `kv_router_config` - Configuration for the KV router
    ///
    /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
    /// (contains "prefill") or by `router_track_active_blocks` being disabled.
716
    #[new]
717
    #[pyo3(signature = (endpoint, block_size, kv_router_config))]
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
    fn new(
        endpoint: &Endpoint,
        block_size: usize,
        kv_router_config: &super::entrypoint::KvRouterConfig,
    ) -> PyResult<Self> {
        let runtime = pyo3_async_runtimes::tokio::get_runtime();
        runtime.block_on(async move {
            let client = endpoint.inner.client().await.map_err(to_pyerr)?;

            // Create PushRouter with KV router mode
            let push_router = rs::pipeline::PushRouter::<
                llm_rs::protocols::common::preprocessor::PreprocessedRequest,
                rs::protocols::annotated::Annotated<
                    llm_rs::protocols::common::llm_backend::LLMEngineOutput,
                >,
            >::from_client(
                client,
                rs::pipeline::network::egress::push_router::RouterMode::KV,
            )
            .await
            .map_err(to_pyerr)?;

740
741
742
743
744
745
746
            // Create KvRouter using helper function (ensures etcd registration)
            let kv_router = create_kv_router_from_endpoint(
                endpoint,
                block_size,
                Some(kv_router_config.inner()),
            )
            .await?;
747

748
            let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
749
750
751
752
753
754
755
756

            Ok(Self {
                inner: Arc::new(kv_push_router),
            })
        })
    }

    #[allow(clippy::too_many_arguments)]
757
    #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
758
759
760
761
762
763
764
765
766
    fn generate<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        model: String,
        stop_conditions: Option<PyObject>,
        sampling_options: Option<PyObject>,
        output_options: Option<PyObject>,
        router_config_override: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
767
768
        worker_id: Option<WorkerId>,
        dp_rank: Option<DpRank>,
769
        extra_args: Option<PyObject>,
770
771
772
        block_mm_infos: Option<PyObject>,
        multi_modal_data: Option<PyObject>,
        mm_routing_info: Option<PyObject>,
773
774
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the options with defaults
775
776
777
778
779
        let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            StopConditions::default()
        };
780

781
782
783
784
785
        let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            SamplingOptions::default()
        };
786

787
788
789
790
791
        let output_options: OutputOptions = if let Some(obj) = output_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            OutputOptions::default()
        };
792

793
794
795
796
797
798
        let router_config_override: Option<llm_rs::kv_router::RouterConfigOverride> =
            if let Some(obj) = router_config_override {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };
799

800
801
802
803
804
        let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
            Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
        } else {
            None
        };
805

806
807
808
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828

        let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
            if let Some(obj) = multi_modal_data {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };

        let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
            if let Some(obj) = mm_routing_info {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                block_mm_infos.map(
                    |infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
                        routing_token_ids: token_ids.clone(),
                        block_mm_infos: infos,
                    },
                )
            };

829
830
831
        // Create tracker to capture worker routing info from KvRouter
        let tracker = Arc::new(RequestTracker::new());

832
        // Build the PreprocessedRequest
833
834
835
        let mut request_builder =
            llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
        request_builder
836
837
838
839
840
            .model(model)
            .token_ids(token_ids)
            .stop_conditions(stop_conditions)
            .sampling_options(sampling_options)
            .output_options(output_options)
841
            .router_config_override(router_config_override)
842
843
            .multi_modal_data(multi_modal_data)
            .mm_routing_info(mm_routing_info)
844
845
            .extra_args(extra_args)
            .tracker(Some(tracker.clone()));
846

847
848
849
850
851
852
853
854
        // Set routing hints if worker_id or dp_rank is provided
        if worker_id.is_some() || dp_rank.is_some() {
            let routing = llm_rs::protocols::common::preprocessor::RoutingHints {
                backend_instance_id: worker_id,
                dp_rank,
                ..Default::default()
            };
            request_builder.routing(Some(routing));
855
856
857
        }

        let request = request_builder.build().map_err(to_pyerr)?;
858

Yan Ru Pei's avatar
Yan Ru Pei committed
859
        // Use the helper method to process the request
860
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
Yan Ru Pei's avatar
Yan Ru Pei committed
861
    }
862

Yan Ru Pei's avatar
Yan Ru Pei committed
863
864
865
866
867
868
    fn generate_from_request<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the request directly into PreprocessedRequest
869
        let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
870
            depythonize(request.bind(py)).map_err(to_pyerr)?;
871

872
873
874
875
876
877
878
879
880
881
        // Create tracker if not already set, to capture worker routing info
        let tracker = match request.tracker {
            Some(ref t) => t.clone(),
            None => {
                let t = Arc::new(RequestTracker::new());
                request.tracker = Some(t.clone());
                t
            }
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
882
        // Use the helper method to process the request
883
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
884
    }
885

886
    #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
Yan Ru Pei's avatar
Yan Ru Pei committed
887
888
889
890
891
892
    fn best_worker<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        router_config_override: Option<PyObject>,
        request_id: Option<String>,
893
        block_mm_infos: Option<PyObject>,
894
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
895
896
    ) -> PyResult<Bound<'p, PyAny>> {
        let router_config_override = if let Some(obj) = router_config_override {
897
898
899
            let override_config: llm_rs::kv_router::RouterConfigOverride =
                depythonize(obj.bind(py)).map_err(to_pyerr)?;
            Some(override_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
900
901
902
903
        } else {
            None
        };

904
905
906
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
907

Yan Ru Pei's avatar
Yan Ru Pei committed
908
909
910
911
912
913
914
915
        let chooser = self.inner.chooser.clone();
        let update_states = request_id.is_some();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let (best_worker, overlap_blocks) = chooser
                .find_best_match(
                    request_id.as_deref(),
                    &token_ids,
916
                    block_mm_infos.as_deref(),
Yan Ru Pei's avatar
Yan Ru Pei committed
917
918
                    router_config_override.as_ref(),
                    update_states,
919
                    lora_name,
920
                    0.0,
921
                    None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
Yan Ru Pei's avatar
Yan Ru Pei committed
922
923
924
925
926
927
928
929
                )
                .await
                .map_err(to_pyerr)?;

            Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
        })
    }

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
    /// Mark prefill as completed for a request
    fn mark_prefill_complete<'p>(
        &self,
        py: Python<'p>,
        request_id: String,
    ) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser
                .mark_prefill_completed(&request_id)
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

    /// Free a request by its ID, signaling the router to release resources
    fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser.free(&request_id).await.map_err(to_pyerr)?;
            Ok(())
        })
    }

957
    #[pyo3(signature = (token_ids, lora_name=None))]
958
959
960
961
    fn get_potential_loads<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
962
        lora_name: Option<String>,
963
    ) -> PyResult<Bound<'p, PyAny>> {
964
        let chooser = self.inner.chooser.clone();
965
966

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
967
            let loads = chooser
968
                .get_potential_loads(&token_ids, None, lora_name.as_deref())
969
970
971
                .await
                .map_err(to_pyerr)?;

Yan Ru Pei's avatar
Yan Ru Pei committed
972
            // Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
973
974
975
976
977
978
979
980
981
            // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
            Python::with_gil(|py| {
                pythonize(py, &loads)
                    .map(|obj| obj.unbind())
                    .map_err(to_pyerr)
            })
        })
    }

982
983
    /// Dump all events from the KV router's indexer as a JSON string
    fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
984
        let chooser = self.inner.chooser.clone();
985
986

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
987
            let events = chooser.dump_events().await.map_err(to_pyerr)?;
988
989
990
991
992
993
            // Serialize to JSON string
            let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
            Ok(json_str)
        })
    }
}