lib.rs 49.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
7
use std::borrow::Cow;
8
use std::ffi::CStr;
9
use std::ptr;
10
use std::sync::Arc;
11
use std::sync::atomic::{AtomicU32, Ordering};
12
use std::time::Duration;
13

14
15
16
17
18
use dynamo_kv_router::{
    config::{KvRouterConfig, RouterConfigOverride},
    protocols::*,
};
use dynamo_llm::kv_router::publisher::KvEventPublisher;
19
use dynamo_llm::preprocessor::OpenAIPreprocessor;
20
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
21
use dynamo_runtime::{DistributedRuntime, Worker};
22
23
24
25

use dynamo_runtime::Runtime;

use dynamo_llm::discovery::{ModelManager, WORKER_TYPE_DECODE};
26
use dynamo_llm::kv_router::{KvRouter, PrefillRouter};
27
28
use dynamo_runtime::pipeline::RouterMode;

29
30
use std::collections::HashSet;

31
32
33
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
34
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
/// - or the resulting string is empty/whitespace.
#[inline]
unsafe fn cstr_or_default<'a>(ptr: *const c_char, default_val: &'a str) -> Cow<'a, str> {
    if ptr.is_null() {
        return Cow::from(default_val);
    }
    match unsafe { CStr::from_ptr(ptr) }
        .to_str()
        .ok()
        .map(|s| s.trim())
    {
        Some(s) if !s.is_empty() => Cow::from(s.to_owned()),
        _ => Cow::from(default_val),
    }
}

55
56
57
58
59
60
61
fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

62
63
64
    if tracing::subscriber::set_global_default(subscriber).is_ok() {
        tracing::debug!("Tracing initialized");
    }
65
66
67
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
68
pub enum DynamoLlmResult {
69
70
71
72
    OK = 0,
    ERR = 1,
}

73
74
75
76
77
78
79
80
81
// Wait for the discovery daemon to sync indefinitely and return at least one instance.
// This is because the Model info is registered by workers and it may take up to 30 min for the model weights to load and for the worker to register itself.
// The waiting timeout is implemented in the Kubernetes StartupProbe. The EPP waiting loops runs indefinitely, the Probe is a single source of truth with when to kill the EPP if discovery fails.
// If workers are not found within the probe's failureThreshold × periodSeconds, the pod will be killed and restarted.
// Users can adjust the StartupProbe waiting timed in the DGD for large models.
async fn wait_for_discovery_sync(drt: &DistributedRuntime) -> usize {
    tracing::info!(
        "Waiting for discovery to sync (no timeout - controlled by K8s StartupProbe)..."
    );
82
83
84
85
86
87
88
89
90
91
92
93
    let discovery = drt.discovery();

    loop {
        match discovery.list(DiscoveryQuery::AllModels).await {
            Ok(instances) if !instances.is_empty() => {
                return instances.len();
            }
            Ok(_) => {
                tracing::debug!("No instances yet, waiting...");
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
            }
            Err(e) => {
94
95
96
                // Log and continue - transient errors shouldn't stop the wait
                tracing::warn!("Discovery list error: {}, retrying...", e);
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
97
98
99
100
101
            }
        }
    }
}

102
/// # Safety
GuanLuo's avatar
GuanLuo committed
103
/// the namespace_c_str and component_c_str are passed as pointers to C strings
104
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
105
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
106
107
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
108
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
109
) -> DynamoLlmResult {
110
111
112
113
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
114
            tracing::error!(error = ?e, "Failed to initialize runtime (Worker::from_settings)");
Neelay Shah's avatar
Neelay Shah committed
115
            return DynamoLlmResult::ERR;
116
117
118
119
120
121
122
123
124
125
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
126
            Ok(drt) => {
127
                // Wait for discovery to sync before returning.
128
                // This is needed because dynamo_create_worker_selection_pipeline() is called
129
130
131
132
                // immediately after, and it needs discovery.list() to return data.
                // The discovery daemon takes time to query K8s and returns async, so we need to wait.
                // Note: This waits indefinitely - the K8s StartupProbe is the timeout mechanism.
                wait_for_discovery_sync(drt).await;
133
134
                Ok(())
            }
135
            Err(e) => {
136
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
137
                Err(DynamoLlmResult::ERR)
138
139
140
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
141
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
142
143
        Ok(s) => s.to_string(),
        Err(e) => {
144
            tracing::error!(error = ?e, "Failed to convert C string to Rust string (namespace)");
Neelay Shah's avatar
Neelay Shah committed
145
            return DynamoLlmResult::ERR;
146
147
148
        }
    };

149
150
151
152
153
    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();
154
155

    match result {
156
        Ok(_) => match KV_PUB.get_or_try_init(move || {
Yan Ru Pei's avatar
Yan Ru Pei committed
157
            dynamo_create_kv_publisher(namespace, component, kv_block_size)
158
        }) {
Neelay Shah's avatar
Neelay Shah committed
159
            Ok(_) => DynamoLlmResult::OK,
160
            Err(e) => {
161
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
162
                DynamoLlmResult::ERR
163
164
165
166
167
168
            }
        },
        Err(e) => e,
    }
}

