lib.rs 56.1 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
7
use std::borrow::Cow;
8
use std::ffi::CStr;
9
use std::sync::Arc;
10
11
use std::sync::atomic::{AtomicU32, Ordering};

12
13
use dynamo_llm::{
    discovery::{KvWorkerMonitor, ModelWatcher},
14
    kv_router::{protocols::*, publisher::KvEventPublisher},
15
};
16
use dynamo_runtime::discovery::DiscoveryQuery;
17
use dynamo_runtime::{DistributedRuntime, Worker};
18
19
20
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
21
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
/// - or the resulting string is empty/whitespace.
#[inline]
unsafe fn cstr_or_default<'a>(ptr: *const c_char, default_val: &'a str) -> Cow<'a, str> {
    if ptr.is_null() {
        return Cow::from(default_val);
    }
    match unsafe { CStr::from_ptr(ptr) }
        .to_str()
        .ok()
        .map(|s| s.trim())
    {
        Some(s) if !s.is_empty() => Cow::from(s.to_owned()),
        _ => Cow::from(default_val),
    }
}

42
43
44
45
46
47
48
fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
49
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
50

51
    tracing::debug!("Tracing initialized");
52
53
54
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
55
pub enum DynamoLlmResult {
56
57
58
59
    OK = 0,
    ERR = 1,
}

60
61
62
63
64
65
66
67
68
69
70
71
/// Default timeout for discovery sync (seconds).
const DEFAULT_DISCOVERY_TIMEOUT_SEC: u64 = 10;

/// Get discovery timeout from environment variable or use default.
/// Reads DYN_DISCOVERY_TIMEOUT_SEC env var (in seconds).
fn get_discovery_timeout_secs() -> u64 {
    std::env::var("DYN_DISCOVERY_TIMEOUT_SEC")
        .ok()
        .and_then(|s| s.parse::<u64>().ok())
        .unwrap_or(DEFAULT_DISCOVERY_TIMEOUT_SEC)
}

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/// Wait for the discovery daemon to sync and return at least one instance.
/// This ensures list() calls will have data available.
/// Returns the number of instances found, or 0 if timed out.
async fn wait_for_discovery_sync(drt: &DistributedRuntime, timeout_secs: u64) -> usize {
    tracing::info!("Waiting for discovery to sync...");
    let discovery = drt.discovery();
    let timeout = std::time::Duration::from_secs(timeout_secs);
    let start = std::time::Instant::now();

    loop {
        match discovery.list(DiscoveryQuery::AllModels).await {
            Ok(instances) if !instances.is_empty() => {
                tracing::info!(
                    "Discovery sync complete: found {} instances",
                    instances.len()
                );
                return instances.len();
            }
            Ok(_) => {
                if start.elapsed() > timeout {
                    tracing::warn!("Discovery sync timed out waiting for instances");
                    return 0;
                }
                tracing::debug!("No instances yet, waiting...");
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
            }
            Err(e) => {
                tracing::warn!("Discovery list error: {}, continuing...", e);
                return 0;
            }
        }
    }
}

106
/// # Safety
GuanLuo's avatar
GuanLuo committed
107
/// the namespace_c_str and component_c_str are passed as pointers to C strings
108
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
109
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
110
111
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
112
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
113
) -> DynamoLlmResult {
114
115
116
117
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
118
            tracing::error!(error = ?e, "Failed to initialize runtime (Worker::from_settings)");
Neelay Shah's avatar
Neelay Shah committed
119
            return DynamoLlmResult::ERR;
120
121
122
123
124
125
126
127
128
129
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
130
131
132
133
134
            Ok(drt) => {
                // Wait for discovery to sync before returning
                // This is needed because dynamo_create_worker_selection_pipeline() is called
                // immediately after, and it needs discovery.list() to return data
                // the discovery daemon takes time to query K8s and returns async, so we need to wait.
135
136
                let timeout_secs = get_discovery_timeout_secs();
                let instance_count = wait_for_discovery_sync(drt, timeout_secs).await;
137
138
139
140
141
142
143
144
                if instance_count == 0 {
                    tracing::error!(
                        "Discovery sync failed: no worker instances found. Is the backend running?"
                    );
                    return Err(DynamoLlmResult::ERR);
                }
                Ok(())
            }
145
            Err(e) => {
146
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
147
                Err(DynamoLlmResult::ERR)
148
149
150
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
151
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
152
153
        Ok(s) => s.to_string(),
        Err(e) => {
154
            tracing::error!(error = ?e, "Failed to convert C string to Rust string (namespace)");
Neelay Shah's avatar
Neelay Shah committed
155
            return DynamoLlmResult::ERR;
156
157
158
        }
    };

159
160
161
162
163
    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();
164
165

    match result {
166
        Ok(_) => match KV_PUB.get_or_try_init(move || {
Yan Ru Pei's avatar
Yan Ru Pei committed
167
            dynamo_create_kv_publisher(namespace, component, kv_block_size)
168
        }) {
Neelay Shah's avatar
Neelay Shah committed
169
            Ok(_) => DynamoLlmResult::OK,
170
            Err(e) => {
171
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
172
                DynamoLlmResult::ERR
173
174
175
176
177
178
            }
        },
        Err(e) => e,
    }
}

179
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
180
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
181
182
183
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
184
            tracing::error!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
185
            return DynamoLlmResult::ERR;
186
187
188
189
190
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
191
    DynamoLlmResult::OK
192
193
}

194
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
195
196
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
197
198
199
200
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
201
202
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
203
204
205
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
206
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
207
208
    namespace: String,
    component: String,
209
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
210
) -> Result<KvEventPublisher, anyhow::Error> {
211
    tracing::info!("Creating KV Publisher for model: {}", component);
212
213
214
215
216
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
217
            let backend = drt.namespace(namespace)?.component(component)?;
Yan Ru Pei's avatar
Yan Ru Pei committed
218
            KvEventPublisher::new(backend, kv_block_size, None)
219
220
221
222
223
224
225
226
227
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
228
    kv_block_size: u32,
229
230
    _lora_id: u64,
) -> KvCacheStoredBlockData {
231
232
233
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
234
        None,
235
    )[0];
236
237
238
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
239
        mm_extra_info: None,
