lib.rs 37.2 KB
Newer Older
1
2
3
4
5
6
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
7
use std::borrow::Cow;
8
9
10
use std::ffi::CStr;
use std::sync::atomic::{AtomicU32, Ordering};

Neelay Shah's avatar
Neelay Shah committed
11
use dynamo_llm::kv_router::{
GuanLuo's avatar
GuanLuo committed
12
    indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
13
};
14
use dynamo_runtime::{DistributedRuntime, Worker};
15
16
17
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
18
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
/// - or the resulting string is empty/whitespace.
#[inline]
unsafe fn cstr_or_default<'a>(ptr: *const c_char, default_val: &'a str) -> Cow<'a, str> {
    if ptr.is_null() {
        return Cow::from(default_val);
    }
    match unsafe { CStr::from_ptr(ptr) }
        .to_str()
        .ok()
        .map(|s| s.trim())
    {
        Some(s) if !s.is_empty() => Cow::from(s.to_owned()),
        _ => Cow::from(default_val),
    }
}

39
40
41
42
43
44
45
fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

Ryan Olson's avatar
Ryan Olson committed
46
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
47

48
    tracing::debug!("Tracing initialized");
49
50
51
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
52
pub enum DynamoLlmResult {
53
54
55
56
57
    OK = 0,
    ERR = 1,
}

/// # Safety
GuanLuo's avatar
GuanLuo committed
58
/// the namespace_c_str and component_c_str are passed as pointers to C strings
59
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
60
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
61
62
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
63
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
64
) -> DynamoLlmResult {
65
66
67
68
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
69
            tracing::error!(error = ?e, "Failed to initialize runtime (Worker::from_settings)");
Neelay Shah's avatar
Neelay Shah committed
70
            return DynamoLlmResult::ERR;
71
72
73
74
75
76
77
78
79
80
81
82
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
            Ok(_) => Ok(()),
            Err(e) => {
83
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
84
                Err(DynamoLlmResult::ERR)
85
86
87
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
88
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
89
90
        Ok(s) => s.to_string(),
        Err(e) => {
91
            tracing::error!(error = ?e, "Failed to convert C string to Rust string (namespace)");
Neelay Shah's avatar
Neelay Shah committed
92
            return DynamoLlmResult::ERR;
93
94
95
        }
    };

96
97
98
99
100
    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();
101
102

    match result {
103
        Ok(_) => match KV_PUB.get_or_try_init(move || {
Yan Ru Pei's avatar
Yan Ru Pei committed
104
            dynamo_create_kv_publisher(namespace, component, kv_block_size)
105
        }) {
Neelay Shah's avatar
Neelay Shah committed
106
            Ok(_) => DynamoLlmResult::OK,
107
            Err(e) => {
108
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
109
                DynamoLlmResult::ERR
110
111
112
113
114
115
            }
        },
        Err(e) => e,
    }
}

116
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
117
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
118
119
120
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
121
            tracing::error!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
122
            return DynamoLlmResult::ERR;
123
124
125
126
127
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
128
    DynamoLlmResult::OK
129
130
}

131
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
132
133
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
134
135
136
137
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
138
139
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
140
141
142
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
143
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
144
145
    namespace: String,
    component: String,
146
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
147
) -> Result<KvEventPublisher, anyhow::Error> {
148
    tracing::info!("Creating KV Publisher for model: {}", component);
149
150
151
152
153
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
154
            let backend = drt.namespace(namespace)?.component(component)?;
Yan Ru Pei's avatar
Yan Ru Pei committed
155
            KvEventPublisher::new(backend, kv_block_size, None)
156
157
158
159
160
161
162
163
164
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
165
    kv_block_size: u32,
166
167
    _lora_id: u64,
) -> KvCacheStoredBlockData {
168
169
170
171
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
    )[0];
172
173
174
175
176
177
178
179
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
180
    kv_params: DynamoKvStoredEventParams,