169
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
170
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
171
172
173
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
174
            tracing::error!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
175
            return DynamoLlmResult::ERR;
176
177
178
179
180
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
181
    DynamoLlmResult::OK
182
183
}

184
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
185
186
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
187
188
189
190
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
191
192
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
193
194
195
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
196
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
197
198
    namespace: String,
    component: String,
199
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
200
) -> Result<KvEventPublisher, anyhow::Error> {
201
    tracing::info!("Creating KV Publisher for model: {}", component);
202
203
204
205
206
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
207
            let backend = drt.namespace(namespace)?.component(component)?;
Yan Ru Pei's avatar
Yan Ru Pei committed
208
            KvEventPublisher::new(backend, kv_block_size, None)
209
210
211
212
213
214
215
216
217
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
218
    kv_block_size: u32,
219
    lora_name: Option<&str>,
220
) -> KvCacheStoredBlockData {
221
222
223
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
224
        None,
225
        lora_name,
226
    )[0];
227
228
229
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
230
        mm_extra_info: None,
231
232
233
234
235
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
236
    kv_params: DynamoKvStoredEventParams,
237
    kv_block_size: u32,
238
239
240
241
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
242
243
244
245
246
247
248
249
250
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

251
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
252
253
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
254
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
255
256
257
                })
                .is_ok()
            {
258
                tracing::warn!(
259
260
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
261
262
                    num_toks
                );
263
264
265
266
267
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
268
269
270
271
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
272
            kv_params.lora_name.as_deref(),
273
274
275
276
277
278
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
279
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
280
        }),
281
        event_id: kv_params.event_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
282
        dp_rank: 0,
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
Yan Ru Pei's avatar
Yan Ru Pei committed
300
        dp_rank: 0,
301
302
303
    }
}

304
305
306
307
308
309
310
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
311
    pub lora_name: Option<String>,
312
313
}

314
315
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
316
317
/// has a parent hash or not. nullptr is used to represent no parent hash.
/// lora_name is an optional null-terminated C string; pass nullptr for base model.
318
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
319
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
320
321
322
323
324
325
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
326
    lora_name: *const c_char,
Neelay Shah's avatar
Neelay Shah committed
327
) -> DynamoLlmResult {
328
329
330
331
332
333
334
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
335
336
337
338
339
340
341
342
343
344
345
    let lora_name = if lora_name.is_null() {
        None
    } else {
        match unsafe { CStr::from_ptr(lora_name) }.to_str() {
            Ok(s) => Some(s.to_owned()),
            Err(e) => {
                tracing::error!(error = ?e, "Failed to convert C string to Rust string (lora_name)");
                return DynamoLlmResult::ERR;
            }
        }
    };
346
    let kv_params = DynamoKvStoredEventParams {
347
348
349
350
351
352
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
353
        lora_name,
354
355
356
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
357
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
358
        Ok(_) => DynamoLlmResult::OK,
359
360
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
361
            DynamoLlmResult::ERR
362
363
364
365
        }
    }
}

366
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
367
pub extern "C" fn dynamo_kv_event_publish_removed(
368
369
370
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
371
) -> DynamoLlmResult {
372
373
374
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
375
        Ok(_) => DynamoLlmResult::OK,
376
377
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
378
            DynamoLlmResult::ERR
379
380
381
382
        }
    }
}

383
/* ------------------------------------------------------------------------
384
 *  Router Bindings for GAIE EPP
385
386
 * ------------------------------------------------------------------------ */

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
// Default timeout for bookkeeping operations
const BOOKKEEPING_TIMEOUT_SEC: u64 = 5;
/// Complete routing result for a chat completion request (C-compatible)
#[repr(C)]
pub struct CRoutingResult {
    /// Whether disaggregated mode is active
    pub is_disaggregated: bool,
    /// Prefill worker ID (only valid if is_disaggregated is true)
    pub prefill_worker_id: u64,
    /// Decode worker ID
    pub decode_worker_id: u64,
    /// Token IDs (needed for add_request callback)
    pub token_ids: *mut u32,
    /// Number of tokens in the request
    pub token_count: usize,
402
403
}

404
405
406
407
408
409
410
411
impl Default for CRoutingResult {
    fn default() -> Self {
        Self {
            is_disaggregated: false,
            prefill_worker_id: 0,
            decode_worker_id: 0,
            token_ids: ptr::null_mut(),
            token_count: 0,
412
        }
413
    }
414
}
415

416
417
418
419
420
421
422
423
424
425
426
427
428
/// Container holding routers and preprocessor for query routing
pub struct RouterHandles {
    prefill_router: Arc<PrefillRouter>,
    decode_router: Arc<KvRouter>,
    #[allow(dead_code)]
    model_manager: Arc<ModelManager>,
    #[allow(dead_code)]
    namespace: String,
    /// Cached runtime for executing async operations (avoids creating new runtime per call)
    runtime: Runtime,
    /// Preprocessor for tokenization and template application (fetched via discovery)
    preprocessor: Option<Arc<OpenAIPreprocessor>>,
}
429