240
241
242
243
244
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
245
    kv_params: DynamoKvStoredEventParams,
246
    kv_block_size: u32,
247
248
249
250
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
251
252
253
254
255
256
257
258
259
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

260
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
261
262
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
263
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
264
265
266
                })
                .is_ok()
            {
267
                tracing::warn!(
268
269
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
270
271
                    num_toks
                );
272
273
274
275
276
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
277
278
279
280
281
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
            kv_params.lora_id,
282
283
284
285
286
287
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
288
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
289
        }),
290
        event_id: kv_params.event_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
291
        dp_rank: 0,
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
Yan Ru Pei's avatar
Yan Ru Pei committed
309
        dp_rank: 0,
310
311
312
    }
}

313
314
315
316
317
318
319
320
321
322
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
    pub lora_id: u64,
}

323
324
325
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
326
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
327
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
328
329
330
331
332
333
334
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
    lora_id: u64,
Neelay Shah's avatar
Neelay Shah committed
335
) -> DynamoLlmResult {
336
337
338
339
340
341
342
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
343
    let kv_params = DynamoKvStoredEventParams {
344
345
346
347
348
349
350
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
        lora_id,
351
352
353
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
354
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
355
        Ok(_) => DynamoLlmResult::OK,
356
357
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
358
            DynamoLlmResult::ERR
359
360
361
362
        }
    }
}

363
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
364
pub extern "C" fn dynamo_kv_event_publish_removed(
365
366
367
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
368
) -> DynamoLlmResult {
369
370
371
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
372
        Ok(_) => DynamoLlmResult::OK,
373
374
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
375
            DynamoLlmResult::ERR
376
377
378
379
        }
    }
}

380
381
382
383
384
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
//     use super::*;
//     use std::ffi::CString;
385

386
387
388
389
390
//     #[test]
//     fn test_dynamo_llm_init() {
//         // Create C-compatible strings
//         let namespace = CString::new("test_namespace").unwrap();
//         let component = CString::new("test_component").unwrap();
391

392
393
394
395
396
397
398
399
400
//         // Call the init function
//         let result = unsafe {
//             dynamo_llm_init(
//                 namespace.as_ptr(),
//                 component.as_ptr(),
//                 1,  // worker_id
//                 32, // kv_block_size
//             )
//         };
401

402
//         assert_eq!(result as u32, DynamoLlmResult::OK as u32);
403

404
//         assert!(WK.get().is_some());
405

406
407
408
//         let shutdown_result = dynamo_llm_shutdown();
//         assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
//     }
409
// }
410
411
412
/* ------------------------------------------------------------------------
 * Worker selection pipeline
 * ------------------------------------------------------------------------ */
413
use std::pin::Pin;
414
415
416
417

const GENERATE_ENDPOINT: &str = "generate";

use anyhow::Context;
418
use dynamo_runtime::{Runtime, traits::DistributedRuntimeProvider};
419
420
421

use dynamo_llm::discovery::ModelManager;
use dynamo_llm::entrypoint::build_routed_pipeline;
422
use dynamo_llm::http::service::metrics::Metrics;
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::protocols::openai::nvext::NvExt;
use dynamo_llm::types::{
    Annotated,
    openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
};
use dynamo_runtime::{
    engine::AsyncEngineStream,
    pipeline::{ManyOut, RouterMode, ServiceEngine, SingleIn},
};
/// Opaque handle exposed to C — it owns its own Worker/runtime and engine.
pub struct WorkerSelectionPipeline {
    wk: Worker,
    engine: ServiceEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    >,
443
444
    /// KV router for bookkeeping operations (only present when router_mode is KV)
    kv_router: Option<Arc<dynamo_llm::kv_router::KvRouter>>,
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
}

