lib.rs 42.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
7
use std::borrow::Cow;
8
use std::ffi::CStr;
9
use std::ptr;
10
use std::sync::Arc;
11
use std::sync::atomic::{AtomicU32, Ordering};
12
use std::time::Duration;
13

14
15
use dynamo_llm::kv_router::{protocols::*, publisher::KvEventPublisher};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
16
use dynamo_runtime::discovery::DiscoveryQuery;
17
use dynamo_runtime::{DistributedRuntime, Worker};
18
19
20
21
22
23
24
25
26

use dynamo_runtime::Runtime;

use dynamo_llm::discovery::{ModelManager, WORKER_TYPE_DECODE};
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::kv_router::protocols::WorkerWithDpRank;
use dynamo_llm::kv_router::{KvRouter, PrefillRouter, RouterConfigOverride};
use dynamo_runtime::pipeline::RouterMode;

27
28
29
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
30
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
/// - or the resulting string is empty/whitespace.
#[inline]
unsafe fn cstr_or_default<'a>(ptr: *const c_char, default_val: &'a str) -> Cow<'a, str> {
    if ptr.is_null() {
        return Cow::from(default_val);
    }
    match unsafe { CStr::from_ptr(ptr) }
        .to_str()
        .ok()
        .map(|s| s.trim())
    {
        Some(s) if !s.is_empty() => Cow::from(s.to_owned()),
        _ => Cow::from(default_val),
    }
}

51
52
53
54
55
56
57
fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
58
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
59

60
    tracing::debug!("Tracing initialized");
61
62
63
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
64
pub enum DynamoLlmResult {
65
66
67
68
    OK = 0,
    ERR = 1,
}

69
70
71
72
73
74
75
76
77
// Wait for the discovery daemon to sync indefinitely and return at least one instance.
// This is because the Model info is registered by workers and it may take up to 30 min for the model weights to load and for the worker to register itself.
// The waiting timeout is implemented in the Kubernetes StartupProbe. The EPP waiting loops runs indefinitely, the Probe is a single source of truth with when to kill the EPP if discovery fails.
// If workers are not found within the probe's failureThreshold × periodSeconds, the pod will be killed and restarted.
// Users can adjust the StartupProbe waiting timed in the DGD for large models.
async fn wait_for_discovery_sync(drt: &DistributedRuntime) -> usize {
    tracing::info!(
        "Waiting for discovery to sync (no timeout - controlled by K8s StartupProbe)..."
    );
78
79
80
81
82
83
84
85
86
87
88
89
    let discovery = drt.discovery();

    loop {
        match discovery.list(DiscoveryQuery::AllModels).await {
            Ok(instances) if !instances.is_empty() => {
                return instances.len();
            }
            Ok(_) => {
                tracing::debug!("No instances yet, waiting...");
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
            }
            Err(e) => {
90
91
92
                // Log and continue - transient errors shouldn't stop the wait
                tracing::warn!("Discovery list error: {}, retrying...", e);
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
93
94
95
96
97
            }
        }
    }
}

98
/// # Safety
GuanLuo's avatar
GuanLuo committed
99
/// the namespace_c_str and component_c_str are passed as pointers to C strings
100
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
101
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
102
103
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
104
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
105
) -> DynamoLlmResult {
106
107
108
109
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
110
            tracing::error!(error = ?e, "Failed to initialize runtime (Worker::from_settings)");
Neelay Shah's avatar
Neelay Shah committed
111
            return DynamoLlmResult::ERR;
112
113
114
115
116
117
118
119
120
121
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
122
            Ok(drt) => {
123
                // Wait for discovery to sync before returning.
124
                // This is needed because dynamo_create_worker_selection_pipeline() is called
125
126
127
128
                // immediately after, and it needs discovery.list() to return data.
                // The discovery daemon takes time to query K8s and returns async, so we need to wait.
                // Note: This waits indefinitely - the K8s StartupProbe is the timeout mechanism.
                wait_for_discovery_sync(drt).await;
129
130
                Ok(())
            }
131
            Err(e) => {
132
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
133
                Err(DynamoLlmResult::ERR)
134
135
136
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
137
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
138
139
        Ok(s) => s.to_string(),
        Err(e) => {
140
            tracing::error!(error = ?e, "Failed to convert C string to Rust string (namespace)");
Neelay Shah's avatar
Neelay Shah committed
141
            return DynamoLlmResult::ERR;
142
143
144
        }
    };

145
146
147
148
149
    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();
150
151

    match result {
152
        Ok(_) => match KV_PUB.get_or_try_init(move || {
Yan Ru Pei's avatar
Yan Ru Pei committed
153
            dynamo_create_kv_publisher(namespace, component, kv_block_size)
154
        }) {
Neelay Shah's avatar
Neelay Shah committed
155
            Ok(_) => DynamoLlmResult::OK,
156
            Err(e) => {
157
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
158
                DynamoLlmResult::ERR
159
160
161
162
163
164
            }
        },
        Err(e) => e,
    }
}

