lib.rs 46.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
7
use std::borrow::Cow;
8
use std::ffi::CStr;
9
use std::ptr;
10
use std::sync::Arc;
11
use std::sync::atomic::{AtomicU32, Ordering};
12
use std::time::Duration;
13

14
15
use dynamo_llm::kv_router::{protocols::*, publisher::KvEventPublisher};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
16
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
17
use dynamo_runtime::{DistributedRuntime, Worker};
18
19
20
21
22
23
24
25
26

use dynamo_runtime::Runtime;

use dynamo_llm::discovery::{ModelManager, WORKER_TYPE_DECODE};
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::kv_router::protocols::WorkerWithDpRank;
use dynamo_llm::kv_router::{KvRouter, PrefillRouter, RouterConfigOverride};
use dynamo_runtime::pipeline::RouterMode;

27
28
use std::collections::HashSet;

29
30
31
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
32
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
/// - or the resulting string is empty/whitespace.
#[inline]
unsafe fn cstr_or_default<'a>(ptr: *const c_char, default_val: &'a str) -> Cow<'a, str> {
    if ptr.is_null() {
        return Cow::from(default_val);
    }
    match unsafe { CStr::from_ptr(ptr) }
        .to_str()
        .ok()
        .map(|s| s.trim())
    {
        Some(s) if !s.is_empty() => Cow::from(s.to_owned()),
        _ => Cow::from(default_val),
    }
}

53
54
55
56
57
58
59
fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
60
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
61

62
    tracing::debug!("Tracing initialized");
63
64
65
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
66
pub enum DynamoLlmResult {
67
68
69
70
    OK = 0,
    ERR = 1,
}

71
72
73
74
75
76
77
78
79
// Wait for the discovery daemon to sync indefinitely and return at least one instance.
// This is because the Model info is registered by workers and it may take up to 30 min for the model weights to load and for the worker to register itself.
// The waiting timeout is implemented in the Kubernetes StartupProbe. The EPP waiting loops runs indefinitely, the Probe is a single source of truth with when to kill the EPP if discovery fails.
// If workers are not found within the probe's failureThreshold × periodSeconds, the pod will be killed and restarted.
// Users can adjust the StartupProbe waiting timed in the DGD for large models.
async fn wait_for_discovery_sync(drt: &DistributedRuntime) -> usize {
    tracing::info!(
        "Waiting for discovery to sync (no timeout - controlled by K8s StartupProbe)..."
    );
80
81
82
83
84
85
86
87
88
89
90
91
    let discovery = drt.discovery();

    loop {
        match discovery.list(DiscoveryQuery::AllModels).await {
            Ok(instances) if !instances.is_empty() => {
                return instances.len();
            }
            Ok(_) => {
                tracing::debug!("No instances yet, waiting...");
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
            }
            Err(e) => {
92
93
94
                // Log and continue - transient errors shouldn't stop the wait
                tracing::warn!("Discovery list error: {}, retrying...", e);
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
95
96
97
98
99
            }
        }
    }
}

100
/// # Safety
GuanLuo's avatar
GuanLuo committed
101
/// the namespace_c_str and component_c_str are passed as pointers to C strings
102
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
103
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
104
105
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
106
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
107
) -> DynamoLlmResult {
108
109
110
111
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
112
            tracing::error!(error = ?e, "Failed to initialize runtime (Worker::from_settings)");
Neelay Shah's avatar
Neelay Shah committed
113
            return DynamoLlmResult::ERR;
114
115
116
117
118
119
120
121
122
123
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
124
            Ok(drt) => {
125
                // Wait for discovery to sync before returning.
126
                // This is needed because dynamo_create_worker_selection_pipeline() is called
127
128
129
130
                // immediately after, and it needs discovery.list() to return data.
                // The discovery daemon takes time to query K8s and returns async, so we need to wait.
                // Note: This waits indefinitely - the K8s StartupProbe is the timeout mechanism.
                wait_for_discovery_sync(drt).await;
131
132
                Ok(())
            }
133
            Err(e) => {
134
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
135
                Err(DynamoLlmResult::ERR)
136
137
138
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
139
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
140
141
        Ok(s) => s.to_string(),
        Err(e) => {
142
            tracing::error!(error = ?e, "Failed to convert C string to Rust string (namespace)");
Neelay Shah's avatar
Neelay Shah committed
143
            return DynamoLlmResult::ERR;
144
145
146
        }
    };

147
148
149
150
151
    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();
152
153

    match result {
154
        Ok(_) => match KV_PUB.get_or_try_init(move || {
Yan Ru Pei's avatar
Yan Ru Pei committed
155
            dynamo_create_kv_publisher(namespace, component, kv_block_size)
156
        }) {
Neelay Shah's avatar
Neelay Shah committed
157
            Ok(_) => DynamoLlmResult::OK,
158
            Err(e) => {
159
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
160
                DynamoLlmResult::ERR
161
162
163
164
165
166
            }
        },
        Err(e) => e,
    }
}