430
431
impl RouterHandles {
    /// Query optimal prefill worker for a request.
432
433
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered.
434
435
436
437
    /// Returns worker_id on success.
    async fn query_prefill_worker(
        &self,
        tokens: &[u32],
438
        block_mm_infos: Option<&[Option<dynamo_kv_router::protocols::BlockExtraInfo>]>,
439
440
441
        update_states: bool,
        lora_name: Option<String>,
        priority_jump: f64,
442
        allowed_worker_ids: Option<HashSet<WorkerId>>,
443
    ) -> Result<u64, QueryRouterResult> {
444
445
446
447
        if let Some(ref ids) = allowed_worker_ids {
            self.prefill_router.register_workers(ids);
        }

448
        self.prefill_router
449
450
451
452
453
454
            .query_prefill_worker(
                tokens,
                block_mm_infos,
                update_states,
                lora_name,
                priority_jump,
455
                allowed_worker_ids,
456
            )
457
458
459
460
461
462
463
            .await
            .map(|(worker_id, _dp_rank)| worker_id)
            .map_err(|e| {
                tracing::error!(error = ?e, "Prefill query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
464

465
466
467
468
    /// Query optimal decode worker for a request.
    /// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
    /// (since KV cache is being transferred from prefill, not reused).
    ///
469
470
471
    /// When `allowed_worker_ids` is Some, only workers in that set are considered.
    /// This does NOT overwrite the router's internal worker state — it only filters this decision.
    ///
472
473
474
475
476
477
478
479
    /// Note: The C bindings are query-only and must not mutate router state during worker
    /// selection. State updates require a `context_id` (request id) and are managed via the
    /// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
    /// Returns (worker, overlap_blocks) on success.
    async fn query_decode_worker(
        &self,
        tokens: &[u32],
        is_disaggregated: bool,
480
        allowed_worker_ids: Option<HashSet<WorkerId>>,
481
    ) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
482
483
484
485
        if let Some(ref ids) = allowed_worker_ids {
            self.decode_router.register_workers(ids);
        }

486
487
488
489
490
491
        // For decode phase in disaggregated mode, use overlap_score_weight=0
        // This matches prefill_router.rs
        let config_override = if is_disaggregated {
            Some(RouterConfigOverride {
                overlap_score_weight: Some(0.0),
                ..Default::default()
492
            })
493
494
495
496
        } else {
            None
        };

497
        self.decode_router
498
499
500
501
502
503
504
505
            .find_best_match(
                None,
                tokens,
                None,
                config_override.as_ref(),
                false,
                None,
                0.0,
506
                None,
507
                allowed_worker_ids,
508
            )
509
510
511
512
513
514
515
            .await
            .map_err(|e| {
                tracing::error!(error = ?e, "Decode query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
}
516

517
518
/// Opaque handle for the router pair
pub type RouterHandlesPtr = *mut RouterHandles;
519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
/// Result codes for query router C FFI
#[repr(u32)]
pub enum QueryRouterResult {
    Ok = 0,
    ErrInvalidHandle = 1,
    ErrInvalidParam = 2,
    ErrInitFailed = 3,
    ErrQueryFailed = 4,
    ErrDisaggEnforced = 5,
    ErrTimeout = 6,
}

/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
fn kv_router_config_from_env() -> KvRouterConfig {
    let mut cfg = KvRouterConfig::default();

    fn env_f64(key: &str) -> Option<f64> {
        std::env::var(key).ok().and_then(|v| v.parse().ok())
538
    }
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    fn env_bool(key: &str) -> Option<bool> {
        std::env::var(key)
            .ok()
            .and_then(|v| match v.to_lowercase().as_str() {
                "true" | "1" | "yes" | "on" => Some(true),
                "false" | "0" | "no" | "off" => Some(false),
                _ => None,
            })
    }

    if let Some(v) = env_f64("DYN_OVERLAP_SCORE_WEIGHT") {
        cfg.overlap_score_weight = v;
    }
    if let Some(v) = env_f64("DYN_ROUTER_TEMPERATURE") {
        cfg.router_temperature = v;
    }
    if let Some(v) = env_bool("DYN_USE_KV_EVENTS") {
        cfg.use_kv_events = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_REPLICA_SYNC") {
        cfg.router_replica_sync = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_ACTIVE_BLOCKS") {
        cfg.router_track_active_blocks = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") {
        cfg.router_track_output_blocks = v;
    }
567
568
569
    if let Some(v) = env_f64("DYN_ROUTER_QUEUE_THRESHOLD") {
        cfg.router_queue_threshold = Some(v);
    }
570
571
572
573
574
575
576
577

    tracing::info!(
        overlap_score_weight = cfg.overlap_score_weight,
        router_temperature = cfg.router_temperature,
        use_kv_events = cfg.use_kv_events,
        router_replica_sync = cfg.router_replica_sync,
        router_track_active_blocks = cfg.router_track_active_blocks,
        router_track_output_blocks = cfg.router_track_output_blocks,
578
        router_queue_threshold = ?cfg.router_queue_threshold,
579
580
581
582
        "KvRouterConfig initialized (DYN_* env overrides applied)"
    );

    cfg
583
584
}

585
/// Create router handles for query-only routing
586
///
587
588
589
/// This function waits for at least one decode worker to be discovered before returning.
/// It auto-detects disaggregated mode by checking if prefill workers are present.
/// The KV cache block size is automatically fetched from the model card via discovery.
590
///
591
592
593
/// # Arguments
/// - `namespace`: Namespace for the model
/// - `component`: Component name (defaults to "backend" if NULL or empty)
594
/// - `enforce_disagg`: If true, requires prefill workers to be present at init time
595
/// - `out_handle`: Output handle
596
///
597
598
599
/// # Safety
/// - All string parameters must be valid null-terminated C strings
/// - The returned handle must be freed with `destroy`
600
#[unsafe(no_mangle)]
601
602
603
pub unsafe extern "C" fn create_routers(
    namespace: *const c_char,
    component: *const c_char,
604
    enforce_disagg: bool,
605
606
    out_handle: *mut RouterHandlesPtr,
) -> QueryRouterResult {
607
608
    initialize_tracing();

609
610
    if namespace.is_null() || out_handle.is_null() {
        return QueryRouterResult::ErrInvalidParam;
611
612
    }

613
614
615
    let namespace_str = match unsafe { CStr::from_ptr(namespace) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
616
617
    };

618
619
    let component_str = if component.is_null() {
        "backend".to_string()
620
    } else {
621
622
623
        match unsafe { CStr::from_ptr(component) }.to_str() {
            Ok(s) if !s.is_empty() => s.to_owned(),
            _ => "backend".to_string(),
624
625
626
        }
    };

627
628
629
    // Create the runtime once - it will be stored in RouterHandles and reused
    let runtime = match Runtime::from_settings() {
        Ok(rt) => rt,
630
        Err(e) => {
631
632
            tracing::error!(error = ?e, "Failed to create runtime");
            return QueryRouterResult::ErrInitFailed;
633
634
        }
    };
635
636
637
638
639
640
641
642
643
644

    // Clone for use inside the async block (the original will be moved into handles)
    let runtime_for_async = runtime.clone();

    let result = runtime_for_async.secondary().block_on(async {
        let drt = match DistributedRuntime::from_settings(runtime_for_async.clone()).await {
            Ok(drt) => drt,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create distributed runtime");
                return Err(QueryRouterResult::ErrInitFailed);
645
            }
646
647
        };

648
649
650
651
652
653
654
655
656
657
658
659
660
661
        let (preprocessor, block_size, model_name, actual_namespace) =
            match init_preprocessor(&drt, &namespace_str).await {
                Ok(result) => result,
                Err(e) => {
                    tracing::error!(error = %e, "Failed to initialize preprocessor");
                    return Err(QueryRouterResult::ErrInitFailed);
                }
            };

        if actual_namespace != namespace_str {
            tracing::info!(
                base_namespace = namespace_str,
                actual_namespace = actual_namespace,
                "Worker namespace has rolling-update suffix"
662
            );
663
664
        }

665
666
        let mut kv_router_config = kv_router_config_from_env();
        kv_router_config.skip_initial_worker_wait = true;
667

668
669
670
        // Build endpoint using the actual namespace discovered from workers,
        // which may include a rolling-update hash suffix.
        let component_handle = match drt.namespace(&actual_namespace) {
671
672
673
674
675
676
            Ok(ns) => match ns.component(&component_str) {
                Ok(c) => c,
                Err(e) => {
                    tracing::error!(error = ?e, "Failed to get component");
                    return Err(QueryRouterResult::ErrInitFailed);
                }
677
            },
678
679
680
681
682
683
            Err(e) => {
                tracing::error!(error = ?e, "Failed to get namespace");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
        let endpoint = component_handle.endpoint("generate");
684

685
        let model_manager = Arc::new(ModelManager::new());
686

687
688
689
690
691
        // Create decode router
        let decode_router = match model_manager
            .kv_chooser_for(
                &endpoint,
                block_size,
692
                Some(kv_router_config.clone()),
693
                WORKER_TYPE_DECODE,
694
                Some(model_name.clone()),
695
696
697
698
699
700
701
702
703
            )
            .await
        {
            Ok(r) => r,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create decode router");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        // Wait for the runtime config watch to be populated with at least one
        // decode worker's ModelRuntimeConfig. skip_initial_worker_wait=true
        // skips this inside KvRouter::new, but the selector needs workers in
        // workers_with_configs to avoid NoEndpoints on the first request.
        // discovery sync already confirmed workers exist; this just waits for
        // the async join of instance IDs + configs to complete in the watch.
        {
            let mut config_watch = model_manager
                .get_or_create_runtime_config_watcher(&endpoint)
                .await
                .map_err(|e| {
                    tracing::error!(error = ?e, "Failed to get runtime config watcher");
                    QueryRouterResult::ErrInitFailed
                })?;
            tracing::info!(
                "Waiting for decode workers to register ModelRuntimeConfig \
                 (no timeout - controlled by K8s StartupProbe)..."
            );
            let wait_result = config_watch.wait_for(|m| !m.is_empty()).await.map(|_| ());
            match wait_result {
                Ok(()) => {
                    let count = config_watch.borrow().len();
                    tracing::info!(
                        worker_count = count,
                        "Runtime config watch populated with decode workers"
                    );
                }
                Err(_) => {
                    tracing::error!(
                        "Runtime config watch closed before any workers appeared. \
                         Decode routing will fail. \
                         Verify workers are running and publishing to discovery."
                    );
                    return Err(QueryRouterResult::ErrInitFailed);
                }
            }
        }

743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        // Create PrefillRouter based on one-time discovery of prefill workers
        // Auto-detects disaggregated mode by checking if prefill workers are present
        // The prefill workers have to be created before the epp is created.
        // Given that we wait first for the decode worker to show up it is reasonable to assume the prefill will be up as well.
        let prefill_router = match find_prefill_endpoint(&drt, &namespace_str).await {
            Some(prefill_endpoint) => {
                tracing::info!("Prefill worker found, running in disaggregated mode");
                let mut prefill_config = kv_router_config;
                prefill_config.router_track_active_blocks = false;

                // Create immediately-resolved channel to activate router
                let (tx, rx) = tokio::sync::oneshot::channel();
                let _ = tx.send(prefill_endpoint);

                PrefillRouter::new(
                    rx,
                    model_manager.clone(),
                    RouterMode::KV,
                    block_size,
                    Some(prefill_config),
763
                    enforce_disagg,
764
                    model_name.clone(),
765
                    namespace_str.clone(),
766
767
                )
            }
768
            None if enforce_disagg => {
769
                tracing::error!(
770
                    "Prefill workers required but none found (enforce_disagg is enabled)"
771
                );
772
773
774
775
                return Err(QueryRouterResult::ErrDisaggEnforced);
            }
            None => {
                tracing::info!("No prefill workers found, running in aggregated mode");
776
                PrefillRouter::disabled(model_manager.clone(), RouterMode::KV, enforce_disagg)
777
778
779
780
781
782
783
784
785
786
787
            }
        };

        Ok((
            prefill_router,
            decode_router,
            model_manager,
            namespace_str,
            preprocessor,
        ))
    });
788
789

    match result {
790
791
792
793
794
795
796
797
798
799
800
        Ok((prefill_router, decode_router, model_manager, namespace_str, preprocessor)) => {
            let handles = RouterHandles {
                prefill_router,
                decode_router,
                model_manager,
                namespace: namespace_str,
                runtime, // Store the runtime for reuse
                preprocessor,
            };
            unsafe { *out_handle = Box::into_raw(Box::new(handles)) };
            QueryRouterResult::Ok
801
        }
802
        Err(code) => code,
803
804
805
806
807
    }
}

/// Add a request to the router's bookkeeping after worker selection.
///
808
809
810
/// Register the request with the KvRouter's scheduler for tracking active blocks
/// and managing prefill/decode lifecycle. Call this after `query_decode` returns
/// worker IDs and before sending the request to the worker.
811
812
///
/// # Safety
813
814
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
815
816
/// - `token_ids` must point to at least `token_count` valid u32 values
#[unsafe(no_mangle)]
817
818
819
pub unsafe extern "C" fn add_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
820
821
822
823
    token_ids: *const u32,
    token_count: usize,
    worker_id: u64,
    dp_rank: u32,
824
825
826
827
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
828

829
830
831
832
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
833
834
835
836
837
838
839
840
    };

    let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
        unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
    } else {
        Vec::new()
    };

841
    let decode_router = handles.decode_router.clone();
842

843
844
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
845

846
847
        tokio::time::timeout(timeout_duration, async {
            let worker = WorkerWithDpRank::new(worker_id, dp_rank);
848

849
            // Compute overlap_blocks using the public method
850
851
852
853
            let overlap_blocks = match decode_router
                .get_overlap_blocks(&tokens, worker, None)
                .await
            {
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
                Ok(overlap) => overlap,
                Err(e) => {
                    tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
                    0
                }
            };

            decode_router
                .add_request(
                    request_id_str.clone(),
                    &tokens,
                    overlap_blocks,
                    None,
                    worker,
                    None, // lora_name
                    None, // router_config_override
                )
                .await;

            tracing::debug!(
                request_id = %request_id_str,
                worker_id = worker_id,
                dp_rank = dp_rank,
                overlap_blocks = overlap_blocks,
                token_count = tokens.len(),
                "add_request completed"
            );
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "add_request timed out"
            );
            QueryRouterResult::ErrTimeout
        }
    }
896
897
898
}

/// Mark prefill as completed for a request.
899
900
///
/// Call when the first token is generated to release prefill tokens from decode worker's load
901
902
///
/// # Safety
903
904
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
905
#[unsafe(no_mangle)]
906
907
908
909
910
911
912
pub unsafe extern "C" fn mark_prefill_complete(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
913

914
915
916
917
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
918
919
    };

920
    let decode_router = handles.decode_router.clone();
921

922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);

        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.mark_prefill_completed(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to mark prefill complete"
                );
            } else {
                tracing::debug!(
                    request_id = %request_id_str,
                    "mark_prefill_complete completed"
                );
            }
        })
        .await
    });