165
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
166
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
167
168
169
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
170
            tracing::error!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
171
            return DynamoLlmResult::ERR;
172
173
174
175
176
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
177
    DynamoLlmResult::OK
178
179
}

180
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
181
182
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
183
184
185
186
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
187
188
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
189
190
191
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
192
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
193
194
    namespace: String,
    component: String,
195
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
196
) -> Result<KvEventPublisher, anyhow::Error> {
197
    tracing::info!("Creating KV Publisher for model: {}", component);
198
199
200
201
202
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
203
            let backend = drt.namespace(namespace)?.component(component)?;
Yan Ru Pei's avatar
Yan Ru Pei committed
204
            KvEventPublisher::new(backend, kv_block_size, None)
205
206
207
208
209
210
211
212
213
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
214
    kv_block_size: u32,
215
    lora_name: Option<&str>,
216
) -> KvCacheStoredBlockData {
217
218
219
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
220
        None,
221
        lora_name,
222
    )[0];
223
224
225
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
226
        mm_extra_info: None,
227
228
229
230
231
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
232
    kv_params: DynamoKvStoredEventParams,
233
    kv_block_size: u32,
234
235
236
237
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
238
239
240
241
242
243
244
245
246
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

247
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
248
249
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
250
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
251
252
253
                })
                .is_ok()
            {
254
                tracing::warn!(
255
256
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
257
258
                    num_toks
                );
259
260
261
262
263
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
264
265
266
267
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
268
            kv_params.lora_name.as_deref(),
269
270
271
272
273
274
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
275
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
276
        }),
277
        event_id: kv_params.event_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
278
        dp_rank: 0,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
Yan Ru Pei's avatar
Yan Ru Pei committed
296
        dp_rank: 0,
297
298
299
    }
}

300
301
302
303
304
305
306
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
307
    pub lora_name: Option<String>,
308
309
}

310
311
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
312
313
/// has a parent hash or not. nullptr is used to represent no parent hash.
/// lora_name is an optional null-terminated C string; pass nullptr for base model.
314
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
315
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
316
317
318
319
320
321
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
322
    lora_name: *const c_char,
Neelay Shah's avatar
Neelay Shah committed
323
) -> DynamoLlmResult {
324
325
326
327
328
329
330
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
331
332
333
334
335
336
337
338
339
340
341
    let lora_name = if lora_name.is_null() {
        None
    } else {
        match unsafe { CStr::from_ptr(lora_name) }.to_str() {
            Ok(s) => Some(s.to_owned()),
            Err(e) => {
                tracing::error!(error = ?e, "Failed to convert C string to Rust string (lora_name)");
                return DynamoLlmResult::ERR;
            }
        }
    };
342
    let kv_params = DynamoKvStoredEventParams {
343
344
345
346
347
348
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
349
        lora_name,
350
351
352
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
353
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
354
        Ok(_) => DynamoLlmResult::OK,
355
356
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
357
            DynamoLlmResult::ERR
358
359
360
361
        }
    }
}

362
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
363
pub extern "C" fn dynamo_kv_event_publish_removed(
364
365
366
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
367
) -> DynamoLlmResult {
368
369
370
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
371
        Ok(_) => DynamoLlmResult::OK,
372
373
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
374
            DynamoLlmResult::ERR
375
376
377
378
        }
    }
}

