kv.rs 33.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use pythonize::{depythonize, pythonize};
5
use std::collections::HashMap;
6
use std::sync::Arc;
7
use std::sync::atomic::AtomicU32;
8
use std::sync::mpsc;
9
use tokio_stream::StreamExt;
10

11
use super::*;
12
use crate::Endpoint;
13
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
Yan Ru Pei's avatar
Yan Ru Pei committed
14
use rs::pipeline::{AsyncEngine, SingleIn};
15
use rs::protocols::annotated::Annotated as RsAnnotated;
16
use tracing;
17

18
use llm_rs::kv_router::KvPushRouter as RsKvPushRouter;
19
use llm_rs::kv_router::protocols::*;
20
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
21
use llm_rs::protocols::common::timing::RequestTracker;
22
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
23
use serde_json::json;
24

25
26
27
28
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
    depythonize(obj).map_err(to_pyerr)
}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#[pyfunction]
#[pyo3(name = "start_kv_block_indexer", signature = (endpoint, block_size, kv_router_config))]
pub fn start_kv_block_indexer_py<'p>(
    py: Python<'p>,
    endpoint: &Endpoint,
    block_size: u32,
    kv_router_config: &super::entrypoint::KvRouterConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let component = endpoint.inner.component().clone();
    let config = kv_router_config.inner();
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        llm_rs::kv_router::indexer_standalone::start_kv_block_indexer(
            &component, &config, block_size,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

Yan Ru Pei's avatar
Yan Ru Pei committed
49
#[pyfunction]
50
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
51
52
53
54
55
pub fn compute_block_hash_for_seq_py(
    _py: Python,
    tokens: Vec<u32>,
    kv_block_size: usize,
    block_mm_infos: Option<Bound<PyAny>>,
56
    lora_name: Option<String>,
57
) -> PyResult<Vec<u64>> {
Yan Ru Pei's avatar
Yan Ru Pei committed
58
    if kv_block_size == 0 {
59
60
61
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "kv_block_size cannot be 0",
        ));
Yan Ru Pei's avatar
Yan Ru Pei committed
62
63
    }

64
    let mm_infos = block_mm_infos
65
        .as_ref()
66
        .map(depythonize_block_mm_infos)
67
68
        .transpose()?;

69
70
71
72
73
74
    let hashes = compute_block_hash_for_seq(
        &tokens,
        kv_block_size as u32,
        mm_infos.as_deref(),
        lora_name.as_deref(),
    );
75

Yan Ru Pei's avatar
Yan Ru Pei committed
76
77
78
    Ok(hashes.into_iter().map(|h| h.0).collect())
}

GuanLuo's avatar
GuanLuo committed
79
#[pyclass]
80
81
pub(crate) struct WorkerMetricsPublisher {
    inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
GuanLuo's avatar
GuanLuo committed
82
83
84
}

#[pymethods]
85
impl WorkerMetricsPublisher {
GuanLuo's avatar
GuanLuo committed
86
87
    #[new]
    fn new() -> PyResult<Self> {
88
89
        let inner =
            llm_rs::kv_router::publisher::WorkerMetricsPublisher::new().map_err(to_pyerr)?;
GuanLuo's avatar
GuanLuo committed
90
91
92
93
94
        Ok(Self {
            inner: inner.into(),
        })
    }

95
    #[pyo3(signature = (endpoint))]
Alec's avatar
Alec committed
96
    fn create_endpoint<'p>(
GuanLuo's avatar
GuanLuo committed
97
98
        &self,
        py: Python<'p>,
99
        endpoint: Endpoint,
GuanLuo's avatar
GuanLuo committed
100
101
    ) -> PyResult<Bound<'p, PyAny>> {
        let rs_publisher = self.inner.clone();
102
        let rs_component = endpoint.inner.component().clone();
GuanLuo's avatar
GuanLuo committed
103
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
104
            rs_publisher
105
                .create_endpoint(rs_component)
GuanLuo's avatar
GuanLuo committed
106
107
108
109
110
111
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

112
113
114
115
116
117
118
    /// Publish worker metrics for load monitoring.
    ///
    /// # Arguments
    /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
    /// * `active_decode_blocks` - Number of active KV cache blocks
    #[pyo3(signature = (dp_rank, active_decode_blocks))]
    fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> {
GuanLuo's avatar
GuanLuo committed
119
        self.inner
120
            .publish(dp_rank, active_decode_blocks)
GuanLuo's avatar
GuanLuo committed
121
122
123
            .map_err(to_pyerr)
    }
}
124