941

942
943
944
    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
945
            tracing::warn!(
946
947
948
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "mark_prefill_complete timed out"
949
            );
950
            QueryRouterResult::ErrTimeout
951
        }
952
    }
953
954
955
}

/// Free a request from the router's bookkeeping.
956
957
///
/// Call this when the stream is closed (completed or cancelled) to release all resources.
958
959
///
/// # Safety
960
961
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
962
#[unsafe(no_mangle)]
963
964
965
966
967
968
969
pub unsafe extern "C" fn free_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
970

971
972
973
974
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
975
976
    };

977
    let decode_router = handles.decode_router.clone();
978

979
980
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
981

982
983
984
985
986
987
        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.free(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to free request"
988
                );
989
            } else {
990
                tracing::debug!(
991
992
                    request_id = %request_id_str,
                    "free_request completed"
993
                );
994
            }
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "free_request timed out"
            );
            QueryRouterResult::ErrTimeout
1008
1009
        }
    }
1010
}
1011

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
/// Destroy router handles
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle or null
/// - After this call, `handle` must not be used
#[unsafe(no_mangle)]
pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
    if !handle.is_null() {
        drop(unsafe { Box::from_raw(handle) });
    }
1022
1023
}

1024
/// Free a routing result.
1025
///
1026
/// # Safety
1027
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
1028
#[unsafe(no_mangle)]
1029
1030
1031
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
    if result.is_null() {
        return;
1032
    }