379
/* ------------------------------------------------------------------------
380
 *  Router Bindings for GAIE EPP
381
382
 * ------------------------------------------------------------------------ */

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
// Default timeout for bookkeeping operations
const BOOKKEEPING_TIMEOUT_SEC: u64 = 5;
/// Complete routing result for a chat completion request (C-compatible)
#[repr(C)]
pub struct CRoutingResult {
    /// Whether disaggregated mode is active
    pub is_disaggregated: bool,
    /// Prefill worker ID (only valid if is_disaggregated is true)
    pub prefill_worker_id: u64,
    /// Decode worker ID
    pub decode_worker_id: u64,
    /// Token IDs (needed for add_request callback)
    pub token_ids: *mut u32,
    /// Number of tokens in the request
    pub token_count: usize,
398
399
}

400
401
402
403
404
405
406
407
impl Default for CRoutingResult {
    fn default() -> Self {
        Self {
            is_disaggregated: false,
            prefill_worker_id: 0,
            decode_worker_id: 0,
            token_ids: ptr::null_mut(),
            token_count: 0,
408
        }
409
    }
410
}
411

412
413
414
415
416
417
418
419
420
421
422
423
424
/// Container holding routers and preprocessor for query routing
pub struct RouterHandles {
    prefill_router: Arc<PrefillRouter>,
    decode_router: Arc<KvRouter>,
    #[allow(dead_code)]
    model_manager: Arc<ModelManager>,
    #[allow(dead_code)]
    namespace: String,
    /// Cached runtime for executing async operations (avoids creating new runtime per call)
    runtime: Runtime,
    /// Preprocessor for tokenization and template application (fetched via discovery)
    preprocessor: Option<Arc<OpenAIPreprocessor>>,
}
425