/// Create a worker-selection pipeline ("generate" endpoint).
///
/// # Safety
/// - `namespace_c_str`, `component_c_str`, and `model_name_c_str` must be **non-null** pointers to
///   **NUL-terminated** C strings that contain **valid UTF-8**. They must remain valid for the
///   duration of this call.
/// - `pipeline_out` must be **non-null** and point to writable memory for a `*mut WorkerSelectionPipeline`.
///   On success this function writes exactly once to `*pipeline_out`. The caller becomes the owner of
///   that pointer and **must** later free it by calling `dynamo_destroy_worker_selection_pipeline`.
/// - Must be called **after** a successful `dynamo_llm_init()`; otherwise behavior is undefined.
/// - This function is not signal-safe and must not be called from a signal handler.
/// - This function may block internally; do not call it from contexts that forbid blocking.
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` on failure and does not write to `pipeline_out`.
462
463
464
465
/// # Safety
/// See detailed safety docs above. Additional parameter:
/// - `enforce_disagg`: If true, requests fail when disaggregated serving is unavailable.
///   If false, falls back to aggregated serving.
466
467
468
469
470
471
472
473
474
475
476
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
    model_name_c_str: *const c_char,
    use_kv_routing: bool,
    busy_threshold: f64,
    overlap_score_weight: f64,
    router_temperature: f64,
    use_kv_events: bool,
    router_replica_sync: bool,
477
    enforce_disagg: bool,
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    pipeline_out: *mut *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
    if pipeline_out.is_null() {
        tracing::error!("pipeline_out pointer is null");
        return DynamoLlmResult::ERR;
    }

    let wk = match WK.get() {
        Some(w) => w.clone(),
        None => {
            tracing::error!("Worker not initialized. Call dynamo_llm_init first.");
            return DynamoLlmResult::ERR;
        }
    };

    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(e) => {
            tracing::error!(error = ?e, "bad namespace");
            return DynamoLlmResult::ERR;
        }
    };
500
501
502
503
504
505
506

    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    let model = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(e) => {
            tracing::error!(error = ?e, "bad model");
            return DynamoLlmResult::ERR;
        }
    };

    let make_engine = || async {
        let router_mode = if use_kv_routing {
            RouterMode::KV
        } else {
            RouterMode::RoundRobin
        };

        let kv_router_config = if use_kv_routing {
            Some(KvRouterConfig::new(
                (overlap_score_weight >= 0.0).then_some(overlap_score_weight),
                (router_temperature >= 0.0).then_some(router_temperature),
                Some(use_kv_events),
                Some(router_replica_sync),
528
                None, // track_active_blocks
529
                None, // track_output_blocks
530
                None, // assume_kv_reuse
531
532
533
534
535
                None, // router_snapshot_threshold
                None, // router_reset_states
                None, // router_ttl_secs
                None, // router_max_tree_size
                None, // router_prune_target_ratio
536
537
538
539
540
541
542
543
544
545
546
547
            ))
        } else {
            None
        };

        create_worker_selection_pipeline_chat(
            &namespace,
            &component,
            &model,
            router_mode,
            (busy_threshold >= 0.0).then_some(busy_threshold),
            kv_router_config,
548
            enforce_disagg,
549
550
551
552
        )
        .await
    };

553
    let (engine, kv_router) = match wk.runtime().secondary().block_on(make_engine()) {
554
555
556
557
558
559
560
        Ok(p) => p,
        Err(e) => {
            tracing::error!(error = ?e, "create_worker_selection_pipeline_chat failed");
            return DynamoLlmResult::ERR;
        }
    };

561
562
563
564
565
    let handle = Box::new(WorkerSelectionPipeline {
        wk,
        engine,
        kv_router,
    });
566
567
568
569
570
571
572
    unsafe {
        *pipeline_out = Box::into_raw(handle);
    }
    DynamoLlmResult::OK
}

/// Query worker selection on an existing pipeline and return:
573
574
/// - `decode_worker_id_out` (`i64`): The decode worker ID (primary worker)
/// - `prefill_worker_id_out` (`i64`): The prefill worker ID (-1 if not in disaggregated mode)
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
/// - `token_ids_out` (heap-allocated `*mut u32`; caller must free via
///   `dynamo_free_worker_selection_result`)
/// - `token_count_out` (`usize`)
/// - `annotated_request_json_out` (`*mut c_char` to a NUL-terminated C string;
///   caller frees via the same free function)
///
/// # Safety
/// - `pipeline`
///   - Must be a **non-null** pointer previously returned by
///     `dynamo_create_worker_selection_pipeline` and not yet passed to
///     `dynamo_destroy_worker_selection_pipeline`.
///   - Must remain valid for the entire duration of this call.
///   - **Do not** call this function concurrently on the same `pipeline` pointer
///     from multiple threads unless the surrounding code guarantees synchronization.
/// - `request_json_c_str`
///   - Must be a **non-null**, **NUL-terminated** C string containing **valid UTF-8**.
///   - The JSON must represent a valid `NvCreateChatCompletionRequest`; otherwise this
///     function returns `DynamoLlmResult::ERR`.
///   - Must remain valid for the duration of this call.
/// - Output pointers:
595
///   - `decode_worker_id_out`, `prefill_worker_id_out`, `token_ids_out`, `token_count_out`,
596
597
///     and `annotated_request_json_out` must each be **non-null** and point to
///     writable memory for their respective types. On success, this function
598
///     writes to all five outputs exactly once.
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
///   - On **error**, outputs are left unmodified.
/// - Ownership & deallocation:
///   - On success, if there are zero tokens, `*token_ids_out` may be set to `NULL`
///     and `*token_count_out` set to `0`.
///   - If non-null, the buffer written to `*token_ids_out` is allocated with the
///     Rust global allocator and **must** be freed by calling
///     `dynamo_free_worker_selection_result` with the same `token_count_out` value.
///   - The pointer written to `*annotated_request_json_out` is a `CString` allocated
///     by Rust and **must** be freed by calling `dynamo_free_worker_selection_result`.
///   - **Do not** free these with `free(3)` or any other allocator; doing so is
///     undefined behavior.
/// - Blocking & context:
///   - This function may **block** internally while it performs async work; do not
///     call it from contexts that forbid blocking (e.g., signal handlers).
/// - Process/ABI assumptions:
///   - The caller and callee must run in the same process and use the same Rust
///     global allocator for the paired allocation/free described above.
///   - This function is not signal-safe.
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` if any precondition fails (null/invalid pointers,
/// malformed UTF-8/JSON, pipeline errors, allocation failures, etc.). On error, no
/// output pointer is written.
622
623
624
625
626
627
///
/// # Output values
/// - `decode_worker_id_out`: The decode worker ID (primary worker in aggregated mode)
/// - `prefill_worker_id_out`: The prefill worker ID (only set in disaggregated mode, -1 if not present)
/// - `token_ids_out`, `token_count_out`: Token IDs and count
/// - `annotated_request_json_out`: The annotated request JSON
628
629
630
631
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
    pipeline: *mut WorkerSelectionPipeline,
    request_json_c_str: *const c_char,
632
633
    decode_worker_id_out: *mut i64,
    prefill_worker_id_out: *mut i64,
634
635
636
637
638
639
640
641
    token_ids_out: *mut *mut u32,
    token_count_out: *mut usize,
    annotated_request_json_out: *mut *mut c_char,
) -> DynamoLlmResult {
    if pipeline.is_null() {
        tracing::error!("Pipeline pointer is null");
        return DynamoLlmResult::ERR;
    }
642
643
    if decode_worker_id_out.is_null()
        || prefill_worker_id_out.is_null()
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
        || token_ids_out.is_null()
        || token_count_out.is_null()
        || annotated_request_json_out.is_null()
    {
        tracing::error!("One or more output pointers are null");
        return DynamoLlmResult::ERR;
    }

    let req_str = match unsafe { CStr::from_ptr(request_json_c_str) }.to_str() {
        Ok(s) => s,
        Err(e) => {
            tracing::error!(error = ?e, "bad request json");
            return DynamoLlmResult::ERR;
        }
    };
    let request: NvCreateChatCompletionRequest = match serde_json::from_str(req_str) {
        Ok(r) => r,
        Err(e) => {
            tracing::error!(error = ?e, "parse request failed");
            return DynamoLlmResult::ERR;
        }
    };

    let pl = unsafe { &*pipeline };
    let fut = async { query_worker_selection_and_annotate(&pl.engine, request).await };
669
    let (result, annotated_req) = match pl.wk.runtime().secondary().block_on(fut) {
670
671
672
673
674
675
676
        Ok(v) => v,
        Err(e) => {
            tracing::error!(error = ?e, "query_worker_selection_and_annotate failed");
            return DynamoLlmResult::ERR;
        }
    };

677
    let tokens_ptr = if result.tokens.is_empty() {
678
679
        std::ptr::null_mut()
    } else {
680
        let len = result.tokens.len();
681
682
683
684
685
686
687
        let layout = std::alloc::Layout::array::<u32>(len).unwrap();
        let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
        if ptr.is_null() {
            tracing::error!("alloc tokens failed");
            return DynamoLlmResult::ERR;
        }
        unsafe {
688
            std::ptr::copy_nonoverlapping(result.tokens.as_ptr(), ptr, len);
689
690
691
692
693
694
695
696
        }
        ptr
    };

    let annotated_json = match serde_json::to_string(&annotated_req) {
        Ok(s) => s,
        Err(e) => {
            if !tokens_ptr.is_null() {
697
698
699
700
                let layout = std::alloc::Layout::array::<u32>(result.tokens.len()).unwrap();
                unsafe {
                    std::alloc::dealloc(tokens_ptr as *mut u8, layout);
                }
701
702
703
704
705
706
707
708
709
710
                tracing::error!(error = ?e, "serialize annotated request failed");
            }
            return DynamoLlmResult::ERR;
        }
    };
    let cjson = match std::ffi::CString::new(annotated_json) {
        Ok(c) => c,
        Err(e) => {
            tracing::error!(error = ?e, "CString::new for annotated JSON failed");
            if !tokens_ptr.is_null() {
711
                let layout = std::alloc::Layout::array::<u32>(result.tokens.len()).unwrap();
712
713
714
715
716
717
718
719
                unsafe {
                    std::alloc::dealloc(tokens_ptr as *mut u8, layout);
                }
            }
            return DynamoLlmResult::ERR;
        }
    };
    unsafe {
720
721
        *decode_worker_id_out = result.decode_worker_id.unwrap_or(0);
        *prefill_worker_id_out = result.prefill_worker_id.unwrap_or(-1);
722
        *token_ids_out = tokens_ptr;
723
        *token_count_out = result.tokens.len();
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
        *annotated_request_json_out = cjson.into_raw();
    }
    DynamoLlmResult::OK
}