1033

1034
    let res = unsafe { &mut *result };
1035

1036
1037
1038
    // Free token IDs
    if !res.token_ids.is_null() && res.token_count > 0 {
        drop(unsafe {
1039
            Box::from_raw(std::ptr::slice_from_raw_parts_mut(
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
                res.token_ids,
                res.token_count,
            ))
        });
        res.token_ids = ptr::null_mut();
        res.token_count = 0;
    }
}

/// Parse a JSON request string, apply the chat template, and tokenize.
/// Returns the token IDs on success, or a `QueryRouterResult` error code.
unsafe fn preprocess_request(
    handles: &RouterHandles,
    request_json: *const c_char,
) -> Result<Vec<u32>, QueryRouterResult> {
1055
1056
1057
1058
    let preprocessor = match &handles.preprocessor {
        Some(p) => p,
        None => {
            tracing::error!("Preprocessor not available");
1059
            return Err(QueryRouterResult::ErrInitFailed);
1060
1061
1062
1063
1064
        }
    };

    let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
        Ok(s) => s,
1065
        Err(_) => return Err(QueryRouterResult::ErrInvalidParam),
1066
1067
1068
1069
1070
1071
1072
    };

    let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
        match serde_json::from_str(json_str) {
            Ok(req) => req,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse request JSON");
1073
                return Err(QueryRouterResult::ErrInvalidParam);
1074
1075
1076
1077
1078
1079
1080
1081
            }
        };

    let formatted_prompt = match preprocessor.apply_template(&request) {
        Ok(Some(prompt)) => prompt,
        Ok(None) => String::new(),
        Err(e) => {
            tracing::error!(error = ?e, "Failed to apply chat template");
1082
            return Err(QueryRouterResult::ErrQueryFailed);
1083
1084
1085
1086
1087
1088
1089
        }
    };

    let encoding = match preprocessor.tokenize(&formatted_prompt) {
        Ok(enc) => enc,
        Err(e) => {
            tracing::error!(error = ?e, "Failed to tokenize");
1090
            return Err(QueryRouterResult::ErrQueryFailed);
1091
1092
1093
        }
    };