125
126
127
#[pyclass]
pub(crate) struct KvEventPublisher {
    inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
128
    kv_block_size: usize,
Yan Ru Pei's avatar
Yan Ru Pei committed
129
    dp_rank: DpRank,
130
    warning_count: Arc<AtomicU32>,
131
132
133
134
135
}

#[pymethods]
impl KvEventPublisher {
    #[new]
136
    #[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None))]
137
    fn new(
138
        endpoint: Endpoint,
139
140
141
        worker_id: WorkerId,
        kv_block_size: usize,
        dp_rank: DpRank,
142
        enable_local_indexer: bool,
143
144
        zmq_endpoint: Option<String>,
        zmq_topic: Option<String>,
145
    ) -> PyResult<Self> {
146
147
        let _ = worker_id;

148
149
        let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
            endpoint: ep,
150
151
            topic: zmq_topic.unwrap_or_default(),
        });
152

Yan Ru Pei's avatar
Yan Ru Pei committed
153
154
155
156
        if kv_block_size == 0 {
            return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
        }

157
158
159
        // Extract component from endpoint
        let component = endpoint.inner.component().clone();

160
        let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
161
            component,
162
            kv_block_size as u32,
163
            source_config,
164
            enable_local_indexer,
165
            dp_rank,
166
167
        )
        .map_err(to_pyerr)?;
168

169
170
        Ok(Self {
            inner: inner.into(),
171
            kv_block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
172
            dp_rank,
173
            warning_count: Arc::new(AtomicU32::new(0)),
174
175
176
177
        })
    }

    #[allow(clippy::too_many_arguments)]
178
    #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
179
    fn publish_stored(
180
        &self,
181
        py: Python,
182
183
        token_ids: Vec<u32>,
        num_block_tokens: Vec<u64>,
184
185
        block_hashes: Vec<i64>,
        parent_hash: Option<i64>,
186
        block_mm_infos: Option<Bound<PyAny>>,
187
        lora_name: Option<String>,
188
    ) -> PyResult<()> {
189
190
191
192
193
        let kv_block_size = self.kv_block_size as u32;
        let dp_rank = self.dp_rank;
        let warning_count = self.warning_count.clone();
        let inner = self.inner.clone();

194
195
        let event_id = inner.next_event_id();

196
        let mm_infos = block_mm_infos
197
            .as_ref()
198
            .map(depythonize_block_mm_infos)
199
200
            .transpose()?;

201
202
203
204
205
206
207
208
209
210
211
        py.allow_threads(|| {
            let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
                    blocks: create_stored_blocks(
                        kv_block_size,
                        &token_ids,
                        &num_block_tokens,
                        &block_hashes_u64,
212
                        lora_name.as_deref(),
213
                        &warning_count,
214
                        mm_infos.as_deref(),
215
216
217
218
219
220
221
                    ),
                }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
222
223
    }

224
    fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
225
226
227
        let dp_rank = self.dp_rank;
        let inner = self.inner.clone();

228
229
230
        // Use shared monotonic event_id counter from the inner publisher
        let event_id = inner.next_event_id();

231
232
233
234
235
236
237
238
239
240
241
242
243
        py.allow_threads(|| {
            let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
                .into_iter()
                .map(ExternalSequenceBlockHash::from)
                .collect();
            let event = KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
                dp_rank,
            };

            inner.publish(event).map_err(to_pyerr)
        })
244
    }
245
246
247
248
249
250
251
252

    fn shutdown(&mut self) {
        // If no other Arc clones exist, shut down eagerly.
        // Otherwise the Drop impl handles cleanup when the last reference is freed.
        if let Some(inner) = Arc::get_mut(&mut self.inner) {
            inner.shutdown();
        }
    }
253
254
}

255
256
257
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
258
    inner: llm_rs::kv_router::protocols::OverlapScores,
259
260
261
262
263
}

#[pymethods]
impl OverlapScores {
    #[getter]
264
    fn scores(&self) -> HashMap<(u64, u32), u32> {
Yan Ru Pei's avatar
Yan Ru Pei committed
265
266
267
268
269
270
        // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
        self.inner
            .scores
            .iter()
            .map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
            .collect()
271
272
273
274
275
276
277
278
    }