167
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
168
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
169
170
171
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
172
            tracing::error!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
173
            return DynamoLlmResult::ERR;
174
175
176
177
178
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
179
    DynamoLlmResult::OK
180
181
}

182
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
183
184
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
185
186
187
188
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
189
190
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
191
192
193
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
194
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
195
196
    namespace: String,
    component: String,
197
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
198
) -> Result<KvEventPublisher, anyhow::Error> {
199
    tracing::info!("Creating KV Publisher for model: {}", component);
200
201
202
203
204
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
205
            let backend = drt.namespace(namespace)?.component(component)?;
Yan Ru Pei's avatar
Yan Ru Pei committed
206
            KvEventPublisher::new(backend, kv_block_size, None)
207
208
209
210
211
212
213
214
215
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
216
    kv_block_size: u32,
217
    lora_name: Option<&str>,
218
) -> KvCacheStoredBlockData {
219
220
221
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
222
        None,
223
        lora_name,
224
    )[0];
225
226
227
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
228
        mm_extra_info: None,
229
230
231
232
233
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
234
    kv_params: DynamoKvStoredEventParams,
235
    kv_block_size: u32,
236
237
238
239
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
240
241
242
243
244
245
246
247
248
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

249
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
250
251
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
252
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
253
254
255
                })
                .is_ok()
            {
256
                tracing::warn!(
257
258
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
259
260
                    num_toks
                );
261
262
263
264
265
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
266
267
268
269
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
270
            kv_params.lora_name.as_deref(),
271
272
273
274
275
276
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
277
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
278
        }),
279
        event_id: kv_params.event_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
280
        dp_rank: 0,
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
Yan Ru Pei's avatar
Yan Ru Pei committed
298
        dp_rank: 0,
299
300
301
    }
}

302
303
304
305
306
307
308
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
309
    pub lora_name: Option<String>,
310
311
}

312
313
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
314
315
/// has a parent hash or not. nullptr is used to represent no parent hash.
/// lora_name is an optional null-terminated C string; pass nullptr for base model.
316
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
317
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
318
319
320
321
322
323
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
324
    lora_name: *const c_char,
Neelay Shah's avatar
Neelay Shah committed
325
) -> DynamoLlmResult {
326
327
328
329
330
331
332
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
333
334
335
336
337
338
339
340
341
342
343
    let lora_name = if lora_name.is_null() {
        None
    } else {
        match unsafe { CStr::from_ptr(lora_name) }.to_str() {
            Ok(s) => Some(s.to_owned()),
            Err(e) => {
                tracing::error!(error = ?e, "Failed to convert C string to Rust string (lora_name)");
                return DynamoLlmResult::ERR;
            }
        }
    };
344
    let kv_params = DynamoKvStoredEventParams {
345
346
347
348
349
350
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
351
        lora_name,
352
353
354
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
355
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
356
        Ok(_) => DynamoLlmResult::OK,
357
358
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
359
            DynamoLlmResult::ERR
360
361
362
363
        }
    }
}

364
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
365
pub extern "C" fn dynamo_kv_event_publish_removed(
366
367
368
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
369
) -> DynamoLlmResult {
370
371
372
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
373
        Ok(_) => DynamoLlmResult::OK,
374
375
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
376
            DynamoLlmResult::ERR
377
378
379
380
        }
    }
}

381
/* ------------------------------------------------------------------------
382
 *  Router Bindings for GAIE EPP
383
384
 * ------------------------------------------------------------------------ */

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
// Default timeout for bookkeeping operations
const BOOKKEEPING_TIMEOUT_SEC: u64 = 5;
/// Complete routing result for a chat completion request (C-compatible)
#[repr(C)]
pub struct CRoutingResult {
    /// Whether disaggregated mode is active
    pub is_disaggregated: bool,
    /// Prefill worker ID (only valid if is_disaggregated is true)
    pub prefill_worker_id: u64,
    /// Decode worker ID
    pub decode_worker_id: u64,
    /// Token IDs (needed for add_request callback)
    pub token_ids: *mut u32,
    /// Number of tokens in the request
    pub token_count: usize,
400
401
}

402
403
404
405
406
407
408
409
impl Default for CRoutingResult {
    fn default() -> Self {
        Self {
            is_disaggregated: false,
            prefill_worker_id: 0,
            decode_worker_id: 0,
            token_ids: ptr::null_mut(),
            token_count: 0,
410
        }
411
    }
412
}
413