181
    kv_block_size: u32,
182
183
184
185
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
186
187
188
189
190
191
192
193
194
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

195
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
196
197
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
198
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
199
200
201
                })
                .is_ok()
            {
202
                tracing::warn!(
203
204
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
205
206
                    num_toks
                );
207
208
209
210
211
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
212
213
214
215
216
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
            kv_params.lora_id,
217
218
219
220
221
222
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
223
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
224
        }),
225
        event_id: kv_params.event_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
226
        dp_rank: 0,
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
Yan Ru Pei's avatar
Yan Ru Pei committed
244
        dp_rank: 0,
245
246
247
    }
}

248
249
250
251
252
253
254
255
256
257
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
    pub lora_id: u64,
}

258
259
260
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
261
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
262
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
263
264
265
266
267
268
269
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
    lora_id: u64,
Neelay Shah's avatar
Neelay Shah committed
270
) -> DynamoLlmResult {
271
272
273
274
275
276
277
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
278
    let kv_params = DynamoKvStoredEventParams {
279
280
281
282
283
284
285
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
        lora_id,
286
287
288
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
289
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
290
        Ok(_) => DynamoLlmResult::OK,
291
292
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
293
            DynamoLlmResult::ERR
294
295
296
297
        }
    }
}

298
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
299
pub extern "C" fn dynamo_kv_event_publish_removed(
300
301
302
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
303
) -> DynamoLlmResult {
304
305
306
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
307
        Ok(_) => DynamoLlmResult::OK,
308
309
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
310
            DynamoLlmResult::ERR
311
312
313
314
        }
    }
}

315
316
317
318
319
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
//     use super::*;
//     use std::ffi::CString;
320

321
322
323
324
325
//     #[test]
//     fn test_dynamo_llm_init() {
//         // Create C-compatible strings
//         let namespace = CString::new("test_namespace").unwrap();
//         let component = CString::new("test_component").unwrap();
326

327
328
329
330
331
332
333
334
335
//         // Call the init function
//         let result = unsafe {
//             dynamo_llm_init(
//                 namespace.as_ptr(),
//                 component.as_ptr(),
//                 1,  // worker_id
//                 32, // kv_block_size
//             )
//         };
336

337
//         assert_eq!(result as u32, DynamoLlmResult::OK as u32);
338

339
//         assert!(WK.get().is_some());
340

341
342
343
//         let shutdown_result = dynamo_llm_shutdown();
//         assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
//     }
344
// }
345
346
347
/* ------------------------------------------------------------------------
 * Worker selection pipeline
 * ------------------------------------------------------------------------ */
348
use std::pin::Pin;
349
350
351
352

const GENERATE_ENDPOINT: &str = "generate";

use anyhow::Context;
353
use dynamo_runtime::{Runtime, distributed::DistributedConfig, traits::DistributedRuntimeProvider};
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

use dynamo_llm::discovery::ModelManager;
use dynamo_llm::entrypoint::build_routed_pipeline;
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::protocols::openai::nvext::NvExt;
use dynamo_llm::types::{
    Annotated,
    openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
};
use dynamo_runtime::{
    engine::AsyncEngineStream,
    pipeline::{ManyOut, RouterMode, ServiceEngine, SingleIn},
};
/// Opaque handle exposed to C — it owns its own Worker/runtime and engine.
pub struct WorkerSelectionPipeline {
    wk: Worker,
    engine: ServiceEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    >,
}