    #[getter]
    fn frequencies(&self) -> Vec<usize> {
        self.inner.frequencies.clone()
    }
}

279
280
281
282
283
#[derive(Debug)]
enum RadixTreeRequest {
    FindMatches {
        local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
        early_exit: bool,
284
        response_tx: mpsc::SyncSender<llm_rs::kv_router::protocols::OverlapScores>,
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    },
    ApplyEvent {
        worker_id: WorkerId,
        kv_cache_event_bytes: Vec<u8>,
        response_tx: mpsc::SyncSender<PyResult<()>>,
    },
    RemoveWorker {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    ClearAllBlocks {
        worker_id: WorkerId,
        response_tx: mpsc::SyncSender<()>,
    },
    DumpTreeAsEvents {
300
        response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::protocols::RouterEvent>>,
301
302
303
304
305
306
    },
    Shutdown,
}

// NOTE: RadixTree is now thread-safe with pure sync patterns
#[pyclass]
Yan Ru Pei's avatar
Yan Ru Pei committed
307
pub(crate) struct RadixTree {
308
    request_tx: mpsc::Sender<RadixTreeRequest>,
Yan Ru Pei's avatar
Yan Ru Pei committed
309
310
311
312
313
314
315
316
}

#[pymethods]
impl RadixTree {
    #[new]
    #[pyo3(signature = (expiration_duration_secs=None))]
    fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
        let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

        let (request_tx, request_rx) = mpsc::channel::<RadixTreeRequest>();

        // Spawn dedicated thread with simplified sync processing
        std::thread::spawn(move || {
            let mut radix_tree =
                llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);

            loop {
                match request_rx.recv() {
                    Ok(RadixTreeRequest::Shutdown) => {
                        tracing::debug!("RadixTree thread received shutdown request");
                        break;
                    }
                    Ok(request) => {
                        Self::handle_request(&mut radix_tree, request);
                    }
                    Err(mpsc::RecvError) => {
                        tracing::debug!("RadixTree request channel disconnected");
                        break;
                    }
                }
            }
        });

        Ok(Self { request_tx })
Yan Ru Pei's avatar
Yan Ru Pei committed
343
344
345
346
347
    }

    #[pyo3(signature = (sequence, early_exit=false))]
    fn find_matches(
        &self,
348
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
349
350
351
        sequence: Vec<u64>,
        early_exit: bool,
    ) -> PyResult<OverlapScores> {
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let local_block_hashes = py.allow_threads(|| {
            sequence
                .into_iter()
                .map(llm_rs::kv_router::protocols::LocalBlockHash)
                .collect()
        });

        let request = RadixTreeRequest::FindMatches {
            local_block_hashes,
            early_exit,
            response_tx,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })?;

        Ok(OverlapScores { inner: result })
Yan Ru Pei's avatar
Yan Ru Pei committed
381
382
383
    }

    fn apply_event(
384
385
        &self,
        py: Python,
Yan Ru Pei's avatar
Yan Ru Pei committed
386
        worker_id: WorkerId,
Yan Ru Pei's avatar
Yan Ru Pei committed
387
388
        kv_cache_event_bytes: &[u8],
    ) -> PyResult<()> {
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ApplyEvent {
            worker_id,
            kv_cache_event_bytes: kv_cache_event_bytes.to_vec(),
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        let result = py.allow_threads(move || response_rx.recv());

        result.map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
        })?
    }

    fn remove_worker(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::RemoveWorker {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }

    fn clear_all_blocks(&self, py: Python, worker_id: WorkerId) -> PyResult<()> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::ClearAllBlocks {
            worker_id,
            response_tx,
        };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "RadixTree background task has shut down",
            )
        })?;

        // Release GIL while waiting for response
        py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("RadixTree request was cancelled")
            })
        })
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
454

455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    fn dump_tree_as_events(&self, py: Python) -> PyResult<Vec<String>> {
        let (response_tx, response_rx) = mpsc::sync_channel(1);

        let request = RadixTreeRequest::DumpTreeAsEvents { response_tx };

        self.request_tx.send(request).map_err(|_| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Failed to send dump tree request")
        })?;

        // Release GIL while waiting for response from dedicated thread
        let events = py.allow_threads(move || {
            response_rx.recv().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Failed to receive dump tree response",
                )
            })
        })?;

        // Serialize RouterEvent structs to JSON strings with GIL released
        py.allow_threads(move || {
            events
                .into_iter()
                .map(|event| {
                    serde_json::to_string(&event).map_err(|e| {
                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                            "Failed to serialize event to JSON: {}",
                            e
                        ))
                    })
                })
                .collect::<Result<Vec<String>, PyErr>>()
        })
