kv.rs 33.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use pythonize::{depythonize, pythonize};
5
use std::collections::HashMap;
6
use std::sync::Arc;
7
use std::sync::atomic::AtomicU32;
8
use std::sync::mpsc;
9
use tokio_stream::StreamExt;
10

11
use super::*;
12
use crate::Endpoint;
13
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
Yan Ru Pei's avatar
Yan Ru Pei committed
14
use rs::pipeline::{AsyncEngine, SingleIn};
15
use rs::protocols::annotated::Annotated as RsAnnotated;
16
use tracing;
17

18
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
19
use llm_rs::kv_router::protocols::*;
20
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
21
use llm_rs::protocols::common::timing::RequestTracker;
22
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
23
use serde_json::json;
24

25
26
27
28
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
    depythonize(obj).map_err(to_pyerr)
}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#[pyfunction]
#[pyo3(name = "start_kv_block_indexer", signature = (endpoint, block_size, kv_router_config))]
pub fn start_kv_block_indexer_py<'p>(
    py: Python<'p>,
    endpoint: &Endpoint,
    block_size: u32,
    kv_router_config: &super::entrypoint::KvRouterConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let component = endpoint.inner.component().clone();
    let config = kv_router_config.inner();
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        llm_rs::kv_router::indexer_standalone::start_kv_block_indexer(
            &component, &config, block_size,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

Yan Ru Pei's avatar
Yan Ru Pei committed
49
#[pyfunction]
50
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
51
52
53
54
55
56
pub fn compute_block_hash_for_seq_py(
    _py: Python,
    tokens: Vec<u32>,
    kv_block_size: usize,
    block_mm_infos: Option<Bound<PyAny>>,
) -> PyResult<Vec<u64>> {
Yan Ru Pei's avatar
Yan Ru Pei committed
57
    if kv_block_size == 0 {
58
59
60
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "kv_block_size cannot be 0",
        ));
Yan Ru Pei's avatar
Yan Ru Pei committed
61
62
    }

63
    let mm_infos = block_mm_infos
64
        .as_ref()
65
        .map(depythonize_block_mm_infos)
66
67
        .transpose()?;

68
    let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos.as_deref());
69

Yan Ru Pei's avatar
Yan Ru Pei committed
70
71
72
    Ok(hashes.into_iter().map(|h| h.0).collect())
}

GuanLuo's avatar
GuanLuo committed
73
#[pyclass]
74
75
pub(crate) struct WorkerMetricsPublisher {
    inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
GuanLuo's avatar
GuanLuo committed
76
77
78
}

#[pymethods]
79
impl WorkerMetricsPublisher {
GuanLuo's avatar
GuanLuo committed
80
81
    #[new]
    fn new() -> PyResult<Self> {
82
83
        let inner =
            llm_rs::kv_router::publisher::WorkerMetricsPublisher::new().map_err(to_pyerr)?;
GuanLuo's avatar
GuanLuo committed
84
85
86
87
88
        Ok(Self {
            inner: inner.into(),
        })
    }

89
    #[pyo3(signature = (endpoint))]
Alec's avatar
Alec committed
90
    fn create_endpoint<'p>(
GuanLuo's avatar
GuanLuo committed
91
92
        &self,
        py: Python<'p>,
93
        endpoint: Endpoint,
GuanLuo's avatar
GuanLuo committed
94
95
    ) -> PyResult<Bound<'p, PyAny>> {
        let rs_publisher = self.inner.clone();
96
        let rs_component = endpoint.inner.component().clone();
GuanLuo's avatar
GuanLuo committed
97
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
98
            rs_publisher
99
                .create_endpoint(rs_component)
GuanLuo's avatar
GuanLuo committed
100
101
102
103
104
105
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

106
107
108
109
110
111
112
    /// Publish worker metrics for load monitoring.
    ///
    /// # Arguments
    /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
    /// * `active_decode_blocks` - Number of active KV cache blocks
    #[pyo3(signature = (dp_rank, active_decode_blocks))]
    fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> {
GuanLuo's avatar
GuanLuo committed
113
        self.inner
114
            .publish(dp_rank, active_decode_blocks)
GuanLuo's avatar
GuanLuo committed
115
116
117
            .map_err(to_pyerr)
    }
}
118