/// Create a worker-selection pipeline ("generate" endpoint).
///
/// # Safety
/// - `namespace_c_str`, `component_c_str`, and `model_name_c_str` must be **non-null** pointers to
///   **NUL-terminated** C strings that contain **valid UTF-8**. They must remain valid for the
///   duration of this call.
/// - `pipeline_out` must be **non-null** and point to writable memory for a `*mut WorkerSelectionPipeline`.
///   On success this function writes exactly once to `*pipeline_out`. The caller becomes the owner of
///   that pointer and **must** later free it by calling `dynamo_destroy_worker_selection_pipeline`.
/// - Must be called **after** a successful `dynamo_llm_init()`; otherwise behavior is undefined.
/// - This function is not signal-safe and must not be called from a signal handler.
/// - This function may block internally; do not call it from contexts that forbid blocking.
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` on failure and does not write to `pipeline_out`.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
    model_name_c_str: *const c_char,
    use_kv_routing: bool,
    busy_threshold: f64,
    overlap_score_weight: f64,
    router_temperature: f64,
    use_kv_events: bool,
    router_replica_sync: bool,
    pipeline_out: *mut *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
    if pipeline_out.is_null() {
        tracing::error!("pipeline_out pointer is null");
        return DynamoLlmResult::ERR;
    }

    let wk = match WK.get() {
        Some(w) => w.clone(),
        None => {
            tracing::error!("Worker not initialized. Call dynamo_llm_init first.");
            return DynamoLlmResult::ERR;
        }
    };

    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(e) => {
            tracing::error!(error = ?e, "bad namespace");
            return DynamoLlmResult::ERR;
        }
    };
427
428
429
430
431
432
433

    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    let model = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(e) => {
            tracing::error!(error = ?e, "bad model");
            return DynamoLlmResult::ERR;
        }
    };

    let make_engine = || async {
        let router_mode = if use_kv_routing {
            RouterMode::KV
        } else {
            RouterMode::RoundRobin
        };

        let kv_router_config = if use_kv_routing {
            Some(KvRouterConfig::new(
                (overlap_score_weight >= 0.0).then_some(overlap_score_weight),
                (router_temperature >= 0.0).then_some(router_temperature),
                Some(use_kv_events),
                Some(router_replica_sync),
                None,
                None,
                None,
            ))
        } else {
            None
        };

        create_worker_selection_pipeline_chat(
            &namespace,
            &component,
            &model,
            router_mode,
            (busy_threshold >= 0.0).then_some(busy_threshold),
            kv_router_config,
        )
        .await
    };

    let engine = match wk.runtime().secondary().block_on(make_engine()) {
        Ok(p) => p,
        Err(e) => {
            tracing::error!(error = ?e, "create_worker_selection_pipeline_chat failed");
            return DynamoLlmResult::ERR;
        }
    };

    let handle = Box::new(WorkerSelectionPipeline { wk, engine });
    unsafe {
        *pipeline_out = Box::into_raw(handle);
    }
    DynamoLlmResult::OK
}

/// Query worker selection on an existing pipeline and return:
/// - `worker_instance_id_out` (`i64`)
/// - `token_ids_out` (heap-allocated `*mut u32`; caller must free via
///   `dynamo_free_worker_selection_result`)
/// - `token_count_out` (`usize`)
/// - `annotated_request_json_out` (`*mut c_char` to a NUL-terminated C string;
///   caller frees via the same free function)
///
/// # Safety
/// - `pipeline`
///   - Must be a **non-null** pointer previously returned by
///     `dynamo_create_worker_selection_pipeline` and not yet passed to
///     `dynamo_destroy_worker_selection_pipeline`.
///   - Must remain valid for the entire duration of this call.
///   - **Do not** call this function concurrently on the same `pipeline` pointer
///     from multiple threads unless the surrounding code guarantees synchronization.
/// - `request_json_c_str`
///   - Must be a **non-null**, **NUL-terminated** C string containing **valid UTF-8**.
///   - The JSON must represent a valid `NvCreateChatCompletionRequest`; otherwise this
///     function returns `DynamoLlmResult::ERR`.
///   - Must remain valid for the duration of this call.
/// - Output pointers:
///   - `worker_instance_id_out`, `token_ids_out`, `token_count_out`,
///     and `annotated_request_json_out` must each be **non-null** and point to
///     writable memory for their respective types. On success, this function
///     writes to all four outputs exactly once.
///   - On **error**, outputs are left unmodified.
/// - Ownership & deallocation:
///   - On success, if there are zero tokens, `*token_ids_out` may be set to `NULL`
///     and `*token_count_out` set to `0`.
///   - If non-null, the buffer written to `*token_ids_out` is allocated with the
///     Rust global allocator and **must** be freed by calling
///     `dynamo_free_worker_selection_result` with the same `token_count_out` value.
///   - The pointer written to `*annotated_request_json_out` is a `CString` allocated
///     by Rust and **must** be freed by calling `dynamo_free_worker_selection_result`.
///   - **Do not** free these with `free(3)` or any other allocator; doing so is
///     undefined behavior.
/// - Blocking & context:
///   - This function may **block** internally while it performs async work; do not
///     call it from contexts that forbid blocking (e.g., signal handlers).
/// - Process/ABI assumptions:
///   - The caller and callee must run in the same process and use the same Rust
///     global allocator for the paired allocation/free described above.
///   - This function is not signal-safe.
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` if any precondition fails (null/invalid pointers,
/// malformed UTF-8/JSON, pipeline errors, allocation failures, etc.). On error, no
/// output pointer is written.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
    pipeline: *mut WorkerSelectionPipeline,
    request_json_c_str: *const c_char,
    worker_instance_id_out: *mut i64,
    token_ids_out: *mut *mut u32,
    token_count_out: *mut usize,
    annotated_request_json_out: *mut *mut c_char,
) -> DynamoLlmResult {
    if pipeline.is_null() {
        tracing::error!("Pipeline pointer is null");
        return DynamoLlmResult::ERR;
    }
    if worker_instance_id_out.is_null()
        || token_ids_out.is_null()
        || token_count_out.is_null()
        || annotated_request_json_out.is_null()
    {
        tracing::error!("One or more output pointers are null");
        return DynamoLlmResult::ERR;
    }

    let req_str = match unsafe { CStr::from_ptr(request_json_c_str) }.to_str() {
        Ok(s) => s,
        Err(e) => {
            tracing::error!(error = ?e, "bad request json");
            return DynamoLlmResult::ERR;
        }
    };
    let request: NvCreateChatCompletionRequest = match serde_json::from_str(req_str) {
        Ok(r) => r,
        Err(e) => {
            tracing::error!(error = ?e, "parse request failed");
            return DynamoLlmResult::ERR;
        }
    };

    let pl = unsafe { &*pipeline };
    let fut = async { query_worker_selection_and_annotate(&pl.engine, request).await };
    let (worker_id, tokens, annotated_req) = match pl.wk.runtime().secondary().block_on(fut) {
        Ok(v) => v,
        Err(e) => {
            tracing::error!(error = ?e, "query_worker_selection_and_annotate failed");
            return DynamoLlmResult::ERR;
        }
    };

    let tokens_ptr = if tokens.is_empty() {
        std::ptr::null_mut()
    } else {
        let len = tokens.len();
        let layout = std::alloc::Layout::array::<u32>(len).unwrap();
        let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
        if ptr.is_null() {
            tracing::error!("alloc tokens failed");
            return DynamoLlmResult::ERR;
        }
        unsafe {
            std::ptr::copy_nonoverlapping(tokens.as_ptr(), ptr, len);
        }
        ptr
    };

    let annotated_json = match serde_json::to_string(&annotated_req) {
        Ok(s) => s,
        Err(e) => {
            let layout = std::alloc::Layout::array::<u32>(tokens.len()).unwrap();
            unsafe {
                std::alloc::dealloc(tokens_ptr as *mut u8, layout);
            }
            if !tokens_ptr.is_null() {
                tracing::error!(error = ?e, "serialize annotated request failed");
            }
            return DynamoLlmResult::ERR;
        }
    };
    let cjson = match std::ffi::CString::new(annotated_json) {
        Ok(c) => c,
        Err(e) => {
            tracing::error!(error = ?e, "CString::new for annotated JSON failed");
            if !tokens_ptr.is_null() {
                let layout = std::alloc::Layout::array::<u32>(tokens.len()).unwrap();
                unsafe {
                    std::alloc::dealloc(tokens_ptr as *mut u8, layout);
                }
            }
            return DynamoLlmResult::ERR;
        }
    };
    unsafe {
        *worker_instance_id_out = worker_id;
        *token_ids_out = tokens_ptr;
        *token_count_out = tokens.len();
        *annotated_request_json_out = cjson.into_raw();
    }
    DynamoLlmResult::OK
}