414
415
416
417
418
419
420
421
422
423
424
425
426
/// Container holding routers and preprocessor for query routing
pub struct RouterHandles {
    prefill_router: Arc<PrefillRouter>,
    decode_router: Arc<KvRouter>,
    #[allow(dead_code)]
    model_manager: Arc<ModelManager>,
    #[allow(dead_code)]
    namespace: String,
    /// Cached runtime for executing async operations (avoids creating new runtime per call)
    runtime: Runtime,
    /// Preprocessor for tokenization and template application (fetched via discovery)
    preprocessor: Option<Arc<OpenAIPreprocessor>>,
}
427

428
429
impl RouterHandles {
    /// Query optimal prefill worker for a request.
430
431
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered.
432
433
434
435
    /// Returns worker_id on success.
    async fn query_prefill_worker(
        &self,
        tokens: &[u32],
436
        block_mm_infos: Option<&[Option<dynamo_llm::kv_router::protocols::BlockExtraInfo>]>,
437
438
439
        update_states: bool,
        lora_name: Option<String>,
        priority_jump: f64,
440
        allowed_worker_ids: Option<HashSet<WorkerId>>,
441
442
    ) -> Result<u64, QueryRouterResult> {
        self.prefill_router
443
444
445
446
447
448
            .query_prefill_worker(
                tokens,
                block_mm_infos,
                update_states,
                lora_name,
                priority_jump,
449
                allowed_worker_ids,
450
            )
451
452
453
454
455
456
457
            .await
            .map(|(worker_id, _dp_rank)| worker_id)
            .map_err(|e| {
                tracing::error!(error = ?e, "Prefill query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
458

459
460
461
462
    /// Query optimal decode worker for a request.
    /// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
    /// (since KV cache is being transferred from prefill, not reused).
    ///
463
464
465
    /// When `allowed_worker_ids` is Some, only workers in that set are considered.
    /// This does NOT overwrite the router's internal worker state — it only filters this decision.
    ///
466
467
468
469
470
471
472
473
    /// Note: The C bindings are query-only and must not mutate router state during worker
    /// selection. State updates require a `context_id` (request id) and are managed via the
    /// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
    /// Returns (worker, overlap_blocks) on success.
    async fn query_decode_worker(
        &self,
        tokens: &[u32],
        is_disaggregated: bool,
474
        allowed_worker_ids: Option<HashSet<WorkerId>>,
475
476
477
478
479
480
481
    ) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
        // For decode phase in disaggregated mode, use overlap_score_weight=0
        // This matches prefill_router.rs
        let config_override = if is_disaggregated {
            Some(RouterConfigOverride {
                overlap_score_weight: Some(0.0),
                ..Default::default()
482
            })
483
484
485
486
        } else {
            None
        };

487
        self.decode_router
488
489
490
491
492
493
494
495
            .find_best_match(
                None,
                tokens,
                None,
                config_override.as_ref(),
                false,
                None,
                0.0,
496
                None,
497
                allowed_worker_ids,
498
            )
499
500
501
502
503
504
505
            .await
            .map_err(|e| {
                tracing::error!(error = ?e, "Decode query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
}
506

507
508
/// Opaque handle for the router pair
pub type RouterHandlesPtr = *mut RouterHandles;
509

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
/// Result codes for query router C FFI
#[repr(u32)]
pub enum QueryRouterResult {
    Ok = 0,
    ErrInvalidHandle = 1,
    ErrInvalidParam = 2,
    ErrInitFailed = 3,
    ErrQueryFailed = 4,
    ErrDisaggEnforced = 5,
    ErrTimeout = 6,
}

/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
fn kv_router_config_from_env() -> KvRouterConfig {
    let mut cfg = KvRouterConfig::default();

    fn env_f64(key: &str) -> Option<f64> {
        std::env::var(key).ok().and_then(|v| v.parse().ok())
528
    }
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
    fn env_bool(key: &str) -> Option<bool> {
        std::env::var(key)
            .ok()
            .and_then(|v| match v.to_lowercase().as_str() {
                "true" | "1" | "yes" | "on" => Some(true),
                "false" | "0" | "no" | "off" => Some(false),
                _ => None,
            })
    }

    if let Some(v) = env_f64("DYN_OVERLAP_SCORE_WEIGHT") {
        cfg.overlap_score_weight = v;
    }
    if let Some(v) = env_f64("DYN_ROUTER_TEMPERATURE") {
        cfg.router_temperature = v;
    }
    if let Some(v) = env_bool("DYN_USE_KV_EVENTS") {
        cfg.use_kv_events = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_REPLICA_SYNC") {
        cfg.router_replica_sync = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_ACTIVE_BLOCKS") {
        cfg.router_track_active_blocks = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") {
        cfg.router_track_output_blocks = v;
    }
557
558
559
    if let Some(v) = env_f64("DYN_ROUTER_QUEUE_THRESHOLD") {
        cfg.router_queue_threshold = Some(v);
    }
560
561
562
563
564
565
566
567

    tracing::info!(
        overlap_score_weight = cfg.overlap_score_weight,
        router_temperature = cfg.router_temperature,
        use_kv_events = cfg.use_kv_events,
        router_replica_sync = cfg.router_replica_sync,
        router_track_active_blocks = cfg.router_track_active_blocks,
        router_track_output_blocks = cfg.router_track_output_blocks,
568
        router_queue_threshold = ?cfg.router_queue_threshold,
569
570
571
572
        "KvRouterConfig initialized (DYN_* env overrides applied)"
    );

    cfg
573
574
}

575
/// Create router handles for query-only routing
576
///
577
578
579
/// This function waits for at least one decode worker to be discovered before returning.
/// It auto-detects disaggregated mode by checking if prefill workers are present.
/// The KV cache block size is automatically fetched from the model card via discovery.
580
///
581
582
583
/// # Arguments
/// - `namespace`: Namespace for the model
/// - `component`: Component name (defaults to "backend" if NULL or empty)
584
/// - `enforce_disagg`: If true, requires prefill workers to be present at init time
585
/// - `out_handle`: Output handle
586
///
587
588
589
/// # Safety
/// - All string parameters must be valid null-terminated C strings
/// - The returned handle must be freed with `destroy`
590
#[unsafe(no_mangle)]
591
592
593
pub unsafe extern "C" fn create_routers(
    namespace: *const c_char,
    component: *const c_char,
594
    enforce_disagg: bool,
595
596
597
598
    out_handle: *mut RouterHandlesPtr,
) -> QueryRouterResult {
    if namespace.is_null() || out_handle.is_null() {
        return QueryRouterResult::ErrInvalidParam;
599
600
    }

601
602
603
    let namespace_str = match unsafe { CStr::from_ptr(namespace) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
604
605
    };

606
607
    let component_str = if component.is_null() {
        "backend".to_string()
608
    } else {
609
610
611
        match unsafe { CStr::from_ptr(component) }.to_str() {
            Ok(s) if !s.is_empty() => s.to_owned(),
            _ => "backend".to_string(),
612
613
614
        }
    };

615
616
617
    // Create the runtime once - it will be stored in RouterHandles and reused
    let runtime = match Runtime::from_settings() {
        Ok(rt) => rt,
618
        Err(e) => {
619
620
            tracing::error!(error = ?e, "Failed to create runtime");
            return QueryRouterResult::ErrInitFailed;
621
622
        }
    };
623
624
625
626
627
628
629
630
631
632

    // Clone for use inside the async block (the original will be moved into handles)
    let runtime_for_async = runtime.clone();

    let result = runtime_for_async.secondary().block_on(async {
        let drt = match DistributedRuntime::from_settings(runtime_for_async.clone()).await {
            Ok(drt) => drt,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create distributed runtime");
                return Err(QueryRouterResult::ErrInitFailed);
633
            }
634
635
636
637
638
639
640
641
642
643
        };

        // Wait for at least one worker to be discovered before proceeding
        // This ensures the decode router can be created successfully
        let instance_count = wait_for_discovery_sync(&drt).await;
        if instance_count == 0 {
            tracing::error!(
                "Discovery sync failed: no worker instances found. Is the backend running?"
            );
            return Err(QueryRouterResult::ErrInitFailed);
644
        }
645
646
647
648
        tracing::info!(
            "Discovery sync complete, {} worker(s) found",
            instance_count
        );
649

650
        let kv_router_config = kv_router_config_from_env();
651

652
653
654
655
656
657
658
659
        // Get component and endpoint
        let component_handle = match drt.namespace(&namespace_str) {
            Ok(ns) => match ns.component(&component_str) {
                Ok(c) => c,
                Err(e) => {
                    tracing::error!(error = ?e, "Failed to get component");
                    return Err(QueryRouterResult::ErrInitFailed);
                }
660
            },
661
662
663
664
665
666
            Err(e) => {
                tracing::error!(error = ?e, "Failed to get namespace");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
        let endpoint = component_handle.endpoint("generate");
667

668
        let model_manager = Arc::new(ModelManager::new());
669

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        // Fetch model card via discovery and create preprocessor + get block_size
        let (preprocessor, block_size, model_name) =
            match fetch_preprocessor_from_discovery(&drt, &namespace_str).await {
                Ok((prep, bs, name)) => {
                    tracing::info!(
                        kv_cache_block_size = bs,
                        "Preprocessor created from discovery"
                    );
                    (Some(prep), bs, name)
                }
                Err(e) => {
                    tracing::error!(
                        error = %e,
                        "Failed to fetch model card from discovery - cannot determine block_size"
                    );
                    return Err(QueryRouterResult::ErrInitFailed);
                }
            };
688

689
690
691
692
693
        // Create decode router
        let decode_router = match model_manager
            .kv_chooser_for(
                &endpoint,
                block_size,
694
                Some(kv_router_config.clone()),
695
                WORKER_TYPE_DECODE,
696
                Some(model_name.clone()),
697
698
699
700
701
702
703
704
705
            )
            .await
        {
            Ok(r) => r,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create decode router");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        // Create PrefillRouter based on one-time discovery of prefill workers
        // Auto-detects disaggregated mode by checking if prefill workers are present
        // The prefill workers have to be created before the epp is created.
        // Given that we wait first for the decode worker to show up it is reasonable to assume the prefill will be up as well.
        let prefill_router = match find_prefill_endpoint(&drt, &namespace_str).await {
            Some(prefill_endpoint) => {
                tracing::info!("Prefill worker found, running in disaggregated mode");
                let mut prefill_config = kv_router_config;
                prefill_config.router_track_active_blocks = false;

                // Create immediately-resolved channel to activate router
                let (tx, rx) = tokio::sync::oneshot::channel();
                let _ = tx.send(prefill_endpoint);

                PrefillRouter::new(
                    rx,
                    model_manager.clone(),
                    RouterMode::KV,
                    block_size,
                    Some(prefill_config),
727
                    enforce_disagg,
728
                    model_name.clone(),
729
                    namespace_str.clone(),
730
731
                )
            }
732
            None if enforce_disagg => {
733
                tracing::error!(
734
                    "Prefill workers required but none found (enforce_disagg is enabled)"
735
                );
736
737
738
739
                return Err(QueryRouterResult::ErrDisaggEnforced);
            }
            None => {
                tracing::info!("No prefill workers found, running in aggregated mode");
740
                PrefillRouter::disabled(model_manager.clone(), RouterMode::KV, enforce_disagg)
741
742
743
744
745
746
747
748
749
750
751
            }
        };

        Ok((
            prefill_router,
            decode_router,
            model_manager,
            namespace_str,
            preprocessor,
        ))
    });
752
753

    match result {
754
755
756
757
758
759
760
761
762
763
764
        Ok((prefill_router, decode_router, model_manager, namespace_str, preprocessor)) => {
            let handles = RouterHandles {
                prefill_router,
                decode_router,
                model_manager,
                namespace: namespace_str,
                runtime, // Store the runtime for reuse
                preprocessor,
            };
            unsafe { *out_handle = Box::into_raw(Box::new(handles)) };
            QueryRouterResult::Ok
765
        }
766
        Err(code) => code,
767
768
769
770
771
    }
}

/// Add a request to the router's bookkeeping after worker selection.
///
772
773
774
/// Register the request with the KvRouter's scheduler for tracking active blocks
/// and managing prefill/decode lifecycle. Call this after `query_decode` returns
/// worker IDs and before sending the request to the worker.
775
776
///
/// # Safety
777
778
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
779
780
/// - `token_ids` must point to at least `token_count` valid u32 values
#[unsafe(no_mangle)]
781
782
783
pub unsafe extern "C" fn add_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
784
785
786
787
    token_ids: *const u32,
    token_count: usize,
    worker_id: u64,
    dp_rank: u32,
788
789
790
791
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
792

793
794
795
796
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
797
798
799
800
801
802
803
804
    };

    let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
        unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
    } else {
        Vec::new()
    };

805
    let decode_router = handles.decode_router.clone();
806

807
808
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
809

810
811
        tokio::time::timeout(timeout_duration, async {
            let worker = WorkerWithDpRank::new(worker_id, dp_rank);
812

813
            // Compute overlap_blocks using the public method
814
815
816
817
            let overlap_blocks = match decode_router
                .get_overlap_blocks(&tokens, worker, None)
                .await
            {
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
                Ok(overlap) => overlap,
                Err(e) => {
                    tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
                    0
                }
            };

            decode_router
                .add_request(
                    request_id_str.clone(),
                    &tokens,
                    overlap_blocks,
                    None,
                    worker,
                    None, // lora_name
                    None, // router_config_override
                )
                .await;

            tracing::debug!(
                request_id = %request_id_str,
                worker_id = worker_id,
                dp_rank = dp_rank,
                overlap_blocks = overlap_blocks,
                token_count = tokens.len(),
                "add_request completed"
            );
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "add_request timed out"
            );
            QueryRouterResult::ErrTimeout
        }
    }
860
861
862
}

/// Mark prefill as completed for a request.
863
864
///
/// Call when the first token is generated to release prefill tokens from decode worker's load
865
866
///
/// # Safety
867
868
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
869
#[unsafe(no_mangle)]
870
871
872
873
874
875
876
pub unsafe extern "C" fn mark_prefill_complete(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
877

878
879
880
881
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
882
883
    };

884
    let decode_router = handles.decode_router.clone();
885

886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);

        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.mark_prefill_completed(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to mark prefill complete"
                );
            } else {
                tracing::debug!(
                    request_id = %request_id_str,
                    "mark_prefill_complete completed"
                );
            }
        })
        .await
    });