119
120
121
#[pyclass]
pub(crate) struct KvEventPublisher {
    inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
122
    kv_block_size: usize,
Yan Ru Pei's avatar
Yan Ru Pei committed
123
    dp_rank: DpRank,
124
    warning_count: Arc<AtomicU32>,
125
126
127
128
129
}

#[pymethods]
impl KvEventPublisher {
    #[new]
130
    #[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None))]
131
    fn new(
132
        endpoint: Endpoint,
133
134
135
        worker_id: WorkerId,
        kv_block_size: usize,
        dp_rank: DpRank,
136
        enable_local_indexer: bool,
137
138
        zmq_endpoint: Option<String>,
        zmq_topic: Option<String>,
139
    ) -> PyResult<Self> {
140
141
        let _ = worker_id;

142
143
        let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
            endpoint: ep,
144
145
            topic: zmq_topic.unwrap_or_default(),
        });
146

Yan Ru Pei's avatar
Yan Ru Pei committed
147
148
149
150
        if kv_block_size == 0 {
            return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
        }

151
152
153
        // Extract component from endpoint
        let component = endpoint.inner.component().clone();

154
        let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
155
            component,
156
            kv_block_size as u32,
157
            source_config,
158
            enable_local_indexer,
159
            dp_rank,
160
161
        )
        .map_err(to_pyerr)?;
162

163
164
        Ok(Self {
            inner: inner.into(),
165
            kv_block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
166
            dp_rank,
167
            warning_count: Arc::new(AtomicU32::new(0)),
168
169
170
171
        })
    }

    #[allow(clippy::too_many_arguments)]
172
    #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))]
173
    fn publish_stored(
174
        &self,
175
        py: Python,
176
177
        token_ids: Vec<u32>,
        num_block_tokens: Vec<u64>,
178
        block_hashes: Vec<i64>,
179
        lora_id: u64,
180
        parent_hash: Option<i64>,
181
        block_mm_infos: Option<Bound<PyAny>>,
182
    ) -> PyResult<()> {
183
184
185
186
187
        let kv_block_size = self.kv_block_size as u32;
        let dp_rank = self.dp_rank;
        let warning_count = self.warning_count.clone();
        let inner = self.inner.clone();

188
189
190
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

191
        let mm_infos = block_mm_infos
192
            .as_ref()
193
            .map(depythonize_block_mm_infos)
194
195
            .transpose()?;

196
197
198
199
200
201
202
203
204
205
206
207
208
        py.allow_threads(|| {
            let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
                    blocks: create_stored_blocks(
                        kv_block_size,
                        &token_ids,
                        &num_block_tokens,
                        &block_hashes_u64,
                        lora_id,
                        &warning_count,
209
                        mm_infos.as_deref(),
210
211
212
213
214
215
216
                    ),
                }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
217
218
    }

219
    fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
220
221
222
        let dp_rank = self.dp_rank;
        let inner = self.inner.clone();

223
224
225
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

226
227
228
229
230
231
232
233
234
235
236
237
238
        py.allow_threads(|| {
            let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
                .into_iter()
                .map(ExternalSequenceBlockHash::from)
                .collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
239
    }
240
241
242
243
244
245
246
247

    fn shutdown(&mut self) {
        // If no other Arc clones exist, shut down eagerly.
        // Otherwise the Drop impl handles cleanup when the last reference is freed.
        if let Some(inner) = Arc::get_mut(&mut self.inner) {
            inner.shutdown();
        }
    }
248
249
}

250
251
252
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
253
    inner: llm_rs::kv_router::protocols::OverlapScores,
254
255
256
257
258
}

#[pymethods]
impl OverlapScores {
    #[getter]
259
    fn scores(&self) -> HashMap<(u64, u32), u32> {
Yan Ru Pei's avatar
Yan Ru Pei committed
260
261
262
263
264
265
        // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
        self.inner
            .scores
            .iter()
            .map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
            .collect()
266
267
268
269
270
271
272
273
    }

    #[getter]
    fn frequencies(&self) -> Vec<usize> {
        self.inner.frequencies.clone()
    }
}