/// Destroy a previously created pipeline.
///
/// # Safety
/// - `pipeline`
///   - **Must** be a non-null pointer that was **originally returned by**
///     `dynamo_create_worker_selection_pipeline` (i.e., obtained via
///     `Box::into_raw` on a `WorkerSelectionPipeline`).
///   - **Must not** have been passed to this function (or otherwise freed)
///     before. Passing the same pointer twice is a **double free** and is
///     undefined behavior.
///   - **Must not** be used by any other thread while this function runs.
///     Ensure no concurrent calls are in flight that read or write through
///     this handle (e.g., `dynamo_query_worker_selection_and_annotate`).
///   - After a successful call, the pointer is **invalid** and must not be
///     dereferenced or used again in any way.
/// - Allocator/ABI
///   - The caller and callee must be in the same process and share the same
///     allocator; this function reclaims the allocation that was created by
///     Rust for the handle.
/// - Lifetime/FFI
///   - Do not call from contexts that forbid blocking or running destructors
///     (e.g., signal handlers).
///
/// # Errors
/// - Returns `DynamoLlmResult::ERR` if `pipeline` is null.
/// - On `OK`, ownership of `pipeline` is taken and the underlying resources
///   are dropped; using the pointer after return is undefined behavior.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_destroy_worker_selection_pipeline(
    pipeline: *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
    if pipeline.is_null() {
        tracing::error!("Pipeline pointer is null");
        return DynamoLlmResult::ERR;
    }
    let _boxed: Box<WorkerSelectionPipeline> = unsafe { Box::from_raw(pipeline) };
    DynamoLlmResult::OK
}

/// Free buffers allocated by `dynamo_query_worker_selection_and_annotate`.
///
/// # Safety
/// - `token_ids` and `annotated_request_json` **must come from this library**:
///   - `token_ids` must be the exact pointer previously returned by
///     `dynamo_query_worker_selection_and_annotate` for the tokens buffer,
///     allocated with Rust’s global allocator in this process.
///   - `annotated_request_json` must be the exact pointer previously returned by
///     `CString::into_raw` inside `dynamo_query_worker_selection_and_annotate`.
/// - **Call at most once** per pointer. Passing the same pointer again is a
///   double-free and is undefined behavior.
/// - Pointer/length invariants:
///   - If `token_ids` is non-null, `token_count` **must** be the exact length
///     originally returned. Mismatched lengths cause invalid deallocation.
///   - If `token_ids` is null, `token_count` should be `0`.
///   - Passing a non-null `token_ids` with `token_count == 0` will leak in this
///     implementation (we only dealloc when `token_count > 0`).
/// - After return, the pointers are **invalid** and must not be used again.
/// - The caller and callee must be in the same process and share the same
///   allocator/ABI (these deallocations use Rust’s global allocator).
/// - Ensure no other threads are concurrently reading/writing these buffers when
///   freeing them.
/// - Do not call from contexts that forbid running destructors (e.g., signal handlers).
///
/// Returns `DynamoLlmResult::OK` on success.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_free_worker_selection_result(
    token_ids: *mut u32,
    token_count: usize,
    annotated_request_json: *mut c_char,
) -> DynamoLlmResult {
    if token_count > 0 {
        match std::alloc::Layout::array::<u32>(token_count) {
            Ok(layout) if !token_ids.is_null() => unsafe {
                std::alloc::dealloc(token_ids as *mut u8, layout);
            },
            _ => {}
        }
    }
    if !annotated_request_json.is_null() {
        unsafe {
            drop(std::ffi::CString::from_raw(annotated_request_json));
        }
    }
    DynamoLlmResult::OK
}

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
/// Default timeout for GAIE bookkeeping operations (30 seconds)
const GAIE_BOOKKEEPING_TIMEOUT_SECS: u64 = 30;