426
427
428
429
430
431
impl RouterHandles {
    /// Query optimal prefill worker for a request.
    /// Returns worker_id on success.
    async fn query_prefill_worker(
        &self,
        tokens: &[u32],
432
        block_mm_infos: Option<&[Option<dynamo_llm::kv_router::protocols::BlockExtraInfo>]>,
433
434
435
436
437
        update_states: bool,
        lora_name: Option<String>,
        priority_jump: f64,
    ) -> Result<u64, QueryRouterResult> {
        self.prefill_router
438
439
440
441
442
443
444
            .query_prefill_worker(
                tokens,
                block_mm_infos,
                update_states,
                lora_name,
                priority_jump,
            )
445
446
447
448
449
450
451
            .await
            .map(|(worker_id, _dp_rank)| worker_id)
            .map_err(|e| {
                tracing::error!(error = ?e, "Prefill query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
452

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    /// Query optimal decode worker for a request.
    /// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
    /// (since KV cache is being transferred from prefill, not reused).
    ///
    /// Note: The C bindings are query-only and must not mutate router state during worker
    /// selection. State updates require a `context_id` (request id) and are managed via the
    /// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
    /// Returns (worker, overlap_blocks) on success.
    async fn query_decode_worker(
        &self,
        tokens: &[u32],
        is_disaggregated: bool,
    ) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
        // For decode phase in disaggregated mode, use overlap_score_weight=0
        // This matches prefill_router.rs
        let config_override = if is_disaggregated {
            Some(RouterConfigOverride {
                overlap_score_weight: Some(0.0),
                ..Default::default()
472
            })
473
474
475
476
        } else {
            None
        };

477
        self.decode_router
478
479
480
481
482
483
484
485
486
            .find_best_match(
                None,
                tokens,
                None,
                config_override.as_ref(),
                false,
                None,
                0.0,
            )
487
488
489
490
491
492
493
            .await
            .map_err(|e| {
                tracing::error!(error = ?e, "Decode query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
}
494

495
496
/// Opaque handle for the router pair
pub type RouterHandlesPtr = *mut RouterHandles;
497

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
/// Result codes for query router C FFI
#[repr(u32)]
pub enum QueryRouterResult {
    Ok = 0,
    ErrInvalidHandle = 1,
    ErrInvalidParam = 2,
    ErrInitFailed = 3,
    ErrQueryFailed = 4,
    ErrDisaggEnforced = 5,
    ErrTimeout = 6,
}

/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
///
/// Supported env vars (all optional — unset or empty values are ignored):
/// - `DYN_OVERLAP_SCORE_WEIGHT` — Weight for overlap score in worker selection (default: 1.0)
/// - `DYN_ROUTER_TEMPERATURE` — Temperature for worker sampling via softmax (default: 0.0)
/// - `DYN_USE_KV_EVENTS` — Use KV events for cache tracking (default: true)
/// - `DYN_ROUTER_REPLICA_SYNC` — Enable replica synchronization (default: false)
/// - `DYN_ROUTER_TRACK_ACTIVE_BLOCKS` — Track active blocks (default: true)
/// - `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` — Track output blocks during generation (default: false)
fn kv_router_config_from_env() -> KvRouterConfig {
    let mut cfg = KvRouterConfig::default();

    fn env_f64(key: &str) -> Option<f64> {
        std::env::var(key).ok().and_then(|v| v.parse().ok())
524
    }
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    fn env_bool(key: &str) -> Option<bool> {
        std::env::var(key)
            .ok()
            .and_then(|v| match v.to_lowercase().as_str() {
                "true" | "1" | "yes" | "on" => Some(true),
                "false" | "0" | "no" | "off" => Some(false),
                _ => None,
            })
    }

    if let Some(v) = env_f64("DYN_OVERLAP_SCORE_WEIGHT") {
        cfg.overlap_score_weight = v;
    }
    if let Some(v) = env_f64("DYN_ROUTER_TEMPERATURE") {
        cfg.router_temperature = v;
    }
    if let Some(v) = env_bool("DYN_USE_KV_EVENTS") {
        cfg.use_kv_events = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_REPLICA_SYNC") {
        cfg.router_replica_sync = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_ACTIVE_BLOCKS") {
        cfg.router_track_active_blocks = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") {
        cfg.router_track_output_blocks = v;
    }

    tracing::info!(
        overlap_score_weight = cfg.overlap_score_weight,
        router_temperature = cfg.router_temperature,
        use_kv_events = cfg.use_kv_events,
        router_replica_sync = cfg.router_replica_sync,
        router_track_active_blocks = cfg.router_track_active_blocks,
        router_track_output_blocks = cfg.router_track_output_blocks,
        "KvRouterConfig initialized (DYN_* env overrides applied)"
    );

    cfg
565
566
}

567
/// Create router handles for query-only routing
568
///
569
570
571
/// This function waits for at least one decode worker to be discovered before returning.
/// It auto-detects disaggregated mode by checking if prefill workers are present.
/// The KV cache block size is automatically fetched from the model card via discovery.
572
///
573
574
575
/// # Arguments
/// - `namespace`: Namespace for the model
/// - `component`: Component name (defaults to "backend" if NULL or empty)
576
/// - `decode_fallback`: If true, allows falling back to decode-only mode when no prefill workers are found
577
/// - `out_handle`: Output handle
578
///
579
580
581
/// # Safety
/// - All string parameters must be valid null-terminated C strings
/// - The returned handle must be freed with `destroy`
582
#[unsafe(no_mangle)]
583
584
585
pub unsafe extern "C" fn create_routers(
    namespace: *const c_char,
    component: *const c_char,
586
    decode_fallback: bool,
587
588
589
590
    out_handle: *mut RouterHandlesPtr,
) -> QueryRouterResult {
    if namespace.is_null() || out_handle.is_null() {
        return QueryRouterResult::ErrInvalidParam;
591
592
    }

593
594
595
    let namespace_str = match unsafe { CStr::from_ptr(namespace) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
596
597
    };

598
599
    let component_str = if component.is_null() {
        "backend".to_string()
600
    } else {
601
602
603
        match unsafe { CStr::from_ptr(component) }.to_str() {
            Ok(s) if !s.is_empty() => s.to_owned(),
            _ => "backend".to_string(),
604
605
606
        }
    };

607
608
609
    // Create the runtime once - it will be stored in RouterHandles and reused
    let runtime = match Runtime::from_settings() {
        Ok(rt) => rt,
610
        Err(e) => {
611
612
            tracing::error!(error = ?e, "Failed to create runtime");
            return QueryRouterResult::ErrInitFailed;
613
614
        }
    };
615
616
617
618
619
620
621
622
623
624

    // Clone for use inside the async block (the original will be moved into handles)
    let runtime_for_async = runtime.clone();

    let result = runtime_for_async.secondary().block_on(async {
        let drt = match DistributedRuntime::from_settings(runtime_for_async.clone()).await {
            Ok(drt) => drt,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create distributed runtime");
                return Err(QueryRouterResult::ErrInitFailed);
625
            }
626
627
628
629
630
631
632
633
634
635
        };

        // Wait for at least one worker to be discovered before proceeding
        // This ensures the decode router can be created successfully
        let instance_count = wait_for_discovery_sync(&drt).await;
        if instance_count == 0 {
            tracing::error!(
                "Discovery sync failed: no worker instances found. Is the backend running?"
            );
            return Err(QueryRouterResult::ErrInitFailed);
636
        }
637
638
639
640
        tracing::info!(
            "Discovery sync complete, {} worker(s) found",
            instance_count
        );
641

642
        let kv_router_config = kv_router_config_from_env();
643

644
645
646
647
648
649
650
651
        // Get component and endpoint
        let component_handle = match drt.namespace(&namespace_str) {
            Ok(ns) => match ns.component(&component_str) {
                Ok(c) => c,
                Err(e) => {
                    tracing::error!(error = ?e, "Failed to get component");
                    return Err(QueryRouterResult::ErrInitFailed);
                }
652
            },
653
654
655
656
657
658
            Err(e) => {
                tracing::error!(error = ?e, "Failed to get namespace");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
        let endpoint = component_handle.endpoint("generate");
659

660
        let model_manager = Arc::new(ModelManager::new());
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        // Fetch model card via discovery and create preprocessor + get block_size
        let (preprocessor, block_size, model_name) =
            match fetch_preprocessor_from_discovery(&drt, &namespace_str).await {
                Ok((prep, bs, name)) => {
                    tracing::info!(
                        kv_cache_block_size = bs,
                        "Preprocessor created from discovery"
                    );
                    (Some(prep), bs, name)
                }
                Err(e) => {
                    tracing::error!(
                        error = %e,
                        "Failed to fetch model card from discovery - cannot determine block_size"
                    );
                    return Err(QueryRouterResult::ErrInitFailed);
                }
            };
680

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        // Create decode router
        let decode_router = match model_manager
            .kv_chooser_for(
                &endpoint,
                block_size,
                Some(kv_router_config),
                WORKER_TYPE_DECODE,
            )
            .await
        {
            Ok(r) => r,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create decode router");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
697

698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        // Create PrefillRouter based on one-time discovery of prefill workers
        // Auto-detects disaggregated mode by checking if prefill workers are present
        // The prefill workers have to be created before the epp is created.
        // Given that we wait first for the decode worker to show up it is reasonable to assume the prefill will be up as well.
        let prefill_router = match find_prefill_endpoint(&drt, &namespace_str).await {
            Some(prefill_endpoint) => {
                tracing::info!("Prefill worker found, running in disaggregated mode");
                let mut prefill_config = kv_router_config;
                prefill_config.router_track_active_blocks = false;

                // Create immediately-resolved channel to activate router
                let (tx, rx) = tokio::sync::oneshot::channel();
                let _ = tx.send(prefill_endpoint);

                PrefillRouter::new(
                    rx,
                    model_manager.clone(),
                    RouterMode::KV,
                    block_size,
                    Some(prefill_config),
718
                    decode_fallback,
719
                    model_name.clone(),
720
                    namespace_str.clone(),
721
722
                )
            }
723
724
725
726
            None if !decode_fallback => {
                tracing::error!(
                    "Prefill workers required but none found and decode fallback is disabled"
                );
727
728
729
730
                return Err(QueryRouterResult::ErrDisaggEnforced);
            }
            None => {
                tracing::info!("No prefill workers found, running in aggregated mode");
731
                PrefillRouter::disabled(model_manager.clone(), RouterMode::KV, decode_fallback)
732
733
734
735
736
737
738
739
740
741
742
            }
        };

        Ok((
            prefill_router,
            decode_router,
            model_manager,
            namespace_str,
            preprocessor,
        ))
    });
743
744

    match result {
745
746
747
748
749
750
751
752
753
754
755
        Ok((prefill_router, decode_router, model_manager, namespace_str, preprocessor)) => {
            let handles = RouterHandles {
                prefill_router,
                decode_router,
                model_manager,
                namespace: namespace_str,
                runtime, // Store the runtime for reuse
                preprocessor,
            };
            unsafe { *out_handle = Box::into_raw(Box::new(handles)) };
            QueryRouterResult::Ok
756
        }
757
        Err(code) => code,
758
759
760
761
762
    }
}

/// Add a request to the router's bookkeeping after worker selection.
///
763
764
765
/// Register the request with the KvRouter's scheduler for tracking active blocks
/// and managing prefill/decode lifecycle. Call this after `query_decode` returns
/// worker IDs and before sending the request to the worker.
766
767
///
/// # Safety
768
769
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
770
771
/// - `token_ids` must point to at least `token_count` valid u32 values
#[unsafe(no_mangle)]
772
773
774
pub unsafe extern "C" fn add_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
775
776
777
778
    token_ids: *const u32,
    token_count: usize,
    worker_id: u64,
    dp_rank: u32,
779
780
781
782
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
783

784
785
786
787
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
788
789
790
791
792
793
794
795
    };

    let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
        unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
    } else {
        Vec::new()
    };

796
    let decode_router = handles.decode_router.clone();
797

798
799
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
800

801
802
        tokio::time::timeout(timeout_duration, async {
            let worker = WorkerWithDpRank::new(worker_id, dp_rank);
803

804
            // Compute overlap_blocks using the public method
805
806
807
808
            let overlap_blocks = match decode_router
                .get_overlap_blocks(&tokens, worker, None)
                .await
            {
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
                Ok(overlap) => overlap,
                Err(e) => {
                    tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
                    0
                }
            };

            decode_router
                .add_request(
                    request_id_str.clone(),
                    &tokens,
                    overlap_blocks,
                    None,
                    worker,
                    None, // lora_name
                    None, // router_config_override
                )
                .await;

            tracing::debug!(
                request_id = %request_id_str,
                worker_id = worker_id,
                dp_rank = dp_rank,
                overlap_blocks = overlap_blocks,
                token_count = tokens.len(),
                "add_request completed"
            );
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "add_request timed out"
            );
            QueryRouterResult::ErrTimeout
        }
    }
851
852
853
}

/// Mark prefill as completed for a request.
854
855
///
/// Call when the first token is generated to release prefill tokens from decode worker's load
856
857
///
/// # Safety
858
859
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
860
#[unsafe(no_mangle)]
861
862
863
864
865
866
867
pub unsafe extern "C" fn mark_prefill_complete(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
868

869
870
871
872
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
873
874
    };

875
    let decode_router = handles.decode_router.clone();
876

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);

        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.mark_prefill_completed(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to mark prefill complete"
                );
            } else {
                tracing::debug!(
                    request_id = %request_id_str,
                    "mark_prefill_complete completed"
                );
            }
        })
        .await
    });