/// Destroy a previously created pipeline.
///
/// # Safety
/// - `pipeline`
///   - **Must** be a non-null pointer that was **originally returned by**
///     `dynamo_create_worker_selection_pipeline` (i.e., obtained via
///     `Box::into_raw` on a `WorkerSelectionPipeline`).
///   - **Must not** have been passed to this function (or otherwise freed)
///     before. Passing the same pointer twice is a **double free** and is
///     undefined behavior.
///   - **Must not** be used by any other thread while this function runs.
///     Ensure no concurrent calls are in flight that read or write through
///     this handle (e.g., `dynamo_query_worker_selection_and_annotate`).
///   - After a successful call, the pointer is **invalid** and must not be
///     dereferenced or used again in any way.
/// - Allocator/ABI
///   - The caller and callee must be in the same process and share the same
///     allocator; this function reclaims the allocation that was created by
///     Rust for the handle.
/// - Lifetime/FFI
///   - Do not call from contexts that forbid blocking or running destructors
///     (e.g., signal handlers).
///
/// # Errors
/// - Returns `DynamoLlmResult::ERR` if `pipeline` is null.
/// - On `OK`, ownership of `pipeline` is taken and the underlying resources
///   are dropped; using the pointer after return is undefined behavior.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_destroy_worker_selection_pipeline(
    pipeline: *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
    if pipeline.is_null() {
        tracing::error!("Pipeline pointer is null");
        return DynamoLlmResult::ERR;
    }
    let _boxed: Box<WorkerSelectionPipeline> = unsafe { Box::from_raw(pipeline) };
    DynamoLlmResult::OK
}