Yan Ru Pei's avatar
Yan Ru Pei committed
487
    }
488
}
Yan Ru Pei's avatar
Yan Ru Pei committed
489

490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
impl RadixTree {
    fn handle_request(
        radix_tree: &mut llm_rs::kv_router::indexer::RadixTree,
        request: RadixTreeRequest,
    ) {
        match request {
            RadixTreeRequest::FindMatches {
                local_block_hashes,
                early_exit,
                response_tx,
            } => {
                let result = radix_tree.find_matches(local_block_hashes, early_exit);
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::ApplyEvent {
                worker_id,
                kv_cache_event_bytes,
                response_tx,
            } => {
                let result = match serde_json::from_slice::<
                    llm_rs::kv_router::protocols::KvCacheEvent,
                >(&kv_cache_event_bytes)
                {
                    Ok(kv_cache_event) => {
514
515
516
517
                        let router_event = llm_rs::kv_router::protocols::RouterEvent::new(
                            worker_id,
                            kv_cache_event,
                        );
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
                        match radix_tree.apply_event(router_event) {
                            Ok(_) => Ok(()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Failed to apply event: {}", e),
                            )),
                        }
                    }
                    Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Failed to deserialize KvCacheEvent: {}",
                        e
                    ))),
                };
                let _ = response_tx.send(result);
            }
            RadixTreeRequest::RemoveWorker {
                worker_id,
                response_tx,
            } => {
                radix_tree.remove_worker(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::ClearAllBlocks {
                worker_id,
                response_tx,
            } => {
                radix_tree.clear_all_blocks(worker_id);
                let _ = response_tx.send(());
            }
            RadixTreeRequest::DumpTreeAsEvents { response_tx } => {
                let events = radix_tree.dump_tree_as_events();
                let _ = response_tx.send(events);
            }
            RadixTreeRequest::Shutdown => {
                // This is handled in the main loop
            }
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
554
    }
555
}
Yan Ru Pei's avatar
Yan Ru Pei committed
556

557
558
559
560
561
// Cleanup when RadixTree is dropped
impl Drop for RadixTree {
    fn drop(&mut self) {
        // Only need graceful shutdown via RadixTreeRequest::Shutdown
        let _ = self.request_tx.send(RadixTreeRequest::Shutdown);
Yan Ru Pei's avatar
Yan Ru Pei committed
562
563
564
    }
}

565
/// Helper function to create a KV router from an endpoint using the ModelManager
566
567
568
569
570
/// to ensure proper etcd registration.
/// Infers worker type using endpoint naming and router config:
/// - If endpoint name/component contains "prefill", treat as prefill
/// - If router_track_active_blocks is disabled, treat as prefill
/// - Otherwise, default to decode
571
572
573
574
575
async fn create_kv_router_from_endpoint(
    endpoint: &Endpoint,
    block_size: usize,
    kv_router_config: Option<llm_rs::kv_router::KvRouterConfig>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
576
    // Create ModelManager and use it to create KvRouter (ensures registration)
577
    let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    let endpoint_id = endpoint.inner.id();
    let namespace = endpoint_id.namespace.to_lowercase();
    let component = endpoint_id.component.to_lowercase();
    let name = endpoint_id.name.to_lowercase();
    let endpoint_is_prefill =
        namespace.contains("prefill") || component.contains("prefill") || name.contains("prefill");
    let track_active_blocks = kv_router_config
        .as_ref()
        .map(|cfg| cfg.router_track_active_blocks)
        .unwrap_or(true);
    let worker_type = if endpoint_is_prefill || !track_active_blocks {
        llm_rs::discovery::WORKER_TYPE_PREFILL
    } else {
        llm_rs::discovery::WORKER_TYPE_DECODE
    };
593
    let kv_router = model_manager
594
595
596
597
598
599
        .kv_chooser_for(
            &endpoint.inner,
            block_size as u32,
            kv_router_config,
            worker_type,
        )
600
601
602
603
604
605
        .await
        .map_err(to_pyerr)?;

    Ok(kv_router)
}