896

897
898
899
    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
900
            tracing::warn!(
901
902
903
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "mark_prefill_complete timed out"
904
            );
905
            QueryRouterResult::ErrTimeout
906
        }
907
    }
908
909
910
}

/// Free a request from the router's bookkeeping.
911
912
///
/// Call this when the stream is closed (completed or cancelled) to release all resources.
913
914
///
/// # Safety
915
916
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
917
#[unsafe(no_mangle)]
918
919
920
921
922
923
924
pub unsafe extern "C" fn free_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
925

926
927
928
929
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
930
931
    };

932
    let decode_router = handles.decode_router.clone();
933

934
935
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
936

937
938
939
940
941
942
        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.free(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to free request"
943
                );
944
            } else {
945
                tracing::debug!(
946
947
                    request_id = %request_id_str,
                    "free_request completed"
948
                );
949
            }
950
951
952
953
954
955
956
957
958
959
960
961
962
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "free_request timed out"
            );
            QueryRouterResult::ErrTimeout
963
964
        }
    }
965
}
966

967
968
969
970
971
972
973
974
975
976
/// Destroy router handles
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle or null
/// - After this call, `handle` must not be used
#[unsafe(no_mangle)]
pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
    if !handle.is_null() {
        drop(unsafe { Box::from_raw(handle) });
    }