/// Helper to validate pipeline pointer and extract request_id from C string.
/// Returns `Err(DynamoLlmResult::ERR)` on validation failure, `Ok((pipeline_ref, request_id))` on success.
unsafe fn validate_pipeline_and_request_id(
    pipeline: *mut WorkerSelectionPipeline,
    request_id_c_str: *const c_char,
    operation: &str,
) -> Result<(&'static WorkerSelectionPipeline, String), DynamoLlmResult> {
    if pipeline.is_null() {
        tracing::error!("[GAIE] {} failed: pipeline pointer is null", operation);
        return Err(DynamoLlmResult::ERR);
    }

    let request_id = match unsafe { CStr::from_ptr(request_id_c_str) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(e) => {
            tracing::error!(error = ?e, "[GAIE] {} failed: bad request_id", operation);
            return Err(DynamoLlmResult::ERR);
        }
    };

    // SAFETY: Caller guarantees pipeline is valid for the duration of the call
    let pl: &'static WorkerSelectionPipeline = unsafe { &*pipeline };
    Ok((pl, request_id))
}

/// Helper to run an async bookkeeping operation with timeout.
/// Returns `OK` on success or timeout, `ERR` only on validation failures (handled by caller).
fn run_bookkeeping_with_timeout<F, Fut>(
    pl: &WorkerSelectionPipeline,
    operation: &'static str,
    request_id: &str,
    f: F,
) -> DynamoLlmResult
where
    F: FnOnce() -> Fut,
    Fut: std::future::Future<Output = ()>,
{
    use std::time::Duration;

    let timeout_duration = Duration::from_secs(GAIE_BOOKKEEPING_TIMEOUT_SECS);
    let fut = f();

    let result = pl
        .wk
        .runtime()
        .secondary()
        .block_on(async { tokio::time::timeout(timeout_duration, fut).await });

    match result {
        Ok(()) => DynamoLlmResult::OK,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id,
                timeout_secs = GAIE_BOOKKEEPING_TIMEOUT_SECS,
                "[GAIE] {} timed out",
                operation
            );
            // Return OK to avoid blocking the caller - the operation may still complete
            DynamoLlmResult::OK
        }
    }
}

/// Router bookkeeping functions for GAIE integration
/// Add a request to the router's bookkeeping after worker selection.
/// Call this from GAIE Stage 1 after `dynamo_query_worker_selection_and_annotate`.
///
/// This function computes the overlap_blocks internally by querying the indexer,
/// so the caller doesn't need to provide it.
///
/// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline`
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string
/// - `token_ids` must point to at least `token_count` valid u32 values
/// - Must not be called concurrently on the same pipeline without synchronization
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_add_request(
    pipeline: *mut WorkerSelectionPipeline,
    request_id_c_str: *const c_char,
    token_ids: *const u32,
    token_count: usize,
    worker_id: u64,
    dp_rank: u32,
) -> DynamoLlmResult {
    let (pl, request_id) = match unsafe {
        validate_pipeline_and_request_id(pipeline, request_id_c_str, "add_request")
    } {
        Ok(v) => v,
        Err(e) => return e,
    };

    let Some(ref kv_router) = pl.kv_router else {
        tracing::debug!(
            "[GAIE] KV router not available (router_mode is not KV), skipping add_request (no-op)"
        );
        return DynamoLlmResult::OK;
    };

    // Log after kv_router check to reduce noise
    tracing::debug!(
        request_id = %request_id,
        worker_id = worker_id,
        dp_rank = dp_rank,
        token_count = token_count,
        "[GAIE] dynamo_router_add_request processing"
    );

    let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
        unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
    } else {
        Vec::new()
    };

    let kv_router = kv_router.clone();
    let request_id_clone = request_id.clone();

    run_bookkeeping_with_timeout(pl, "add_request", &request_id, || async move {
        let worker = dynamo_llm::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);

        // Compute overlap_blocks using the public method
        let overlap_blocks = match kv_router.get_overlap_blocks(&tokens, worker).await {
            Ok(overlap) => overlap,
            Err(e) => {
                tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
                0
            }
        };

        kv_router
947
948
949
950
951
952
            .add_request(
                request_id_clone.clone(),
                &tokens,
                overlap_blocks,
                None,
                worker,
953
                None, // lora_name not exposed in C API yet
954
            )
955
956
957
958
959
960
961
962
963
964
965
966
967
968
            .await;

        tracing::debug!(
            request_id = %request_id_clone,
            worker_id = worker_id,
            dp_rank = dp_rank,
            overlap_blocks = overlap_blocks,
            token_count = tokens.len(),
            "[GAIE] dynamo_router_add_request completed - request registered in router bookkeeping"
        );
    })
}

/// Mark prefill as completed for a request.
969
/// Call this from the EPP extension point when the first token is generated.
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
///
/// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline`
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_mark_prefill_complete(
    pipeline: *mut WorkerSelectionPipeline,
    request_id_c_str: *const c_char,
) -> DynamoLlmResult {
    let (pl, request_id) = match unsafe {
        validate_pipeline_and_request_id(pipeline, request_id_c_str, "mark_prefill_complete")
    } {
        Ok(v) => v,
        Err(e) => return e,
    };

    let Some(ref kv_router) = pl.kv_router else {
        tracing::debug!(
            "[GAIE] KV router not available (router_mode is not KV), skipping mark_prefill_complete (no-op)"
        );
        return DynamoLlmResult::OK;
    };

    // Log after kv_router check to reduce noise
    tracing::debug!(
        request_id = %request_id,
        "[GAIE] dynamo_router_mark_prefill_complete processing"
    );

    let kv_router = kv_router.clone();
    let request_id_clone = request_id.clone();

    run_bookkeeping_with_timeout(pl, "mark_prefill_complete", &request_id, || async move {
        if let Err(e) = kv_router.mark_prefill_completed(&request_id_clone).await {
            tracing::warn!(
                "Failed to mark prefill completed for {}: {}",
                request_id_clone,
                e
            );
        } else {
            tracing::debug!(
                request_id = %request_id_clone,
                "[GAIE] dynamo_router_mark_prefill_complete completed - prefill tokens released"
            );
        }
    })
}