274
275
276
277
278
#[derive(Debug)]
enum RadixTreeRequest {
    FindMatches {
        local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
        early_exit: bool,
279
        response_tx: mpsc::SyncSender<llm_rs::kv_router::protocols::OverlapScores>,
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    },
    ApplyEvent {
        worker_id: WorkerId,
        kv_cache_event_bytes: Vec<u8>,
        response_tx: mpsc::SyncSender<PyResult<()>>,
    },
    RemoveWorker {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    ClearAllBlocks {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    DumpTreeAsEvents {
295
        response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::protocols::RouterEvent>>,
296
297
298
299
300
301
    },
    Shutdown,
}

// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
Yan Ru Pei's avatar
Yan Ru Pei committed
302
pub(crate) struct RadixTree {
303
    request_tx: mpsc::Sender<RadixTreeRequest>,
Yan Ru Pei's avatar
Yan Ru Pei committed
304
305
306
307
308
309
310
311
}

#[pymethods]
impl RadixTree {
    #[new]
    #[pyo3(signature = (expiration_duration_secs=None))]
    fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
        let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();

        // Spawn dedicated thread with simplified sync processing
        std::thread::spawn(move || {
            let mut radix_tree =
                llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);

            loop {
                match request_rx.recv() {
                    Ok(RadixTreeRequest::Shutdown) => {
                        tracing::debug!("RadixTree thread received shutdown request");
                        break;
                    }
                    Ok(request) => {
                        Self::handle_request(&mut radix_tree, request);
                    }
                    Err(mpsc::RecvError) => {
                        tracing::debug!("RadixTree request channel disconnected");
                        break;
                    }
                }
            }
        });

        Ok(Self { request_tx })
Yan Ru Pei's avatar
Yan Ru Pei committed
338
339
340
341
342
    }

    #[pyo3(signature = (sequence, early_exit=false))]
    fn find_matches(
        &self,
343
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
344
345
346
        sequence: Vec<u64>,
        early_exit: bool,
    ) -> PyResult<OverlapScores> {
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let local_block_hashes = py.allow_threads(|| {
            sequence
                .into_iter()
                .map(llm_rs::kv_router::protocols::LocalBlockHash)
                .collect()
        });

        let request = RadixTreeRequest::FindMatches {
            local_block_hashes,
            early_exit,
            response_tx,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
361

362
363
364
365
366
367
368
369
370
371
372
373
374
375
        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })?;

        Ok(OverlapScores { inner: result })
Yan Ru Pei's avatar
Yan Ru Pei committed
376
377
378
    }

    fn apply_event(
379
380
        &self,
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
381
        worker_id: WorkerId,
Yan Ru Pei's avatar
Yan Ru Pei committed
382
383
        kv_cache_event_bytes: &[u8],
    ) -> PyResult<()> {
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ApplyEvent {
            worker_id,
            kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || response_rx.recv());

        result.map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
        })?
    }

    fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::RemoveWorker {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }

    fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ClearAllBlocks {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
        })?;

        // Release GIL while waiting for response from dedicated thread
        let events = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Failed to receive dump tree response",
                )
            })
        })?;

        // Serialize RouterEvent structs to JSON strings with GIL released
        py.allow_threads(move || {
            events
                .into_iter()
                .map(|event| {
                    serde_json::to_string(&event).map_err(|e| {
                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                            "Failed to serialize event to JSON: {}",
                            e
                        ))
                    })
                })
                .collect::<Result<Vec<String>, PyErr>>()
        })
Yan Ru Pei's avatar
Yan Ru Pei committed
482
    }