905

906
907
908
    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
909
            tracing::warn!(
910
911
912
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "mark_prefill_complete timed out"
913
            );
914
            QueryRouterResult::ErrTimeout
915
        }
916
    }
917
918
919
}

/// Free a request from the router's bookkeeping.
920
921
///
/// Call this when the stream is closed (completed or cancelled) to release all resources.
922
923
///
/// # Safety
924
925
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
926
#[unsafe(no_mangle)]
927
928
929
930
931
932
933
pub unsafe extern "C" fn free_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
934

935
936
937
938
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
939
940
    };

941
    let decode_router = handles.decode_router.clone();
942

943
944
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
945

946
947
948
949
950
951
        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.free(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to free request"
952
                );
953
            } else {
954
                tracing::debug!(
955
956
                    request_id = %request_id_str,
                    "free_request completed"
957
                );
958
            }
959
960
961
962
963
964
965
966
967
968
969
970
971
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "free_request timed out"
            );
            QueryRouterResult::ErrTimeout
972
973
        }
    }
974
}
975

976
977
978
979
980
981
982
983
984
985
/// Destroy router handles
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle or null
/// - After this call, `handle` must not be used
#[unsafe(no_mangle)]
pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
    if !handle.is_null() {
        drop(unsafe { Box::from_raw(handle) });
    }