/// Free a request from the router's bookkeeping.
/// Call this from GAIE hook when the stream is closed (completed or cancelled).
///
/// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline`
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_free_request(
    pipeline: *mut WorkerSelectionPipeline,
    request_id_c_str: *const c_char,
) -> DynamoLlmResult {
    let (pl, request_id) = match unsafe {
        validate_pipeline_and_request_id(pipeline, request_id_c_str, "free_request")
    } {
        Ok(v) => v,
        Err(e) => return e,
    };

    let Some(ref kv_router) = pl.kv_router else {
        tracing::debug!(
            "[GAIE] KV router not available (router_mode is not KV), skipping free_request (no-op)"
        );
        return DynamoLlmResult::OK;
    };

    // Log after kv_router check to reduce noise
    tracing::debug!(
        request_id = %request_id,
        "[GAIE] dynamo_router_free_request processing"
    );

    let kv_router = kv_router.clone();
    let request_id_clone = request_id.clone();

    run_bookkeeping_with_timeout(pl, "free_request", &request_id, || async move {
        if let Err(e) = kv_router.free(&request_id_clone).await {
            tracing::warn!("Failed to free request {}: {}", request_id_clone, e);
        } else {
            tracing::debug!(
                request_id = %request_id_clone,
                "[GAIE] dynamo_router_free_request completed - request removed from bookkeeping"
            );
        }
    })
}

1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
/// Result of worker selection extraction
#[derive(Debug, Clone, Default)]
pub struct WorkerSelectionResult {
    /// Decode worker ID (primary worker for aggregated, decode-only for disaggregated)
    pub decode_worker_id: Option<i64>,
    /// Prefill worker ID (only present in disaggregated mode)
    pub prefill_worker_id: Option<i64>,
    /// Token IDs from tokenization
    pub tokens: Vec<u32>,
}

1075
/// Helper function to extract worker selection information from the annotation stream
1076
1077
1078
1079
///
/// The response format (from disaggregated_params in nvext):
/// - worker_id: {"prefill_worker_id": 123, "decode_worker_id": 456}
/// - token_ids: [1, 2, 3, ...]
1080
1081
pub async fn extract_worker_selection_from_stream(
    mut stream: Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
1082
1083
) -> anyhow::Result<WorkerSelectionResult> {
    use dynamo_llm::protocols::openai::nvext::WorkerIdInfo;
1084
1085
    use futures::StreamExt;

1086
    let mut result = WorkerSelectionResult::default();
1087
1088

    while let Some(response) = stream.next().await {
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        // Check for data in nvext (worker_id and token_ids are direct fields)
        // nvext is a serde_json::Value, so we access it as a JSON object
        if let Some(data) = &response.data
            && let Some(nvext) = &data.nvext
        {
            // Extract worker_id
            if let Some(worker_id_value) = nvext.get("worker_id")
                && let Ok(worker_info) =
                    serde_json::from_value::<WorkerIdInfo>(worker_id_value.clone())
            {
                result.decode_worker_id = worker_info.decode_worker_id.map(|id| id as i64);
                result.prefill_worker_id = worker_info.prefill_worker_id.map(|id| id as i64);
                tracing::debug!(
                    decode_worker_id = ?result.decode_worker_id,
                    prefill_worker_id = ?result.prefill_worker_id,
                    "Parsed worker_id from nvext"
                );
1106
1107
            }

1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
            // Extract token_ids
            if let Some(token_ids_value) = nvext.get("token_ids")
                && let Ok(parsed_tokens) =
                    serde_json::from_value::<Vec<u32>>(token_ids_value.clone())
            {
                result.tokens = parsed_tokens;
                tracing::debug!(
                    "Successfully parsed {} tokens from nvext",
                    result.tokens.len()
                );
1118
1119
1120
1121
1122
            }
        }
    }

    tracing::info!(
1123
1124
1125
1126
        decode_worker_id = ?result.decode_worker_id,
        prefill_worker_id = ?result.prefill_worker_id,
        token_count = result.tokens.len(),
        "Worker selection extraction complete"
1127
    );
1128
    Ok(result)
1129
1130
1131
1132
1133
}

/// Utility function to add the "query_instance_id" annotation to an OpenAI request
///
/// This function modifies the request to include the annotation that signals the KV router
1134
/// to return worker selection information (worker_fid and token_data) instead of
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
/// performing actual inference.
///
/// # Parameters
/// - `request`: Mutable reference to the OpenAI chat completion request
///
/// # Returns
/// The same request with the "query_instance_id" annotation added
pub fn add_query_instance_id(
    request: &mut NvCreateChatCompletionRequest,
) -> &mut NvCreateChatCompletionRequest {
1145
1146
    // Send empty value - router treats empty as aggregated / aggregated worker selection
    set_kv_annotation(request, "query_instance_id".to_string(), "")
1147
1148
}

1149
1150
1151
1152
1153
1154
1155
1156
// Note: set_worker_ids_for_stage2 and set_token_data_for_stage2 have been removed.
// The EPP now handles routing configuration via HTTP headers:
// - `x-worker-instance-id`: decode worker ID
// - `x-prefill-instance-id`: prefill worker ID (disaggregated mode only)
// - `x-enable-local-updates`: set to "false" to disable router bookkeeping
//
// Body modifications are NOT sent to the inference engine (only headers are forwarded),
// so these functions were ineffective.
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181

/// Ensure `nvext` exists and return a mutable slice of annotations.
fn ensure_annotations(request: &mut NvCreateChatCompletionRequest) -> &mut Vec<String> {
    let nvext = request.nvext.get_or_insert_with(|| {
        NvExt::builder()
            .build()
            .expect("NvExt builder should not fail")
    });
    nvext.annotations.get_or_insert_with(Vec::new)
}

/// Set a `key:value` annotation.
fn set_kv_annotation(
    request: &mut NvCreateChatCompletionRequest,
    key: String, // <- owned, only one borrowed param remains
    value: impl Into<String>,
) -> &mut NvCreateChatCompletionRequest {
    let prefix = format!("{}:", key);
    let kv = format!("{}{}", prefix, value.into());
    let annotations = ensure_annotations(request);
    annotations.retain(|a| !a.starts_with(&prefix));
    annotations.push(kv);
    request
}