483
}
Yan Ru Pei's avatar
Yan Ru Pei committed
484

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
impl RadixTree {
    fn handle_request(
        radix_tree: &mut llm_rs::kv_router::indexer::RadixTree,
        request: RadixTreeRequest,
    ) {
        match request {
            RadixTreeRequest::FindMatches {
                local_block_hashes,
                early_exit,
                response_tx,
            } => {
                let result = radix_tree.find_matches(local_block_hashes, early_exit);
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::ApplyEvent {
                worker_id,
                kv_cache_event_bytes,
                response_tx,
            } => {
                let result = match serde_json::from_slice::<
                    llm_rs::kv_router::protocols::KvCacheEvent,
                >(&kv_cache_event_bytes)
                {
                    Ok(kv_cache_event) => {
509
510
511
512
                        let router_event = llm_rs::kv_router::protocols::RouterEvent::new(
                            worker_id,
                            kv_cache_event,
                        );
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                        match radix_tree.apply_event(router_event) {
                            Ok(_) => Ok(()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Failed to apply event: {}", e),
                            )),
                        }
                    }
                    Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Failed to deserialize KvCacheEvent: {}",
                        e
                    ))),
                };
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::RemoveWorker {
                worker_id,
                response_tx,
            } => {
                radix_tree.remove_worker(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::ClearAllBlocks {
                worker_id,
                response_tx,
            } => {
                radix_tree.clear_all_blocks(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
                let events = radix_tree.dump_tree_as_events();
                let _ = response_tx.send(events);
            }
            RadixTreeRequest::Shutdown => {
                // This is handled in the main loop
            }
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
549
    }
550
}
Yan Ru Pei's avatar
Yan Ru Pei committed
551

552
553
554
555
556
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
    fn drop(&mut self) {
        // Only need graceful shutdown via RadixTreeRequest::Shutdown
        let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
Yan Ru Pei's avatar
Yan Ru Pei committed
557
558
559
    }
}

560
/// Helper function to create a KV router from an endpoint using the ModelManager
561
562
563
564
565
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
/// - If endpoint name/component contains "prefill", treat as prefill
/// - If router_track_active_blocks is disabled, treat as prefill
/// - Otherwise, default to decode
566
567
568
569
570
async fn create_kv_router_from_endpoint(
    endpoint: &Endpoint,
    block_size: usize,
    kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
571
    // Create ModelManager and use it to create KvRouter (ensures registration)
572
    let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    let endpoint_id = endpoint.inner.id();
    let namespace = endpoint_id.namespace.to_lowercase();
    let component = endpoint_id.component.to_lowercase();
    let name = endpoint_id.name.to_lowercase();
    let endpoint_is_prefill =
        namespace.contains("prefill") || component.contains("prefill") || name.contains("prefill");
    let track_active_blocks = kv_router_config
        .as_ref()
        .map(|cfg| cfg.router_track_active_blocks)
        .unwrap_or(true);
    let worker_type = if endpoint_is_prefill || !track_active_blocks {
        llm_rs::discovery::WORKER_TYPE_PREFILL
    } else {
        llm_rs::discovery::WORKER_TYPE_DECODE
    };
588
    let kv_router = model_manager
589
590
591
592
593
594
        .kv_chooser_for(
            &endpoint.inner,
            block_size as u32,
            kv_router_config,
            worker_type,
        )
595
596
597
598
599
600
        .await
        .map_err(to_pyerr)?;

    Ok(kv_router)
}

601
#[pyclass]
602
603
pub(crate) struct KvRouter {
    inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
604
605
}

606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
    data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
    tracker: &RequestTracker,
) {
    let Some(worker_info) = tracker.get_worker_info() else {
        return;
    };

    let worker_id_json =
        serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");

    if let Some(obj) = data
        .disaggregated_params
        .as_mut()
        .and_then(|p| p.as_object_mut())
    {
        obj.insert("worker_id".to_string(), worker_id_json);
    } else {
        data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
631
// TODO: can this reuse the stream conversion method in Client bindings?
632
impl KvRouter {
Yan Ru Pei's avatar
Yan Ru Pei committed
633
634
635
    /// Helper method to process a request and create a Python async generator
    fn process_request_to_stream<'p>(
        py: Python<'p>,
636
        inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
637
        request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
638
        tracker: Option<Arc<RequestTracker>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
639
640
641
642
    ) -> PyResult<Bound<'p, PyAny>> {
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let single_in = SingleIn::new(request);
            let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
643
            let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
Yan Ru Pei's avatar
Yan Ru Pei committed
644
645
646

            tokio::spawn(async move {
                let mut stream = stream;
647
                let mut first_item = true;
648
                let mut first_token_gauges_observed = false;
649
650
651
652
653
654
655
656
657

                while let Some(mut response) = stream.next().await {
                    if first_item {
                        first_item = false;
                        if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
                            inject_worker_id_from_tracker(data, tracker);
                        }
                    }

658
659
660
661
662
663
664
665
666
667
668
669
670
671
                    if !first_token_gauges_observed {
                        let has_tokens = response
                            .data
                            .as_ref()
                            .map(|d| !d.token_ids.is_empty())
                            .unwrap_or(false);
                        if has_tokens {
                            if let Some(ref tracker) = tracker {
                                tracker.observe_first_token_gauges();
                            }
                            first_token_gauges_observed = true;
                        }
                    }

Yan Ru Pei's avatar
Yan Ru Pei committed
672
673
674
675
676
677
678
679
                    let py_response = Python::with_gil(|py| {
                        pythonize(py, &response.data)
                            .map(|obj| obj.unbind())
                            .map_err(|e| e.to_string())
                    });

                    match py_response {
                        Ok(obj) => {
680
681
                            if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
                                break;
Yan Ru Pei's avatar
Yan Ru Pei committed
682
683
684
685
686
687
688
689
                            }
                        }
                        Err(e) => {
                            tracing::error!("Failed to pythonize response: {}", e);
                            break;
                        }
                    }
                }
690
691
692
693

                if let Some(ref tracker) = tracker {
                    tracker.observe_finish_gauges();
                }
Yan Ru Pei's avatar
Yan Ru Pei committed
694
695
            });