/// Free buffers allocated by `dynamo_query_worker_selection_and_annotate`.
///
/// # Safety
/// - `token_ids` and `annotated_request_json` **must come from this library**:
///   - `token_ids` must be the exact pointer previously returned by
///     `dynamo_query_worker_selection_and_annotate` for the tokens buffer,
///     allocated with Rust’s global allocator in this process.
///   - `annotated_request_json` must be the exact pointer previously returned by
///     `CString::into_raw` inside `dynamo_query_worker_selection_and_annotate`.
/// - **Call at most once** per pointer. Passing the same pointer again is a
///   double-free and is undefined behavior.
/// - Pointer/length invariants:
///   - If `token_ids` is non-null, `token_count` **must** be the exact length
///     originally returned. Mismatched lengths cause invalid deallocation.
///   - If `token_ids` is null, `token_count` should be `0`.
///   - Passing a non-null `token_ids` with `token_count == 0` will leak in this
///     implementation (we only dealloc when `token_count > 0`).
/// - After return, the pointers are **invalid** and must not be used again.
/// - The caller and callee must be in the same process and share the same
///   allocator/ABI (these deallocations use Rust’s global allocator).
/// - Ensure no other threads are concurrently reading/writing these buffers when
///   freeing them.
/// - Do not call from contexts that forbid running destructors (e.g., signal handlers).
///
/// Returns `DynamoLlmResult::OK` on success.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_free_worker_selection_result(
    token_ids: *mut u32,
    token_count: usize,
    annotated_request_json: *mut c_char,
) -> DynamoLlmResult {
    if token_count > 0 {
        match std::alloc::Layout::array::<u32>(token_count) {
            Ok(layout) if !token_ids.is_null() => unsafe {
                std::alloc::dealloc(token_ids as *mut u8, layout);
            },
            _ => {}
        }
    }
    if !annotated_request_json.is_null() {
        unsafe {
            drop(std::ffi::CString::from_raw(annotated_request_json));
        }
    }
    DynamoLlmResult::OK
}