606
#[pyclass]
607
608
pub(crate) struct KvRouter {
    inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
609
610
}

611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
/// Inject worker_id info from tracker into response's disaggregated_params.
/// This is needed for Python bindings to expose worker routing info since
/// the raw LLMEngineOutput doesn't go through DeltaGenerator (which adds nvext).
fn inject_worker_id_from_tracker(
    data: &mut llm_rs::protocols::common::llm_backend::LLMEngineOutput,
    tracker: &RequestTracker,
) {
    let Some(worker_info) = tracker.get_worker_info() else {
        return;
    };

    let worker_id_json =
        serde_json::to_value(&worker_info).expect("WorkerIdInfo serialization should not fail");

    if let Some(obj) = data
        .disaggregated_params
        .as_mut()
        .and_then(|p| p.as_object_mut())
    {
        obj.insert("worker_id".to_string(), worker_id_json);
    } else {
        data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
636
// TODO: can this reuse the stream conversion method in Client bindings?
637
impl KvRouter {
Yan Ru Pei's avatar
Yan Ru Pei committed
638
639
640
    /// Helper method to process a request and create a Python async generator
    fn process_request_to_stream<'p>(
        py: Python<'p>,
641
        inner: Arc<RsKvPushRouter>,
Yan Ru Pei's avatar
Yan Ru Pei committed
642
        request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
643
        tracker: Option<Arc<RequestTracker>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
644
645
646
647
    ) -> PyResult<Bound<'p, PyAny>> {
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let single_in = SingleIn::new(request);
            let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
648
            let (tx, rx) = tokio::sync::mpsc::channel::<RsAnnotated<PyObject>>(100);
Yan Ru Pei's avatar
Yan Ru Pei committed
649
650
651

            tokio::spawn(async move {
                let mut stream = stream;
652
                let mut first_item = true;
653
                let mut first_token_gauges_observed = false;
654
655
656
657
658
659
660
661
662

                while let Some(mut response) = stream.next().await {
                    if first_item {
                        first_item = false;
                        if let (Some(tracker), Some(data)) = (&tracker, &mut response.data) {
                            inject_worker_id_from_tracker(data, tracker);
                        }
                    }

663
664
665
666
667
668
669
670
671
672
673
674
675
676
                    if !first_token_gauges_observed {
                        let has_tokens = response
                            .data
                            .as_ref()
                            .map(|d| !d.token_ids.is_empty())
                            .unwrap_or(false);
                        if has_tokens {
                            if let Some(ref tracker) = tracker {
                                tracker.observe_first_token_gauges();
                            }
                            first_token_gauges_observed = true;
                        }
                    }

Yan Ru Pei's avatar
Yan Ru Pei committed
677
678
679
680
681
682
683
684
                    let py_response = Python::with_gil(|py| {
                        pythonize(py, &response.data)
                            .map(|obj| obj.unbind())
                            .map_err(|e| e.to_string())
                    });

                    match py_response {
                        Ok(obj) => {
685
686
                            if tx.send(RsAnnotated::from_data(obj)).await.is_err() {
                                break;
Yan Ru Pei's avatar
Yan Ru Pei committed
687
688
689
690
691
692
693
694
                            }
                        }
                        Err(e) => {
                            tracing::error!("Failed to pythonize response: {}", e);
                            break;
                        }
                    }
                }
695
696
697
698

                if let Some(ref tracker) = tracker {
                    tracker.observe_finish_gauges();
                }
Yan Ru Pei's avatar
Yan Ru Pei committed
699
700
            });

701
            Ok(crate::AsyncResponseStream::new(rx, false))
Yan Ru Pei's avatar
Yan Ru Pei committed
702
703
        })
    }
704
705
706
}