986
987
}

988
/// Free a routing result.
989
///
990
/// # Safety
991
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
992
#[unsafe(no_mangle)]
993
994
995
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
    if result.is_null() {
        return;
996
    }
997

998
    let res = unsafe { &mut *result };
999

1000
1001
1002
    // Free token IDs
    if !res.token_ids.is_null() && res.token_count > 0 {
        drop(unsafe {
1003
            Box::from_raw(std::ptr::slice_from_raw_parts_mut(
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
                res.token_ids,
                res.token_count,
            ))
        });
        res.token_ids = ptr::null_mut();
        res.token_count = 0;
    }
}

/// Parse a JSON request string, apply the chat template, and tokenize.
/// Returns the token IDs on success, or a `QueryRouterResult` error code.
unsafe fn preprocess_request(
    handles: &RouterHandles,
    request_json: *const c_char,
) -> Result<Vec<u32>, QueryRouterResult> {
1019
1020
1021
1022
    let preprocessor = match &handles.preprocessor {
        Some(p) => p,
        None => {
            tracing::error!("Preprocessor not available");
1023
            return Err(QueryRouterResult::ErrInitFailed);
1024
1025
1026
1027
1028
        }
    };

    let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
        Ok(s) => s,
1029
        Err(_) => return Err(QueryRouterResult::ErrInvalidParam),
1030
1031
1032
1033
1034
1035
1036
    };

    let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
        match serde_json::from_str(json_str) {
            Ok(req) => req,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse request JSON");
1037
                return Err(QueryRouterResult::ErrInvalidParam);
1038
1039
1040
1041
1042
1043
1044
1045
            }
        };

    let formatted_prompt = match preprocessor.apply_template(&request) {
        Ok(Some(prompt)) => prompt,
        Ok(None) => String::new(),
        Err(e) => {
            tracing::error!(error = ?e, "Failed to apply chat template");
1046
            return Err(QueryRouterResult::ErrQueryFailed);
1047
1048
1049
1050
1051
1052
1053
        }
    };

    let encoding = match preprocessor.tokenize(&formatted_prompt) {
        Ok(enc) => enc,
        Err(e) => {
            tracing::error!(error = ?e, "Failed to tokenize");
1054
            return Err(QueryRouterResult::ErrQueryFailed);
1055
1056
1057
        }
    };