/// Helper function to extract worker selection information from the annotation stream
pub async fn extract_worker_selection_from_stream(
    mut stream: Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
) -> anyhow::Result<(i64, Vec<u32>)> {
    use futures::StreamExt;

    let mut worker_id: i64 = 0;
    let mut tokens: Vec<u32> = Vec::new();

    while let Some(response) = stream.next().await {
        let Some(event) = &response.event else {
            tracing::error!("Response has no event field");
            continue;
        };

        match event.as_str() {
            "worker_instance_id" => {
                tracing::debug!("Found worker_instance_id event");

                let Some(first_comment) = response.comment.as_ref().and_then(|v| v.first()) else {
                    tracing::debug!("worker_instance_id event without comments");
                    continue;
                };

                // Try JSON string first (e.g. `"1732646935200805498"`), then plain integer.
                if let Ok(id_string) = serde_json::from_str::<String>(first_comment) {
                    match id_string.parse::<i64>() {
                        Ok(parsed_id) => {
                            worker_id = parsed_id;
                            tracing::debug!("parsed worker_id from JSON string: {}", worker_id);
                        }
                        Err(_) => {
                            tracing::error!(
                                "failed to parse number from JSON string: '{}'",
                                id_string
                            );
                        }
                    }
                    continue;
                }

                match first_comment.parse::<i64>() {
                    Ok(parsed_id) => {
                        worker_id = parsed_id;
                        tracing::debug!("parsed worker_id directly: {}", worker_id);
                    }
                    Err(_) => {
                        tracing::error!("failed to parse worker_id from: '{}'", first_comment);
                    }
                }
            }

            "token_data" => {
                tracing::debug!("Found token_data event");

                let Some(first_comment) = response.comment.as_ref().and_then(|v| v.first()) else {
                    tracing::debug!("token_data event without comments");
                    continue;
                };

                tracing::debug!("Token comment: '{}'", first_comment);
                match serde_json::from_str::<Vec<u32>>(first_comment) {
                    Ok(parsed_tokens) => {
                        tokens = parsed_tokens;
                        tracing::debug!("Successfully parsed {} tokens", tokens.len());
                    }
                    Err(e) => {
                        tracing::error!("Failed to parse tokens from '{}': {}", first_comment, e);
                    }
                }
            }

            other => {
                tracing::debug!("Unknown event type: '{}'", other);
            }
        }
    }

    tracing::info!(
        "Final worker_id={}, tokens.len()={}",
        worker_id,
        tokens.len()
    );
    Ok((worker_id, tokens))
}

/// Utility function to add the "query_instance_id" annotation to an OpenAI request
///
/// This function modifies the request to include the annotation that signals the KV router
/// to return worker selection information (worker_instance_id and token_data) instead of
/// performing actual inference.
///
/// # Parameters
/// - `request`: Mutable reference to the OpenAI chat completion request
///
/// # Returns
/// The same request with the "query_instance_id" annotation added
pub fn add_query_instance_id(
    request: &mut NvCreateChatCompletionRequest,
) -> &mut NvCreateChatCompletionRequest {
    add_annotation_unique(request, "query_instance_id")
}

/// Utility function to add worker_instance_id annotation to an OpenAI request
pub fn add_worker_instance_id_annotation(
    request: &mut NvCreateChatCompletionRequest,
    worker_id: i64,
) -> &mut NvCreateChatCompletionRequest {
    set_kv_annotation(
        request,
        "worker_instance_id".to_string(),
        worker_id.to_string(),
    )
}