#[pymethods]
707
708
impl KvRouter {
    /// Create a new KvRouter for KV-aware routing to workers.
709
710
711
712
713
714
715
716
    ///
    /// # Arguments
    /// * `endpoint` - The endpoint to route requests to
    /// * `block_size` - KV cache block size for routing decisions
    /// * `kv_router_config` - Configuration for the KV router
    ///
    /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
    /// (contains "prefill") or by `router_track_active_blocks` being disabled.
717
    #[new]
718
    #[pyo3(signature = (endpoint, block_size, kv_router_config))]
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
    fn new(
        endpoint: &Endpoint,
        block_size: usize,
        kv_router_config: &super::entrypoint::KvRouterConfig,
    ) -> PyResult<Self> {
        let runtime = pyo3_async_runtimes::tokio::get_runtime();
        runtime.block_on(async move {
            let client = endpoint.inner.client().await.map_err(to_pyerr)?;

            // Create PushRouter with KV router mode
            let push_router = rs::pipeline::PushRouter::<
                llm_rs::protocols::common::preprocessor::PreprocessedRequest,
                rs::protocols::annotated::Annotated<
                    llm_rs::protocols::common::llm_backend::LLMEngineOutput,
                >,
            >::from_client(
                client,
                rs::pipeline::network::egress::push_router::RouterMode::KV,
            )
            .await
            .map_err(to_pyerr)?;

741
742
743
744
745
746
747
            // Create KvRouter using helper function (ensures etcd registration)
            let kv_router = create_kv_router_from_endpoint(
                endpoint,
                block_size,
                Some(kv_router_config.inner()),
            )
            .await?;
748

749
            let kv_push_router = RsKvPushRouter::new(push_router, kv_router);
750
751
752
753
754
755
756
757

            Ok(Self {
                inner: Arc::new(kv_push_router),
            })
        })
    }

    #[allow(clippy::too_many_arguments)]
758
    #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None, block_mm_infos=None, multi_modal_data=None, mm_routing_info=None))]
759
760
761
762
763
764
765
766
767
    fn generate<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        model: String,
        stop_conditions: Option<PyObject>,
        sampling_options: Option<PyObject>,
        output_options: Option<PyObject>,
        router_config_override: Option<PyObject>,
Yan Ru Pei's avatar
Yan Ru Pei committed
768
769
        worker_id: Option<WorkerId>,
        dp_rank: Option<DpRank>,
770
        extra_args: Option<PyObject>,
771
772
773
        block_mm_infos: Option<PyObject>,
        multi_modal_data: Option<PyObject>,
        mm_routing_info: Option<PyObject>,
774
775
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the options with defaults
776
777
778
779
780
        let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            StopConditions::default()
        };
781

782
783
784
785
786
        let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            SamplingOptions::default()
        };
787

788
789
790
791
792
        let output_options: OutputOptions = if let Some(obj) = output_options {
            depythonize(obj.bind(py)).map_err(to_pyerr)?
        } else {
            OutputOptions::default()
        };
793

794
795
796
797
798
799
        let router_config_override: Option<llm_rs::kv_router::RouterConfigOverride> =
            if let Some(obj) = router_config_override {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };
800

801
802
803
804
805
        let extra_args: Option<serde_json::Value> = if let Some(obj) = extra_args {
            Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
        } else {
            None
        };
806

807
808
809
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829

        let multi_modal_data: Option<llm_rs::protocols::common::preprocessor::MultimodalDataMap> =
            if let Some(obj) = multi_modal_data {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                None
            };

        let mm_routing_info: Option<llm_rs::protocols::common::preprocessor::MmRoutingInfo> =
            if let Some(obj) = mm_routing_info {
                Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
            } else {
                block_mm_infos.map(
                    |infos| llm_rs::protocols::common::preprocessor::MmRoutingInfo {
                        routing_token_ids: token_ids.clone(),
                        block_mm_infos: infos,
                    },
                )
            };

830
831
832
        // Create tracker to capture worker routing info from KvRouter
        let tracker = Arc::new(RequestTracker::new());

833
        // Build the PreprocessedRequest
834
835
836
        let mut request_builder =
            llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
        request_builder
837
838
839
840
841
            .model(model)
            .token_ids(token_ids)
            .stop_conditions(stop_conditions)
            .sampling_options(sampling_options)
            .output_options(output_options)
842
            .router_config_override(router_config_override)
843
844
            .multi_modal_data(multi_modal_data)
            .mm_routing_info(mm_routing_info)
845
846
            .extra_args(extra_args)
            .tracker(Some(tracker.clone()));
847

848
849
850
851
852
853
854
855
        // Set routing hints if worker_id or dp_rank is provided
        if worker_id.is_some() || dp_rank.is_some() {
            let routing = llm_rs::protocols::common::preprocessor::RoutingHints {
                backend_instance_id: worker_id,
                dp_rank,
                ..Default::default()
            };
            request_builder.routing(Some(routing));
856
857
858
        }

        let request = request_builder.build().map_err(to_pyerr)?;