977
978
}

979
/// Route a chat completion request in a single call.
980
///
981
982
983
984
985
986
987
/// This is the main function for EPP to route a `/v1/chat/completions` request.
/// It combines tokenization and worker selection in one call:
/// 1. Applies the chat template to the request JSON
/// 2. Tokenizes the formatted prompt
/// 3. Queries the prefill router (if disaggregated mode)
/// 4. Queries the decode router
/// 5. Returns worker IDs and token_ids
988
///
989
990
991
992
993
994
/// After this call, EPP should:
/// - Call `add_request()` to register the request for bookkeeping
/// - Set worker ID headers and forward to backend
/// - Call `mark_prefill_complete()` on first token
/// - Call `free_request()` when the stream ends
/// - Call `free_routing_result()` to free the result
995
///
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
1009

1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    let handles = unsafe { &*handle };

    // Get preprocessor
    let preprocessor = match &handles.preprocessor {
        Some(p) => p,
        None => {
            tracing::error!("Preprocessor not available");
            return QueryRouterResult::ErrInitFailed;
        }
    };

    let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
        Ok(s) => s,
        Err(_) => return QueryRouterResult::ErrInvalidParam,
    };

    // Parse JSON
    let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
        match serde_json::from_str(json_str) {
            Ok(req) => req,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse request JSON");
                return QueryRouterResult::ErrInvalidParam;
            }
        };

    // Apply chat template
    let formatted_prompt = match preprocessor.apply_template(&request) {
        Ok(Some(prompt)) => prompt,
        Ok(None) => String::new(),
        Err(e) => {
            tracing::error!(error = ?e, "Failed to apply chat template");
            return QueryRouterResult::ErrQueryFailed;
        }
    };

    // Tokenize
    let encoding = match preprocessor.tokenize(&formatted_prompt) {
        Ok(enc) => enc,
        Err(e) => {
            tracing::error!(error = ?e, "Failed to tokenize");
            return QueryRouterResult::ErrQueryFailed;
        }
    };

    let tokens = encoding.token_ids();
    let token_count = tokens.len();
    let is_disaggregated = handles.prefill_router.is_activated();

    // Query workers
    let result = handles.runtime.secondary().block_on(async {
        let prefill_worker_id = if is_disaggregated {
            handles
1063
                .query_prefill_worker(tokens, None, false, None, 0.0)
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
                .await?
        } else {
            0
        };

        let (decode_worker, _overlap_blocks) = handles
            .query_decode_worker(tokens, is_disaggregated)
            .await?;

        tracing::info!(
            is_disaggregated = is_disaggregated,
            prefill_worker_id = prefill_worker_id,
            decode_worker_id = decode_worker.worker_id,
            decode_dp_rank = decode_worker.dp_rank,
            token_count = token_count,
            "Routed chat request"
        );

        Ok((prefill_worker_id, decode_worker))
1083
1084
    });