/// Utility function to add token_data annotation to an OpenAI request
pub fn add_token_data_annotation<'a>(
    request: &'a mut NvCreateChatCompletionRequest,
    tokens: &[u32],
) -> &'a mut NvCreateChatCompletionRequest {
    let tokens_json = serde_json::to_string(tokens).unwrap_or_default();
    set_kv_annotation(request, "token_data".to_string(), tokens_json)
}

/// Ensure `nvext` exists and return a mutable slice of annotations.
fn ensure_annotations(request: &mut NvCreateChatCompletionRequest) -> &mut Vec<String> {
    let nvext = request.nvext.get_or_insert_with(|| {
        NvExt::builder()
            .build()
            .expect("NvExt builder should not fail")
    });
    nvext.annotations.get_or_insert_with(Vec::new)
}

/// Add a plain annotation once.
fn add_annotation_unique(
    request: &mut NvCreateChatCompletionRequest,
    annotation: impl Into<String>,
) -> &mut NvCreateChatCompletionRequest {
    let ann = annotation.into();
    let annotations = ensure_annotations(request);
    if !annotations.iter().any(|a| a == &ann) {
        annotations.push(ann);
    }
    request
}

/// Set a `key:value` annotation.
fn set_kv_annotation(
    request: &mut NvCreateChatCompletionRequest,
    key: String, // <- owned, only one borrowed param remains
    value: impl Into<String>,
) -> &mut NvCreateChatCompletionRequest {
    let prefix = format!("{}:", key);
    let kv = format!("{}{}", prefix, value.into());
    let annotations = ensure_annotations(request);
    annotations.retain(|a| !a.starts_with(&prefix));
    annotations.push(kv);
    request
}

/// Wrapper function that queries worker selection and annotates the original request
///
/// This function performs the complete flow:
/// 1. Clones the original request and adds "query_instance_id" annotation
/// 2. Calls engine.generate() with the modified request
/// 3. Extracts worker_instance_id and tokens from the response stream
/// 4. Adds worker_instance_id and token_data annotations to the original request
/// 5. Returns (worker_id, tokens, annotated_original_request)
///
/// # Parameters
/// - `engine`: The worker selection pipeline engine
/// - `original_request`: The original OpenAI request to process
///
/// # Returns
/// A tuple containing (worker_instance_id, tokens, modified_original_request)
/// where the modified_original_request has worker_instance_id and token_data annotations added
pub async fn query_worker_selection_and_annotate(
    engine: &ServiceEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    >,
    mut original_request: NvCreateChatCompletionRequest,
) -> anyhow::Result<(i64, Vec<u32>, NvCreateChatCompletionRequest)> {
    let mut query_request = original_request.clone();
    add_query_instance_id(&mut query_request);
    let single_in = SingleIn::new(query_request);
    let response_stream = engine.generate(single_in).await?;
    let (worker_id, tokens) = extract_worker_selection_from_stream(response_stream).await?;
    add_worker_instance_id_annotation(&mut original_request, worker_id);
    add_token_data_annotation(&mut original_request, &tokens);

    Ok((worker_id, tokens, original_request))
}