1058
1059
    Ok(encoding.token_ids().to_vec())
}
1060

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
/// Parse pods JSON into an optional set of allowed worker IDs.
unsafe fn parse_pods_filter(pods_json: *const c_char) -> Option<HashSet<WorkerId>> {
    if pods_json.is_null() {
        return None;
    }
    match unsafe { CStr::from_ptr(pods_json) }.to_str() {
        Ok(s) if !s.is_empty() => match serde_json::from_str::<Vec<serde_json::Value>>(s) {
            Ok(pods) => {
                let mut worker_ids = HashSet::new();
                for pod in &pods {
                    let pod_name = pod
                        .get("pod")
                        .and_then(|p| p.get("podName"))
                        .or_else(|| pod.get("podName"))
                        .and_then(|v| v.as_str());
                    if let Some(name) = pod_name {
                        let worker_id = hash_pod_name(name);
                        tracing::debug!(
                            pod_name = name,
                            worker_id = format!("{:x}", worker_id),
                            "Mapped EPP pod to worker_id"
                        );
                        worker_ids.insert(worker_id);
                    }
                }
                tracing::info!(
                    pod_count = pods.len(),
                    unique_worker_ids = worker_ids.len(),
                    "Parsed EPP pods into allowed_worker_ids filter"
                );
                if worker_ids.is_empty() {
                    None
                } else {
                    Some(worker_ids)
                }
            }
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse pods JSON");
                None
            }
        },
        _ => None,
    }
}
1105

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
/// Write token IDs into a `CRoutingResult`, transferring ownership to the caller.
fn write_tokens_to_result(tokens: &[u32], out: &mut CRoutingResult) {
    let token_vec: Vec<u32> = tokens.to_vec();
    let mut tokens_boxed = token_vec.into_boxed_slice();
    out.token_ids = tokens_boxed.as_mut_ptr();
    out.token_count = tokens.len();
    std::mem::forget(tokens_boxed);
}