696
            Ok(crate::AsyncResponseStream::new(rx, false))
Yan Ru Pei's avatar
Yan Ru Pei committed
697
698
        })
    }
699
700
701
}

#[pymethods]
702
703
impl KvRouter {
    /// Create a new KvRouter for KV-aware routing to workers.
704
705
706
707
708
709
710
711
    ///
    /// # Arguments
    /// * `endpoint` - The endpoint to route requests to
    /// * `block_size` - KV cache block size for routing decisions
    /// * `kv_router_config` - Configuration for the KV router
    ///
    /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
    /// (contains "prefill") or by `router_track_active_blocks` being disabled.
712
    #[new]
713
    #[pyo3(signature = (endpoint, block_size, kv_router_config))]
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    fn new(
        endpoint: &Endpoint,
        block_size: usize,
        kv_router_config: &super::entrypoint::KvRouterConfig,
    ) -> PyResult<Self> {
        let runtime = pyo3_async_runtimes::tokio::get_runtime();
        runtime.block_on(async move {
            let client = endpoint.inner.client().await.map_err(to_pyerr)?;

            // Create PushRouter with KV router mode
            let push_router = rs::pipeline::PushRouter::<
                llm_rs::protocols::common::preprocessor::PreprocessedRequest,
                rs::protocols::annotated::Annotated<
                    llm_rs::protocols::common::llm_backend::LLMEngineOutput,
                >,
            >::from_client(
                client,
                rs::pipeline::network::egress::push_router::RouterMode::KV,
            )
            .await
            .map_err(to_pyerr)?;

736
737
738
739
740
741
742
            // Create KvRouter using helper function (ensures etcd registration)
            let kv_router = create_kv_router_from_endpoint(
                endpoint,
                block_size,
                Some(kv_router_config.inner()),
            )
            .await?;
743

744
            let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
745
746
747
748
749
750
751
752

            Ok(Self {
                inner: Arc::new(kv_push_router),
            })
        })
    }

    #[allow(clippy::too_many_arguments)]
753
    #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
754
755
756
757
758
759
760
761
762
    fn generate<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        model: String,
        stop_conditions: Option<PyObject>,
        sampling_options: Option<PyObject>,
        output_options: Option<PyObject>,
        router_config_override: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
763
764
        worker_id: Option<WorkerId>,
        dp_rank: Option<DpRank>,
765
        extra_args: Option<PyObject>,
766
767
768
        block_mm_infos: Option<PyObject>,
        multi_modal_data: Option<PyObject>,
        mm_routing_info: Option<PyObject>,
769
770
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the options with defaults
771
772
773
774
775
        let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            StopConditions::default()
        };
776

777
778
779
780
781
        let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            SamplingOptions::default()
        };
782

783
784
785
786
787
        let output_options: OutputOptions = if let Some(obj) = output_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            OutputOptions::default()
        };
788

789
790
791
792
793
794
        let router_config_override: Option<llm_rs::kv_router::RouterConfigOverride> =
            if let Some(obj) = router_config_override {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };
795

796
797
798
799
800
        let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
            Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
        } else {
            None
        };
801

802
803
804
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824

        let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
            if let Some(obj) = multi_modal_data {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };

        let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
            if let Some(obj) = mm_routing_info {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                block_mm_infos.map(
                    |infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
                        routing_token_ids: token_ids.clone(),
                        block_mm_infos: infos,
                    },
                )
            };

825
826
827
        // Create tracker to capture worker routing info from KvRouter
        let tracker = Arc::new(RequestTracker::new());

828
        // Build the PreprocessedRequest
829
830
831
        let mut request_builder =
            llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
        request_builder