1182
/// Wrapper function that queries worker selection for GAIE Stage 1
1183
///
1184
1185
/// This function performs the complete GAIE Stage 1 flow:
/// 1. Clones the original request and adds "query_instance_id:" (empty) annotation
1186
/// 2. Calls engine.generate() with the modified request
1187
/// 3. Extracts worker_id info and tokens from the response stream
1188
1189
1190
1191
1192
1193
1194
1195
1196
/// 4. Returns WorkerSelectionResult and the original request
///
/// Note: The EPP (caller) is responsible for setting HTTP headers for Stage 2:
/// - `x-worker-instance-id`: decode worker ID
/// - `x-prefill-instance-id`: prefill worker ID (disaggregated mode only)
/// - `x-enable-local-updates`: "false" to disable router bookkeeping
///
/// Body modifications are NOT forwarded to the inference engine, so this function
/// does not modify the request body.
1197
1198
1199
1200
1201
1202
///
/// # Parameters
/// - `engine`: The worker selection pipeline engine
/// - `original_request`: The original OpenAI request to process
///
/// # Returns
1203
/// A tuple containing (WorkerSelectionResult, original_request)
1204
1205
1206
1207
1208
pub async fn query_worker_selection_and_annotate(
    engine: &ServiceEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    >,
1209
    original_request: NvCreateChatCompletionRequest,
1210
1211
) -> anyhow::Result<(WorkerSelectionResult, NvCreateChatCompletionRequest)> {
    // GAIE Stage 1: Query for worker selection
1212
1213
1214
1215
    let mut query_request = original_request.clone();
    add_query_instance_id(&mut query_request);
    let single_in = SingleIn::new(query_request);
    let response_stream = engine.generate(single_in).await?;
1216
    let result = extract_worker_selection_from_stream(response_stream).await?;
1217

1218
1219
1220
    // Return the original request unchanged.
    // The EPP sets routing headers (worker IDs, enable_local_updates) which the
    // Dynamo frontend reads via apply_header_routing_overrides().
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
    Ok((result, original_request))
}

/// Spawn a background task to watch for prefill models and activate prefill routers.
/// This is a lightweight watcher that only handles prefill model discovery.
fn spawn_prefill_watcher(
    drt: DistributedRuntime,
    model_manager: Arc<ModelManager>,
    target_namespace: String,
) {
    use dynamo_llm::model_card::ModelDeploymentCard;
    use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery};
    use dynamo_runtime::protocols::EndpointId;
    use futures::StreamExt;

    tokio::spawn(async move {
        let discovery = drt.discovery();
        let mut stream = match discovery
            .list_and_watch(DiscoveryQuery::AllModels, None)
            .await
        {
            Ok(s) => s,
            Err(e) => {
                tracing::error!(error = %e, "Failed to start prefill discovery stream");
                return;
            }
        };

        while let Some(result) = stream.next().await {
            let event = match result {
                Ok(e) => e,
                Err(e) => {
                    tracing::error!(error = %e, "Error in prefill discovery stream");
                    continue;
                }
            };

            match event {
                DiscoveryEvent::Added(instance) => {
                    let (endpoint_id, card) = match &instance {
                        DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            ..
                        } => {
                            // Filter by namespace
                            if namespace != &target_namespace {
                                continue;
                            }

                            let eid = EndpointId {
                                namespace: namespace.clone(),
                                component: component.clone(),
                                name: endpoint.clone(),
                            };

                            match instance.deserialize_model::<ModelDeploymentCard>() {
                                Ok(card) => (eid, card),
                                Err(_) => continue,
                            }
                        }
                        _ => continue,
                    };

                    // Only handle prefill models
                    if !card.model_type.supports_prefill() {
                        continue;
                    }

                    tracing::info!(
                        model_name = card.name(),
                        "Prefill model discovered, activating prefill router"
                    );

                    // Get the endpoint and activate the prefill router
                    if let Ok(ns) = drt.namespace(&endpoint_id.namespace)
                        && let Ok(comp) = ns.component(&endpoint_id.component)
                    {
                        let endpoint = comp.endpoint(&endpoint_id.name);
                        if let Err(e) = model_manager.activate_prefill_router(card.name(), endpoint)
                        {
                            tracing::warn!(
                                model_name = card.name(),
                                error = %e,
                                "Failed to activate prefill router"
                            );
                        } else {
                            tracing::info!(
                                model_name = card.name(),
                                "Prefill router activated successfully"
                            );
                        }
                    }
                }
1316
                DiscoveryEvent::Removed(id) => {
1317
1318
1319
1320
                    // Log removal for observability
                    // Note: The PrefillRouter remains active - worker availability
                    // is handled dynamically by the underlying Client's instance tracking
                    tracing::debug!(
1321
                        instance_id = id.instance_id(),
1322
1323
1324
1325
1326
1327
                        "Prefill worker instance removed from discovery"
                    );
                }
            }
        }
    });
1328
1329
}