1094
1095
    Ok(encoding.token_ids().to_vec())
}
1096

1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
/// Parse pods JSON into an optional set of allowed worker IDs.
unsafe fn parse_pods_filter(pods_json: *const c_char) -> Option<HashSet<WorkerId>> {
    if pods_json.is_null() {
        return None;
    }
    match unsafe { CStr::from_ptr(pods_json) }.to_str() {
        Ok(s) if !s.is_empty() => match serde_json::from_str::<Vec<serde_json::Value>>(s) {
            Ok(pods) => {
                let mut worker_ids = HashSet::new();
                for pod in &pods {
                    let pod_name = pod
                        .get("pod")
                        .and_then(|p| p.get("podName"))
                        .or_else(|| pod.get("podName"))
                        .and_then(|v| v.as_str());
                    if let Some(name) = pod_name {
                        let worker_id = hash_pod_name(name);
                        tracing::debug!(
                            pod_name = name,
                            worker_id = format!("{:x}", worker_id),
                            "Mapped EPP pod to worker_id"
                        );
                        worker_ids.insert(worker_id);
                    }
                }
                tracing::info!(
                    pod_count = pods.len(),
                    unique_worker_ids = worker_ids.len(),
                    "Parsed EPP pods into allowed_worker_ids filter"
                );
                if worker_ids.is_empty() {
                    None
                } else {
                    Some(worker_ids)
                }
            }
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse pods JSON");
                None
            }
        },
        _ => None,
    }
}
1141