1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
    match result {
        Ok((prefill_worker_id, decode_worker)) => {
            // Allocate and copy token IDs for caller (needed for add_request bookkeeping)
            let token_vec: Vec<u32> = tokens.to_vec();
            let mut tokens_boxed = token_vec.into_boxed_slice();
            let token_ptr = tokens_boxed.as_mut_ptr();
            std::mem::forget(tokens_boxed);

            unsafe {
                *out_result = CRoutingResult {
                    is_disaggregated,
                    prefill_worker_id,
                    decode_worker_id: decode_worker.worker_id,
                    token_ids: token_ptr,
                    token_count,
                };
            }
            QueryRouterResult::Ok
        }
        Err(code) => code,
    }
1106
1107
}

1108
/// Free a routing result.
1109
///
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
/// # Safety
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
#[unsafe(no_mangle)]
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
    if result.is_null() {
        return;
    }

    let res = unsafe { &mut *result };

    // Free token IDs
    if !res.token_ids.is_null() && res.token_count > 0 {
        drop(unsafe {
            Box::from_raw(std::slice::from_raw_parts_mut(
                res.token_ids,
                res.token_count,
            ))
        });
        res.token_ids = ptr::null_mut();
        res.token_count = 0;
    }
1131
1132
}

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
/// Fetch model card via discovery and create preprocessor.
///
/// This function:
/// 1. Lists all models via discovery
/// 2. Finds the first model in the target namespace (decode workers only)
/// 3. Downloads the model config (tokenizer files) if needed
/// 4. Creates an OpenAIPreprocessor from the model card
/// 5. Returns the preprocessor, the kv_cache_block_size, and model_name from the model card
async fn fetch_preprocessor_from_discovery(
    drt: &DistributedRuntime,
    target_namespace: &str,
) -> anyhow::Result<(Arc<OpenAIPreprocessor>, u32, String)> {
1145
    use dynamo_llm::model_card::ModelDeploymentCard;
1146
    use dynamo_runtime::discovery::DiscoveryInstance;
1147

1148
    let discovery = drt.discovery();
1149

1150
1151
    // List all models
    let instances = discovery.list(DiscoveryQuery::AllModels).await?;
1152

1153
1154
    // Find first model card in the target namespace (decode workers only)
    let mut model_card: Option<ModelDeploymentCard> = None;
1155

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
    for instance in instances {
        if let DiscoveryInstance::Model { namespace, .. } = &instance {
            // Filter by namespace
            if namespace != target_namespace {
                continue;
            }

            match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => {
                    // Skip prefill-only workers, we want decode workers for routing
                    if card.model_type.supports_prefill()
                        && !card.model_type.supports_chat()
                        && !card.model_type.supports_completions()
1169
                    {
1170
                        continue;
1171
                    }
1172
1173
                    model_card = Some(card);
                    break;
1174
                }
1175
1176
1177
                Err(e) => {
                    tracing::debug!(error = %e, "Failed to deserialize model card, skipping");
                    continue;
1178
1179
1180
                }
            }
        }
1181
    }