832
833
834
835
836
            .model(model)
            .token_ids(token_ids)
            .stop_conditions(stop_conditions)
            .sampling_options(sampling_options)
            .output_options(output_options)
837
            .router_config_override(router_config_override)
838
839
            .multi_modal_data(multi_modal_data)
            .mm_routing_info(mm_routing_info)
840
841
            .extra_args(extra_args)
            .tracker(Some(tracker.clone()));
842

843
844
845
846
847
848
849
850
        // Set routing hints if worker_id or dp_rank is provided
        if worker_id.is_some() || dp_rank.is_some() {
            let routing = llm_rs::protocols::common::preprocessor::RoutingHints {
                backend_instance_id: worker_id,
                dp_rank,
                ..Default::default()
            };
            request_builder.routing(Some(routing));
851
852
853
        }

        let request = request_builder.build().map_err(to_pyerr)?;
854

Yan Ru Pei's avatar
Yan Ru Pei committed
855
        // Use the helper method to process the request
856
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
Yan Ru Pei's avatar
Yan Ru Pei committed
857
    }
858

Yan Ru Pei's avatar
Yan Ru Pei committed
859
860
861
862
863
864
    fn generate_from_request<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the request directly into PreprocessedRequest
865
        let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
866
            depythonize(request.bind(py)).map_err(to_pyerr)?;
867

868
869
870
871
872
873
874
875
876
877
        // Create tracker if not already set, to capture worker routing info
        let tracker = match request.tracker {
            Some(ref t) => t.clone(),
            None => {
                let t = Arc::new(RequestTracker::new());
                request.tracker = Some(t.clone());
                t
            }
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
878
        // Use the helper method to process the request
879
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
880
    }
881

882
    #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None))]
Yan Ru Pei's avatar
Yan Ru Pei committed
883
884
885
886
887
888
    fn best_worker<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        router_config_override: Option<PyObject>,
        request_id: Option<String>,
889
        block_mm_infos: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
890
891
    ) -> PyResult<Bound<'p, PyAny>> {
        let router_config_override = if let Some(obj) = router_config_override {
892
893
894
            let override_config: llm_rs::kv_router::RouterConfigOverride =
                depythonize(obj.bind(py)).map_err(to_pyerr)?;
            Some(override_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
895
896
897
898
        } else {
            None
        };

899
900
901
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
902

Yan Ru Pei's avatar
Yan Ru Pei committed
903
904
905
906
907
908
909
910
        let chooser = self.inner.chooser.clone();
        let update_states = request_id.is_some();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let (best_worker, overlap_blocks) = chooser
                .find_best_match(
                    request_id.as_deref(),
                    &token_ids,
911
                    block_mm_infos.as_deref(),
Yan Ru Pei's avatar
Yan Ru Pei committed
912
913
                    router_config_override.as_ref(),
                    update_states,
914
                    None, // lora_name not exposed in Python API yet
915
                    0.0,
Yan Ru Pei's avatar
Yan Ru Pei committed
916
917
918
919
920
921
922
923
                )
                .await
                .map_err(to_pyerr)?;

            Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
        })
    }

924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    /// Mark prefill as completed for a request
    fn mark_prefill_complete<'p>(
        &self,
        py: Python<'p>,
        request_id: String,
    ) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser
                .mark_prefill_completed(&request_id)
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

    /// Free a request by its ID, signaling the router to release resources
    fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser.free(&request_id).await.map_err(to_pyerr)?;
            Ok(())
        })
    }

951
952
953
954
955
    fn get_potential_loads<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
    ) -> PyResult<Bound<'p, PyAny>> {
956
        let chooser = self.inner.chooser.clone();
957
958

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
959
            let loads = chooser
960
                .get_potential_loads(&token_ids, None)
961
962
963
                .await
                .map_err(to_pyerr)?;

Yan Ru Pei's avatar
Yan Ru Pei committed
964
            // Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
965
966
967
968
969
970
971
972
973
            // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
            Python::with_gil(|py| {
                pythonize(py, &loads)
                    .map(|obj| obj.unbind())
                    .map_err(to_pyerr)
            })
        })
    }

974
975
    /// Dump all events from the KV router's indexer as a JSON string
    fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
976
        let chooser = self.inner.chooser.clone();
977
978

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
979
            let events = chooser.dump_events().await.map_err(to_pyerr)?;
980
981
982
983
984
985
            // Serialize to JSON string
            let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
            Ok(json_str)
        })
    }
}