1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
/// Write token IDs into a `CRoutingResult`, transferring ownership to the caller.
fn write_tokens_to_result(tokens: &[u32], out: &mut CRoutingResult) {
    let token_vec: Vec<u32> = tokens.to_vec();
    let mut tokens_boxed = token_vec.into_boxed_slice();
    out.token_ids = tokens_boxed.as_mut_ptr();
    out.token_count = tokens.len();
    std::mem::forget(tokens_boxed);
}

/// Route a request to select the best **prefill** worker only.
///
/// This is used in disaggregated mode where the EPP runs separate prefill and decode
/// scoring profiles.  It tokenizes the request and queries only the prefill router.
///
/// The returned `CRoutingResult` contains:
/// - `prefill_worker_id`: the selected prefill worker
/// - `decode_worker_id`: 0 (unused — decode is handled by `route_decode_request`)
/// - `is_disaggregated`: always true (this function is only called in disagg mode)
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_prefill_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    pods_json: *const c_char,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }

    let handles = unsafe { &*handle };

    let tokens = match unsafe { preprocess_request(handles, request_json) } {
        Ok(t) => t,
        Err(code) => return code,
    };

    let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };

    let result = handles.runtime.secondary().block_on(async {
        let prefill_worker_id = handles
            .query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
1190
1191
1192
1193
            .await?;

        tracing::info!(
            prefill_worker_id = prefill_worker_id,
1194
1195
            token_count = tokens.len(),
            "Routed prefill request"
1196
1197
        );

1198
        Ok(prefill_worker_id)
1199
1200
    });

1201
    match result {
1202
1203
1204
1205
1206
1207
        Ok(prefill_worker_id) => {
            let out = unsafe { &mut *out_result };
            *out = CRoutingResult::default();
            out.is_disaggregated = true;
            out.prefill_worker_id = prefill_worker_id;
            write_tokens_to_result(&tokens, out);
1208
1209
1210
1211
            QueryRouterResult::Ok
        }
        Err(code) => code,
    }
1212
1213
}

1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
/// Route a request to select the best **decode** worker only.
///
/// This is used in both aggregated and disaggregated modes.
/// - When `is_disaggregated` is true, the decode router uses `overlap_score_weight=0`
///   (KV cache is being transferred from prefill, not reused locally).
/// - When `is_disaggregated` is false, normal KV-aware scoring is used.
///
/// The returned `CRoutingResult` contains:
/// - `decode_worker_id`: the selected decode worker
/// - `prefill_worker_id`: 0 (unused — prefill is handled by `route_prefill_request`)
/// - `is_disaggregated`: mirrors the input parameter
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
1226
///
1227
/// # Safety
1228
1229
1230
1231
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
1232
#[unsafe(no_mangle)]
1233
1234
1235
1236
1237
1238
1239
1240
1241
pub unsafe extern "C" fn route_decode_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    pods_json: *const c_char,
    is_disaggregated: bool,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
1242
1243
    }

1244
    let handles = unsafe { &*handle };
1245

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
    let tokens = match unsafe { preprocess_request(handles, request_json) } {
        Ok(t) => t,
        Err(code) => return code,
    };

    let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };

    let result = handles.runtime.secondary().block_on(async {
        let (decode_worker, _overlap_blocks) = handles
            .query_decode_worker(&tokens, is_disaggregated, allowed_worker_ids)
            .await?;

        tracing::info!(
            is_disaggregated = is_disaggregated,
            decode_worker_id = decode_worker.worker_id,
            decode_dp_rank = decode_worker.dp_rank,
            token_count = tokens.len(),
            "Routed decode request"
        );

        Ok(decode_worker)
    });

    match result {
        Ok(decode_worker) => {
            let out = unsafe { &mut *out_result };
            *out = CRoutingResult::default();
            out.is_disaggregated = is_disaggregated;
            out.decode_worker_id = decode_worker.worker_id;
            write_tokens_to_result(&tokens, out);
            QueryRouterResult::Ok
        }
        Err(code) => code,
1279
    }
1280
1281
}

1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
/// Initialize the preprocessor, block size, and model name.
///
/// Waits for discovery to sync (model card must be available for tokenization),
/// then creates the preprocessor from the model card. The `kv_cache_block_size`
/// and `model_name` are taken from the model card to ensure consistency with
/// the worker configuration.
async fn init_preprocessor(
    drt: &DistributedRuntime,
    target_namespace: &str,
) -> anyhow::Result<(Option<Arc<OpenAIPreprocessor>>, u32, String, String)> {
    let instance_count = wait_for_discovery_sync(drt).await;
    if instance_count == 0 {
        anyhow::bail!("Discovery sync failed: no worker instances found. Is the backend running?");
    }
    tracing::info!(
        "Discovery sync complete, {} worker(s) found",
        instance_count
    );

    // Retry fetching the preprocessor: model card metadata may arrive after
    // worker endpoints are registered.
    let (prep, block_size, model_name, actual_namespace) = loop {
        match fetch_preprocessor_from_discovery(drt, target_namespace).await {
            Ok(result) => break result,
            Err(e) => {
                tracing::warn!(
                    error = %e,
                    target_namespace,
                    "Model card not available yet, retrying in 5s..."
                );
                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
            }
        }
    };

    tracing::info!(
        kv_cache_block_size = block_size,
        model_name = model_name,
        actual_namespace = actual_namespace,
        "Preprocessor initialized from model card"
    );

    Ok((Some(prep), block_size, model_name, actual_namespace))
}