1182

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
    let mut card = model_card.ok_or_else(|| {
        anyhow::anyhow!(
            "No model found in namespace '{}' via discovery",
            target_namespace
        )
    })?;

    let kv_cache_block_size = card.kv_cache_block_size;
    let model_name = card.name().to_string();
    tracing::info!(
        model_name = model_name,
        kv_cache_block_size = kv_cache_block_size,
        "Found model card via discovery"
1196
1197
    );

1198
1199
    // Download config (tokenizer files) if not local
    card.download_config().await?;
1200

1201
1202
1203
1204
    // Create preprocessor
    let preprocessor = OpenAIPreprocessor::new(card)?;
    Ok((preprocessor, kv_cache_block_size, model_name))
}
1205

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
/// Find a prefill endpoint from already-discovered instances (one-time filter).
/// Returns the endpoint if a prefill worker is found in the target namespace.
async fn find_prefill_endpoint(
    drt: &DistributedRuntime,
    target_namespace: &str,
) -> Option<dynamo_runtime::component::Endpoint> {
    use dynamo_llm::model_card::ModelDeploymentCard;
    use dynamo_runtime::discovery::DiscoveryInstance;

    let discovery = drt.discovery();
    let instances = match discovery.list(DiscoveryQuery::AllModels).await {
        Ok(instances) => instances,
        Err(e) => {
            tracing::warn!(error = %e, "Failed to list instances for prefill discovery");
            return None;
        }
1222
1223
    };

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
    for instance in instances {
        if let DiscoveryInstance::Model {
            namespace,
            component,
            endpoint,
            ..
        } = &instance
        {
            // Filter by namespace
            if namespace != target_namespace {
                continue;
            }
1236

1237
1238
1239
1240
            let card = match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => card,
                Err(_) => continue,
            };
1241

1242
1243
1244
1245
            // Only handle prefill models
            if !card.model_type.supports_prefill() {
                continue;
            }
1246

1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
            tracing::info!(
                model_name = card.name(),
                "Prefill worker found in discovered instances"
            );

            // Build and return the endpoint
            if let Ok(ns) = drt.namespace(namespace)
                && let Ok(comp) = ns.component(component)
            {
                return Some(comp.endpoint(endpoint));
            }
        }
    }
1260

1261
    None
1262
}