917
/// Create a worker selection pipeline for OpenAI Chat Completion requests
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
///
/// This is a concrete implementation that works specifically with NvCreateChatCompletionRequest
/// and is designed for use with C bindings. Uses the "generate" endpoint by default.
///
/// # Parameters
/// - `namespace`: namespace name
/// - `component_name`: component name
/// - `model_name`: Name/slug of the model to load
/// - `router_mode`: How to route requests (KV, RoundRobin, etc.)
/// - `busy_threshold`: Optional threshold for busy worker detection
/// - `kv_router_config`: Optional KV router configuration (only used when router_mode is KV)
///
/// # Returns
/// A configured worker selection pipeline ready to use
pub async fn create_worker_selection_pipeline_chat(
    namespace: &str,
    component_name: &str,
    model_name: &str,
    router_mode: RouterMode,
    busy_threshold: Option<f64>,
    kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<
    ServiceEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    >,
> {
    let runtime = Runtime::from_settings()?;
946
    let dst_config = DistributedConfig::from_settings();
947
948
949
950
951
952
    let drt_owned = DistributedRuntime::new(runtime, dst_config).await?;
    let distributed_runtime: &'static DistributedRuntime = Box::leak(Box::new(drt_owned));

    let component = distributed_runtime
        .namespace(namespace)?
        .component(component_name)?;
953
954
    let endpoint = component.endpoint(GENERATE_ENDPOINT);
    let client = endpoint.client().await?;
955

956
957
958
959
960
961
    // Discover the model card by searching all instances with this model name
    tracing::debug!("Looking for model: {}", model_name);
    tracing::debug!("Namespace: {}", namespace);

    use dynamo_llm::discovery::ModelWatcher;
    let model_manager = std::sync::Arc::new(ModelManager::new());
962
963
964
965
966
967
    let router_config = dynamo_llm::entrypoint::RouterConfig {
        router_mode,
        kv_router_config: kv_router_config.unwrap_or_default(),
        busy_threshold,
        enforce_disagg: false,
    };
968
969
970
    let watcher = ModelWatcher::new(
        component.drt().clone(),
        model_manager.clone(),
971
        router_config,
972
973
974
975
976
977
978
979
980
981
982
983
    );
    let cards = watcher
        .cards_for_model(model_name, Some(namespace), false)
        .await
        .with_context(|| format!("Failed to discover model: {}", model_name))?;

    tracing::debug!("Found {} cards for model {}", cards.len(), model_name);

    let card = cards.into_iter().next().ok_or_else(|| {
        tracing::error!("No ModelDeploymentCard found for model: {}", model_name);
        anyhow::anyhow!("ModelDeploymentCard not found for model: {}", model_name)
    })?;
984
985
986
987

    let chooser = if router_mode == RouterMode::KV {
        Some(
            model_manager
988
                .kv_chooser_for(&endpoint, card.kv_cache_block_size, kv_router_config)
989
990
991
992
993
994
                .await?,
        )
    } else {
        None
    };

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    // Download model config files from HuggingFace for EPP
    // The backend's card has NATS URLs which aren't accessible from EPP
    tracing::debug!(
        "Downloading model config files for EPP: {}",
        card.display_name
    );

    let local_path = dynamo_llm::hub::from_hf(&card.display_name, true)
        .await
        .with_context(|| {
            format!(
                "Failed to download model config files for: {}",
                card.display_name
            )
        })?;

    // Load a fresh card from local files, then copy runtime config from original card
    tracing::debug!("Loading ModelDeploymentCard from local path...");
    let mut card_with_local_files = ModelDeploymentCard::load_from_disk(&local_path, None)
        .with_context(|| format!("Failed to load card from disk: {:?}", local_path))?;

    // Copy runtime settings from the backend's card
    tracing::debug!("Copying runtime config from backend card...");
    card_with_local_files.runtime_config = card.runtime_config.clone();
    card_with_local_files.kv_cache_block_size = card.kv_cache_block_size;
    card_with_local_files.context_length = card.context_length;

    // Load the tokenizer from the downloaded files
    tracing::debug!("Loading tokenizer from local files...");
    let hf_tokenizer = card_with_local_files
        .tokenizer_hf()
        .with_context(|| format!("Failed to load tokenizer for: {}", card.display_name))?;

    let engine = build_routed_pipeline::<
        NvCreateChatCompletionRequest,
        NvCreateChatCompletionStreamResponse,
    >(
        &card_with_local_files,
        &client,
        router_mode,
        busy_threshold,
        chooser,
        hf_tokenizer,
1038
1039
        None,  // prefill_chooser
        false, // enforce_disagg
1040
1041
1042
1043
    )
    .await?;

    Ok(engine)
1044
}