859

Yan Ru Pei's avatar
Yan Ru Pei committed
860
        // Use the helper method to process the request
861
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
Yan Ru Pei's avatar
Yan Ru Pei committed
862
    }
863

Yan Ru Pei's avatar
Yan Ru Pei committed
864
865
866
867
868
869
    fn generate_from_request<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
    ) -> PyResult<Bound<'p, PyAny>> {
        // Depythonize the request directly into PreprocessedRequest
870
        let mut request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
871
            depythonize(request.bind(py)).map_err(to_pyerr)?;
872

873
874
875
876
877
878
879
880
881
882
        // Create tracker if not already set, to capture worker routing info
        let tracker = match request.tracker {
            Some(ref t) => t.clone(),
            None => {
                let t = Arc::new(RequestTracker::new());
                request.tracker = Some(t.clone());
                t
            }
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
883
        // Use the helper method to process the request
884
        Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
885
    }
886

887
    #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
Yan Ru Pei's avatar
Yan Ru Pei committed
888
889
890
891
892
893
    fn best_worker<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
        router_config_override: Option<PyObject>,
        request_id: Option<String>,
894
        block_mm_infos: Option<PyObject>,
895
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
896
897
    ) -> PyResult<Bound<'p, PyAny>> {
        let router_config_override = if let Some(obj) = router_config_override {
898
899
900
            let override_config: llm_rs::kv_router::RouterConfigOverride =
                depythonize(obj.bind(py)).map_err(to_pyerr)?;
            Some(override_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
901
902
903
904
        } else {
            None
        };

905
906
907
        let block_mm_infos = block_mm_infos
            .map(|obj| depythonize_block_mm_infos(obj.bind(py)))
            .transpose()?;
908

Yan Ru Pei's avatar
Yan Ru Pei committed
909
910
911
912
913
914
915
916
        let chooser = self.inner.chooser.clone();
        let update_states = request_id.is_some();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let (best_worker, overlap_blocks) = chooser
                .find_best_match(
                    request_id.as_deref(),
                    &token_ids,
917
                    block_mm_infos.as_deref(),
Yan Ru Pei's avatar
Yan Ru Pei committed
918
919
                    router_config_override.as_ref(),
                    update_states,
920
                    lora_name,
921
                    0.0,
922
                    None, // allowed_worker_ids not exposed in Python API yet
Yan Ru Pei's avatar
Yan Ru Pei committed
923
924
925
926
927
928
929
930
                )
                .await
                .map_err(to_pyerr)?;

            Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
        })
    }

931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
    /// Mark prefill as completed for a request
    fn mark_prefill_complete<'p>(
        &self,
        py: Python<'p>,
        request_id: String,
    ) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser
                .mark_prefill_completed(&request_id)
                .await
                .map_err(to_pyerr)?;
            Ok(())
        })
    }

    /// Free a request by its ID, signaling the router to release resources
    fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
        let chooser = self.inner.chooser.clone();

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            chooser.free(&request_id).await.map_err(to_pyerr)?;
            Ok(())
        })
    }

958
    #[pyo3(signature = (token_ids, lora_name=None))]
959
960
961
962
    fn get_potential_loads<'p>(
        &self,
        py: Python<'p>,
        token_ids: Vec<u32>,
963
        lora_name: Option<String>,
964
    ) -> PyResult<Bound<'p, PyAny>> {
965
        let chooser = self.inner.chooser.clone();
966
967

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
968
            let loads = chooser
969
                .get_potential_loads(&token_ids, None, lora_name.as_deref())
970
971
972
                .await
                .map_err(to_pyerr)?;

Yan Ru Pei's avatar
Yan Ru Pei committed
973
            // Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
974
975
976
977
978
979
980
981
982
            // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
            Python::with_gil(|py| {
                pythonize(py, &loads)
                    .map(|obj| obj.unbind())
                    .map_err(to_pyerr)
            })
        })
    }

983
984
    /// Dump all events from the KV router's indexer as a JSON string
    fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
985
        let chooser = self.inner.chooser.clone();
986
987

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
988
            let events = chooser.dump_events().await.map_err(to_pyerr)?;
989
990
991
992
993
994
            // Serialize to JSON string
            let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
            Ok(json_str)
        })
    }
}