/// Route a request to select the best **prefill** worker only.
///
/// This is used in disaggregated mode where the EPP runs separate prefill and decode
/// scoring profiles.  It tokenizes the request and queries only the prefill router.
///
/// The returned `CRoutingResult` contains:
/// - `prefill_worker_id`: the selected prefill worker
/// - `decode_worker_id`: 0 (unused — decode is handled by `route_decode_request`)
/// - `is_disaggregated`: always true (this function is only called in disagg mode)
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_prefill_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    pods_json: *const c_char,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }

    let handles = unsafe { &*handle };

    let tokens = match unsafe { preprocess_request(handles, request_json) } {
        Ok(t) => t,
        Err(code) => return code,
    };

    let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };

    let result = handles.runtime.secondary().block_on(async {
        let prefill_worker_id = handles
            .query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
1154
1155
1156
1157
            .await?;

        tracing::info!(
            prefill_worker_id = prefill_worker_id,
1158
1159
            token_count = tokens.len(),
            "Routed prefill request"
1160
1161
        );

1162
        Ok(prefill_worker_id)
1163
1164
    });

1165
    match result {
1166
1167
1168
1169
1170
1171
        Ok(prefill_worker_id) => {
            let out = unsafe { &mut *out_result };
            *out = CRoutingResult::default();
            out.is_disaggregated = true;
            out.prefill_worker_id = prefill_worker_id;
            write_tokens_to_result(&tokens, out);
1172
1173
1174
1175
            QueryRouterResult::Ok
        }
        Err(code) => code,
    }
1176
1177
}

1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
/// Route a request to select the best **decode** worker only.
///
/// This is used in both aggregated and disaggregated modes.
/// - When `is_disaggregated` is true, the decode router uses `overlap_score_weight=0`
///   (KV cache is being transferred from prefill, not reused locally).
/// - When `is_disaggregated` is false, normal KV-aware scoring is used.
///
/// The returned `CRoutingResult` contains:
/// - `decode_worker_id`: the selected decode worker
/// - `prefill_worker_id`: 0 (unused — prefill is handled by `route_prefill_request`)
/// - `is_disaggregated`: mirrors the input parameter
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
1190
///
1191
/// # Safety
1192
1193
1194
1195
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
1196
#[unsafe(no_mangle)]
1197
1198
1199
1200
1201
1202
1203
1204
1205
pub unsafe extern "C" fn route_decode_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    pods_json: *const c_char,
    is_disaggregated: bool,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
1206
1207
    }

1208
    let handles = unsafe { &*handle };
1209