1330
/// Create a worker selection pipeline for OpenAI Chat Completion requests
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
///
/// This is a concrete implementation that works specifically with NvCreateChatCompletionRequest
/// and is designed for use with C bindings. Uses the "generate" endpoint by default.
///
/// # Parameters
/// - `namespace`: namespace name
/// - `component_name`: component name
/// - `model_name`: Name/slug of the model to load
/// - `router_mode`: How to route requests (KV, RoundRobin, etc.)
/// - `busy_threshold`: Optional threshold for busy worker detection
/// - `kv_router_config`: Optional KV router configuration (only used when router_mode is KV)
1342
/// - `enforce_disagg`: If true, fail requests when disaggregated serving is unavailable
1343
1344
///
/// # Returns
1345
/// A tuple of (engine, kv_router) where kv_router is Some when router_mode is KV
1346
1347
1348
1349
1350
1351
1352
pub async fn create_worker_selection_pipeline_chat(
    namespace: &str,
    component_name: &str,
    model_name: &str,
    router_mode: RouterMode,
    busy_threshold: Option<f64>,
    kv_router_config: Option<KvRouterConfig>,
1353
    enforce_disagg: bool,
1354
) -> anyhow::Result<(
1355
1356
1357
1358
    ServiceEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    >,
1359
1360
    Option<Arc<dynamo_llm::kv_router::KvRouter>>,
)> {
1361
1362
    use dynamo_llm::kv_router::PrefillRouter;

1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
    // Use the global DRT singleton - initialize if not already done
    // Check if already initialized (by dynamo_llm_init) to avoid redundant sync wait
    let needs_sync = DRT.get().is_none();

    let distributed_runtime = DRT
        .get_or_try_init(async {
            tracing::debug!("Initializing DistributedRuntime singleton (standalone mode)");
            DistributedRuntime::from_settings(Runtime::from_settings()?).await
        })
        .await
        .map_err(|e| anyhow::anyhow!("Failed to initialize DistributedRuntime: {}", e))?;

    // Only wait for discovery sync if we just initialized the DRT
    // (dynamo_llm_init already does this when it initializes)
    if needs_sync {
1378
1379
        let timeout_secs = get_discovery_timeout_secs();
        let instance_count = wait_for_discovery_sync(distributed_runtime, timeout_secs).await;
1380
1381
1382
1383
1384
1385
        if instance_count == 0 {
            return Err(anyhow::anyhow!(
                "Discovery sync failed: no worker instances found. Is the backend running?"
            ));
        }
    }
1386
1387
1388
1389

    let component = distributed_runtime
        .namespace(namespace)?
        .component(component_name)?;
1390
1391
    let endpoint = component.endpoint(GENERATE_ENDPOINT);
    let client = endpoint.client().await?;
1392

1393
1394
1395
1396
    // Discover the model card by searching all instances with this model name
    tracing::debug!("Looking for model: {}", model_name);
    tracing::debug!("Namespace: {}", namespace);

1397
    let model_manager = Arc::new(ModelManager::new());
1398
1399
1400
    let router_config = dynamo_llm::entrypoint::RouterConfig {
        router_mode,
        kv_router_config: kv_router_config.unwrap_or_default(),
1401
1402
1403
1404
1405
        load_threshold_config: dynamo_llm::discovery::LoadThresholdConfig {
            active_decode_blocks_threshold: busy_threshold,
            active_prefill_tokens_threshold: None,
            active_prefill_tokens_threshold_frac: None,
        },
1406
        enforce_disagg,
1407
    };
1408
1409
    // Create metrics for migration tracking (not exposed via /metrics in C bindings)
    let metrics = Arc::new(Metrics::new());
1410
1411
1412
    let watcher = ModelWatcher::new(
        component.drt().clone(),
        model_manager.clone(),
1413
        router_config,
1414
        None,
1415
        metrics.clone(),
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
    );
    let cards = watcher
        .cards_for_model(model_name, Some(namespace), false)
        .await
        .with_context(|| format!("Failed to discover model: {}", model_name))?;

    tracing::debug!("Found {} cards for model {}", cards.len(), model_name);

    let card = cards.into_iter().next().ok_or_else(|| {
        tracing::error!("No ModelDeploymentCard found for model: {}", model_name);
        anyhow::anyhow!("ModelDeploymentCard not found for model: {}", model_name)
    })?;
1428
1429
1430
1431

    let chooser = if router_mode == RouterMode::KV {
        Some(
            model_manager
1432
                .kv_chooser_for(&endpoint, card.kv_cache_block_size, kv_router_config)
1433
1434
1435
1436
1437
1438
                .await?,
        )
    } else {
        None
    };

1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
    // Create prefill chooser for dynamic disaggregation support
    // This registers the model and returns a receiver that will be activated
    // when a prefill worker is discovered
    let prefill_chooser = model_manager
        .register_prefill_router(model_name.to_string())
        .map(|rx| {
            // Create prefill-specific config with track_active_blocks disabled
            let mut prefill_config = kv_router_config.unwrap_or_default();
            prefill_config.router_track_active_blocks = false;

            PrefillRouter::new(
                rx,
                model_manager.clone(),
                router_mode,
                card.kv_cache_block_size,
                Some(prefill_config),
                enforce_disagg,
            )
        });

    // Start background watcher for prefill model discovery
    // This will activate the prefill router when prefill workers join
    spawn_prefill_watcher(
        component.drt().clone(),
        model_manager.clone(),
        namespace.to_string(),
    );

1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
    // Download model config files from HuggingFace for EPP
    // The backend's card has NATS URLs which aren't accessible from EPP
    tracing::debug!(
        "Downloading model config files for EPP: {}",
        card.display_name
    );

    let local_path = dynamo_llm::hub::from_hf(&card.display_name, true)
        .await
        .with_context(|| {
            format!(
                "Failed to download model config files for: {}",
                card.display_name
            )
        })?;

    // Load a fresh card from local files, then copy runtime config from original card
    tracing::debug!("Loading ModelDeploymentCard from local path...");
    let mut card_with_local_files = ModelDeploymentCard::load_from_disk(&local_path, None)
        .with_context(|| format!("Failed to load card from disk: {:?}", local_path))?;

    // Copy runtime settings from the backend's card
    tracing::debug!("Copying runtime config from backend card...");
    card_with_local_files.runtime_config = card.runtime_config.clone();
    card_with_local_files.kv_cache_block_size = card.kv_cache_block_size;
    card_with_local_files.context_length = card.context_length;

    // Load the tokenizer from the downloaded files
    tracing::debug!("Loading tokenizer from local files...");
    let hf_tokenizer = card_with_local_files
        .tokenizer_hf()
        .with_context(|| format!("Failed to load tokenizer for: {}", card.display_name))?;

1500
1501
    // Create worker monitor if busy_threshold is set
    // Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
    let worker_monitor = busy_threshold.map(|t| {
        KvWorkerMonitor::new(
            client.clone(),
            dynamo_llm::discovery::LoadThresholdConfig {
                active_decode_blocks_threshold: Some(t),
                active_prefill_tokens_threshold: None,
                active_prefill_tokens_threshold_frac: None,
            },
        )
    });
1512

1513
1514
1515
    // Clone chooser before passing to build_routed_pipeline (which takes ownership)
    let kv_router = chooser.clone();

1516
1517
1518
1519
1520
1521
    let engine = build_routed_pipeline::<
        NvCreateChatCompletionRequest,
        NvCreateChatCompletionStreamResponse,
    >(
        &card_with_local_files,
        &client,
1522
        model_manager.clone(),
1523
        router_mode,
1524
        worker_monitor,
1525
1526
        chooser,
        hf_tokenizer,
1527
1528
        prefill_chooser,
        enforce_disagg,
1529
        metrics,
1530
1531
1532
    )
    .await?;

1533
    Ok((engine, kv_router))
1534
}