1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
/// Fetch model card via discovery and create preprocessor.
///
/// This function:
/// 1. Lists all models via discovery
/// 2. Finds the first model in the target namespace (decode workers only)
/// 3. Downloads the model config (tokenizer files) if needed
/// 4. Creates an OpenAIPreprocessor from the model card
/// 5. Returns the preprocessor, the kv_cache_block_size, and model_name from the model card
async fn fetch_preprocessor_from_discovery(
    drt: &DistributedRuntime,
    target_namespace: &str,
1338
) -> anyhow::Result<(Arc<OpenAIPreprocessor>, u32, String, String)> {
1339
    use dynamo_llm::model_card::ModelDeploymentCard;
1340
    use dynamo_runtime::discovery::DiscoveryInstance;
1341

1342
    let discovery = drt.discovery();
1343

1344
1345
    // List all models
    let instances = discovery.list(DiscoveryQuery::AllModels).await?;
1346

1347
1348
1349
1350
    // Find first model card in the target namespace (decode workers only).
    // Use prefix matching because workers may append a rolling-update hash
    // suffix to the base namespace (e.g. "ns-dgd-58908edc" vs "ns-dgd").
    let mut model_card: Option<(ModelDeploymentCard, String)> = None;
1351

1352
1353
    for instance in instances {
        if let DiscoveryInstance::Model { namespace, .. } = &instance {
1354
            if !namespace.starts_with(target_namespace) {
1355
1356
1357
                continue;
            }

1358
            let actual_namespace = namespace.clone();
1359
1360
1361
1362
1363
1364
            match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => {
                    // Skip prefill-only workers, we want decode workers for routing
                    if card.model_type.supports_prefill()
                        && !card.model_type.supports_chat()
                        && !card.model_type.supports_completions()
1365
                    {
1366
                        continue;
1367
                    }
1368
                    model_card = Some((card, actual_namespace));
1369
                    break;
1370
                }
1371
1372
1373
                Err(e) => {
                    tracing::debug!(error = %e, "Failed to deserialize model card, skipping");
                    continue;
1374
1375
1376
                }
            }
        }
1377
    }
1378

1379
    let (mut card, actual_namespace) = model_card.ok_or_else(|| {
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
        anyhow::anyhow!(
            "No model found in namespace '{}' via discovery",
            target_namespace
        )
    })?;

    let kv_cache_block_size = card.kv_cache_block_size;
    let model_name = card.name().to_string();
    tracing::info!(
        model_name = model_name,
        kv_cache_block_size = kv_cache_block_size,
1391
        actual_namespace = actual_namespace,
1392
        "Found model card via discovery"
1393
1394
    );

1395
1396
    // Download config (tokenizer files) if not local
    card.download_config().await?;
1397

1398
1399
    // Create preprocessor
    let preprocessor = OpenAIPreprocessor::new(card)?;
1400
1401
1402
1403
1404
1405
    Ok((
        preprocessor,
        kv_cache_block_size,
        model_name,
        actual_namespace,
    ))
1406
}
1407

1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
/// Find a prefill endpoint from already-discovered instances (one-time filter).
/// Returns the endpoint if a prefill worker is found in the target namespace.
async fn find_prefill_endpoint(
    drt: &DistributedRuntime,
    target_namespace: &str,
) -> Option<dynamo_runtime::component::Endpoint> {
    use dynamo_llm::model_card::ModelDeploymentCard;
    use dynamo_runtime::discovery::DiscoveryInstance;

    let discovery = drt.discovery();
    let instances = match discovery.list(DiscoveryQuery::AllModels).await {
        Ok(instances) => instances,
        Err(e) => {
            tracing::warn!(error = %e, "Failed to list instances for prefill discovery");
            return None;
        }
1424
1425
    };

1426
1427
1428
1429
1430
1431
1432
1433
    for instance in instances {
        if let DiscoveryInstance::Model {
            namespace,
            component,
            endpoint,
            ..
        } = &instance
        {
1434
            if !namespace.starts_with(target_namespace) {
1435
1436
                continue;
            }
1437

1438
1439
1440
1441
            let card = match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => card,
                Err(_) => continue,
            };
1442

1443
1444
1445
1446
            // Only handle prefill models
            if !card.model_type.supports_prefill() {
                continue;
            }
1447

1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
            tracing::info!(
                model_name = card.name(),
                "Prefill worker found in discovered instances"
            );

            // Build and return the endpoint
            if let Ok(ns) = drt.namespace(namespace)
                && let Ok(comp) = ns.component(component)
            {
                return Some(comp.endpoint(endpoint));
            }
        }
    }
1461

1462
    None
1463
}