1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    let tokens = match unsafe { preprocess_request(handles, request_json) } {
        Ok(t) => t,
        Err(code) => return code,
    };

    let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };

    let result = handles.runtime.secondary().block_on(async {
        let (decode_worker, _overlap_blocks) = handles
            .query_decode_worker(&tokens, is_disaggregated, allowed_worker_ids)
            .await?;

        tracing::info!(
            is_disaggregated = is_disaggregated,
            decode_worker_id = decode_worker.worker_id,
            decode_dp_rank = decode_worker.dp_rank,
            token_count = tokens.len(),
            "Routed decode request"
        );

        Ok(decode_worker)
    });

    match result {
        Ok(decode_worker) => {
            let out = unsafe { &mut *out_result };
            *out = CRoutingResult::default();
            out.is_disaggregated = is_disaggregated;
            out.decode_worker_id = decode_worker.worker_id;
            write_tokens_to_result(&tokens, out);
            QueryRouterResult::Ok
        }
        Err(code) => code,
1243
    }
1244
1245
}

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
/// Fetch model card via discovery and create preprocessor.
///
/// This function:
/// 1. Lists all models via discovery
/// 2. Finds the first model in the target namespace (decode workers only)
/// 3. Downloads the model config (tokenizer files) if needed
/// 4. Creates an OpenAIPreprocessor from the model card
/// 5. Returns the preprocessor, the kv_cache_block_size, and model_name from the model card
async fn fetch_preprocessor_from_discovery(
    drt: &DistributedRuntime,
    target_namespace: &str,
) -> anyhow::Result<(Arc<OpenAIPreprocessor>, u32, String)> {
1258
    use dynamo_llm::model_card::ModelDeploymentCard;
1259
    use dynamo_runtime::discovery::DiscoveryInstance;
1260

1261
    let discovery = drt.discovery();
1262

1263
1264
    // List all models
    let instances = discovery.list(DiscoveryQuery::AllModels).await?;
1265

1266
1267
    // Find first model card in the target namespace (decode workers only)
    let mut model_card: Option<ModelDeploymentCard> = None;
1268

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    for instance in instances {
        if let DiscoveryInstance::Model { namespace, .. } = &instance {
            // Filter by namespace
            if namespace != target_namespace {
                continue;
            }

            match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => {
                    // Skip prefill-only workers, we want decode workers for routing
                    if card.model_type.supports_prefill()
                        && !card.model_type.supports_chat()
                        && !card.model_type.supports_completions()
1282
                    {
1283
                        continue;
1284
                    }
1285
1286
                    model_card = Some(card);
                    break;
1287
                }
1288
1289
1290
                Err(e) => {
                    tracing::debug!(error = %e, "Failed to deserialize model card, skipping");
                    continue;
1291
1292
1293
                }
            }
        }
1294
    }
1295

1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
    let mut card = model_card.ok_or_else(|| {
        anyhow::anyhow!(
            "No model found in namespace '{}' via discovery",
            target_namespace
        )
    })?;

    let kv_cache_block_size = card.kv_cache_block_size;
    let model_name = card.name().to_string();
    tracing::info!(
        model_name = model_name,
        kv_cache_block_size = kv_cache_block_size,
        "Found model card via discovery"
1309
1310
    );

1311
1312
    // Download config (tokenizer files) if not local
    card.download_config().await?;
1313

1314
1315
1316
1317
    // Create preprocessor
    let preprocessor = OpenAIPreprocessor::new(card)?;
    Ok((preprocessor, kv_cache_block_size, model_name))
}
1318

1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
/// Find a prefill endpoint from already-discovered instances (one-time filter).
/// Returns the endpoint if a prefill worker is found in the target namespace.
async fn find_prefill_endpoint(
    drt: &DistributedRuntime,
    target_namespace: &str,
) -> Option<dynamo_runtime::component::Endpoint> {
    use dynamo_llm::model_card::ModelDeploymentCard;
    use dynamo_runtime::discovery::DiscoveryInstance;

    let discovery = drt.discovery();
    let instances = match discovery.list(DiscoveryQuery::AllModels).await {
        Ok(instances) => instances,
        Err(e) => {
            tracing::warn!(error = %e, "Failed to list instances for prefill discovery");
            return None;
        }
1335
1336
    };

1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
    for instance in instances {
        if let DiscoveryInstance::Model {
            namespace,
            component,
            endpoint,
            ..
        } = &instance
        {
            // Filter by namespace
            if namespace != target_namespace {
                continue;
            }
1349

1350
1351
1352
1353
            let card = match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => card,
                Err(_) => continue,
            };
1354

1355
1356
1357
1358
            // Only handle prefill models
            if !card.model_type.supports_prefill() {
                continue;
            }
1359

1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
            tracing::info!(
                model_name = card.name(),
                "Prefill worker found in discovered instances"
            );

            // Build and return the endpoint
            if let Ok(ns) = drt.namespace(namespace)
                && let Ok(comp) = ns.component(component)
            {
                return Some(comp.endpoint(endpoint));
            }
        }
    }
1373

1374
    None
1375
}