"csrc/quantization/marlin/generate_kernels.py" did not exist on "ce96857fdd2bf2390aaa2183561fd1a0f5c464c7"
lib.rs 51.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use async_once_cell::OnceCell as AsyncOnceCell;
use libc::c_char;
use once_cell::sync::OnceCell;
7
use std::borrow::Cow;
8
use std::ffi::CStr;
9
use std::ptr;
10
use std::sync::Arc;
11
use std::sync::atomic::{AtomicU32, Ordering};
12
use std::time::Duration;
13

14
15
16
17
18
use dynamo_kv_router::{
    config::{KvRouterConfig, RouterConfigOverride},
    protocols::*,
};
use dynamo_llm::kv_router::publisher::KvEventPublisher;
19
use dynamo_llm::model_card::ModelDeploymentCard;
20
use dynamo_llm::preprocessor::OpenAIPreprocessor;
21
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
22
use dynamo_runtime::{DistributedRuntime, Worker};
23
24
25
26

use dynamo_runtime::Runtime;

use dynamo_llm::discovery::{ModelManager, WORKER_TYPE_DECODE};
27
use dynamo_llm::kv_router::{KvRouter, PrefillRouter};
28
29
use dynamo_runtime::pipeline::RouterMode;

30
31
use std::collections::HashSet;

32
33
34
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
GuanLuo's avatar
GuanLuo committed
35
static KV_PUB: OnceCell<KvEventPublisher> = OnceCell::new();
36

37
38
39
40
41
42
struct DiscoveredModelBootstrap {
    preprocessor: Arc<OpenAIPreprocessor>,
    card: ModelDeploymentCard,
    actual_namespace: String,
}

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/// Convert a C string pointer to a Rust string, falling back to a default when:
/// - the pointer is NULL,
/// - the bytes are not valid UTF-8,
/// - or the resulting string is empty/whitespace.
#[inline]
unsafe fn cstr_or_default<'a>(ptr: *const c_char, default_val: &'a str) -> Cow<'a, str> {
    if ptr.is_null() {
        return Cow::from(default_val);
    }
    match unsafe { CStr::from_ptr(ptr) }
        .to_str()
        .ok()
        .map(|s| s.trim())
    {
        Some(s) if !s.is_empty() => Cow::from(s.to_owned()),
        _ => Cow::from(default_val),
    }
}

62
63
64
65
66
67
68
fn initialize_tracing() {
    // Sets up RUST_LOG environment variable for logging while KV Publishing
    // Example: os.environ["RUST_LOG"] = "debug"
    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .finish();

69
70
71
    if tracing::subscriber::set_global_default(subscriber).is_ok() {
        tracing::debug!("Tracing initialized");
    }
72
73
74
}

#[repr(u32)]
Neelay Shah's avatar
Neelay Shah committed
75
pub enum DynamoLlmResult {
76
77
78
79
    OK = 0,
    ERR = 1,
}

80
81
82
83
84
85
86
87
88
// Wait for the discovery daemon to sync indefinitely and return at least one instance.
// This is because the Model info is registered by workers and it may take up to 30 min for the model weights to load and for the worker to register itself.
// The waiting timeout is implemented in the Kubernetes StartupProbe. The EPP waiting loops runs indefinitely, the Probe is a single source of truth with when to kill the EPP if discovery fails.
// If workers are not found within the probe's failureThreshold × periodSeconds, the pod will be killed and restarted.
// Users can adjust the StartupProbe waiting timed in the DGD for large models.
async fn wait_for_discovery_sync(drt: &DistributedRuntime) -> usize {
    tracing::info!(
        "Waiting for discovery to sync (no timeout - controlled by K8s StartupProbe)..."
    );
89
90
91
92
93
94
95
96
97
98
99
100
    let discovery = drt.discovery();

    loop {
        match discovery.list(DiscoveryQuery::AllModels).await {
            Ok(instances) if !instances.is_empty() => {
                return instances.len();
            }
            Ok(_) => {
                tracing::debug!("No instances yet, waiting...");
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
            }
            Err(e) => {
101
102
103
                // Log and continue - transient errors shouldn't stop the wait
                tracing::warn!("Discovery list error: {}, retrying...", e);
                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
104
105
106
107
108
            }
        }
    }
}

109
/// # Safety
GuanLuo's avatar
GuanLuo committed
110
/// the namespace_c_str and component_c_str are passed as pointers to C strings
111
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
112
pub unsafe extern "C" fn dynamo_llm_init(
GuanLuo's avatar
GuanLuo committed
113
114
    namespace_c_str: *const c_char,
    component_c_str: *const c_char,
115
    kv_block_size: u32,
Neelay Shah's avatar
Neelay Shah committed
116
) -> DynamoLlmResult {
117
118
119
120
    initialize_tracing();
    let wk = match WK.get_or_try_init(Worker::from_settings) {
        Ok(wk) => wk.clone(),
        Err(e) => {
121
            tracing::error!(error = ?e, "Failed to initialize runtime (Worker::from_settings)");
Neelay Shah's avatar
Neelay Shah committed
122
            return DynamoLlmResult::ERR;
123
124
125
126
127
128
129
130
131
132
        }
    };
    let rt = wk.runtime();
    let secondary = rt.secondary().clone();
    let result = secondary.block_on(async {
        // Initialize the distributed runtime
        match DRT
            .get_or_try_init(async { DistributedRuntime::from_settings(rt.clone()).await })
            .await
        {
133
            Ok(drt) => {
134
                // Wait for discovery to sync before returning.
135
                // This is needed because dynamo_create_worker_selection_pipeline() is called
136
137
138
139
                // immediately after, and it needs discovery.list() to return data.
                // The discovery daemon takes time to query K8s and returns async, so we need to wait.
                // Note: This waits indefinitely - the K8s StartupProbe is the timeout mechanism.
                wait_for_discovery_sync(drt).await;
140
141
                Ok(())
            }
142
            Err(e) => {
143
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
144
                Err(DynamoLlmResult::ERR)
145
146
147
            }
        }
    });
GuanLuo's avatar
GuanLuo committed
148
    let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
149
150
        Ok(s) => s.to_string(),
        Err(e) => {
151
            tracing::error!(error = ?e, "Failed to convert C string to Rust string (namespace)");
Neelay Shah's avatar
Neelay Shah committed
152
            return DynamoLlmResult::ERR;
153
154
155
        }
    };

156
157
158
159
160
    let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
    if let Cow::Borrowed("backend") = &component_cow {
        tracing::info!("defaulting to \"backend\" for component");
    }
    let component: String = component_cow.into_owned();
161
162

    match result {
163
        Ok(_) => match KV_PUB.get_or_try_init(move || {
Yan Ru Pei's avatar
Yan Ru Pei committed
164
            dynamo_create_kv_publisher(namespace, component, kv_block_size)
165
        }) {
Neelay Shah's avatar
Neelay Shah committed
166
            Ok(_) => DynamoLlmResult::OK,
167
            Err(e) => {
168
                tracing::error!(error = ?e, "Failed to initialize distributed runtime");
Neelay Shah's avatar
Neelay Shah committed
169
                DynamoLlmResult::ERR
170
171
172
173
174
175
            }
        },
        Err(e) => e,
    }
}

176
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
177
pub extern "C" fn dynamo_llm_shutdown() -> DynamoLlmResult {
178
179
180
    let wk = match WK.get() {
        Some(wk) => wk,
        None => {
181
            tracing::error!("Runtime not initialized");
Neelay Shah's avatar
Neelay Shah committed
182
            return DynamoLlmResult::ERR;
183
184
185
186
187
        }
    };

    wk.runtime().shutdown();

Neelay Shah's avatar
Neelay Shah committed
188
    DynamoLlmResult::OK
189
190
}

191
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
192
193
pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
    DynamoLlmResult::OK
194
195
196
197
}

// instantiate a kv publisher
// this will bring up the task to publish and the channels to await publishing events
Neelay Shah's avatar
Neelay Shah committed
198
199
// the [`dynamo_kv_publish_store_event`] call will use a handle to the publisher to send events
// store and the [`dynamo_kv_event_create_removed`] will create remove events
200
201
202
// these call mus be driving by external c++ threads that are consuming the kv events from the
// c++ executor api

Neelay Shah's avatar
Neelay Shah committed
203
fn dynamo_create_kv_publisher(
GuanLuo's avatar
GuanLuo committed
204
205
    namespace: String,
    component: String,
206
    kv_block_size: u32,
GuanLuo's avatar
GuanLuo committed
207
) -> Result<KvEventPublisher, anyhow::Error> {
208
    tracing::info!("Creating KV Publisher for model: {}", component);
209
210
211
212
213
    match DRT
        .get()
        .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
    {
        Ok(drt) => {
GuanLuo's avatar
GuanLuo committed
214
            let backend = drt.namespace(namespace)?.component(component)?;
Yan Ru Pei's avatar
Yan Ru Pei committed
215
            KvEventPublisher::new(backend, kv_block_size, None)
216
217
218
219
220
221
222
223
224
        }
        Err(e) => Err(e),
    }
}

fn kv_event_create_stored_block_from_parts(
    block_hash: u64,
    token_ids: *const u32,
    num_tokens: usize,
225
    kv_block_size: u32,
226
    lora_name: Option<&str>,
227
) -> KvCacheStoredBlockData {
228
229
230
    let tokens_hash = compute_block_hash_for_seq(
        unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
        kv_block_size,
231
232
233
234
        BlockHashOptions {
            lora_name,
            ..Default::default()
        },
235
    )[0];
236
237
238
    KvCacheStoredBlockData {
        block_hash: ExternalSequenceBlockHash(block_hash),
        tokens_hash,
239
        mm_extra_info: None,
240
241
242
243
244
    }
}
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);

fn kv_event_create_stored_from_parts(
245
    kv_params: DynamoKvStoredEventParams,
246
    kv_block_size: u32,
247
248
249
250
) -> KvCacheEvent {
    let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();

    let mut token_offset: usize = 0;
251
252
253
254
255
256
257
258
259
    for block_idx in 0..kv_params.num_blocks {
        let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
        let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
        let num_toks = unsafe {
            *kv_params
                .num_block_tokens
                .offset(block_idx.try_into().unwrap())
        };

260
        if num_toks != (kv_block_size as usize) {
Ryan Olson's avatar
Ryan Olson committed
261
262
            if WARN_COUNT
                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
263
                    if c < 3 { Some(c + 1) } else { None }
Ryan Olson's avatar
Ryan Olson committed
264
265
266
                })
                .is_ok()
            {
267
                tracing::warn!(
268
269
                    "Block not published. Block size must be {} tokens to be published. Block size is: {}",
                    kv_block_size,
Ryan Olson's avatar
Ryan Olson committed
270
271
                    num_toks
                );
272
273
274
275
276
            }
            break;
        }
        token_offset += num_toks;
        blocks.push(kv_event_create_stored_block_from_parts(
277
278
279
280
            block_hash,
            tokens,
            num_toks,
            kv_block_size,
281
            kv_params.lora_name.as_deref(),
282
283
284
285
286
287
        ));
    }

    KvCacheEvent {
        data: KvCacheEventData::Stored(KvCacheStoreData {
            blocks,
288
            parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
289
            start_position: None,
290
        }),
291
        event_id: kv_params.event_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
292
        dp_rank: 0,
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    }
}

fn kv_event_create_removed_from_parts(
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
) -> KvCacheEvent {
    let block_hashes: Vec<ExternalSequenceBlockHash> =
        unsafe { std::slice::from_raw_parts(block_ids, num_blocks) }
            .to_vec()
            .iter()
            .map(|&v| ExternalSequenceBlockHash(v))
            .collect();
    KvCacheEvent {
        event_id,
        data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
Yan Ru Pei's avatar
Yan Ru Pei committed
310
        dp_rank: 0,
311
312
313
    }
}

314
315
316
317
318
319
320
pub struct DynamoKvStoredEventParams {
    pub event_id: u64,
    pub token_ids: *const u32,
    pub num_block_tokens: *const usize,
    pub block_ids: *const u64,
    pub num_blocks: usize,
    pub parent_hash: Option<u64>,
321
    pub lora_name: Option<String>,
322
323
}

324
325
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
326
327
/// has a parent hash or not. nullptr is used to represent no parent hash.
/// lora_name is an optional null-terminated C string; pass nullptr for base model.
328
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
329
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
330
331
332
333
334
335
    event_id: u64,
    token_ids: *const u32,
    num_block_tokens: *const usize,
    block_ids: *const u64,
    num_blocks: usize,
    parent_hash: *const u64,
336
    lora_name: *const c_char,
Neelay Shah's avatar
Neelay Shah committed
337
) -> DynamoLlmResult {
338
339
340
341
342
343
344
    let parent_hash = {
        if parent_hash.is_null() {
            None
        } else {
            Some(unsafe { *parent_hash })
        }
    };
345
346
347
348
349
350
351
352
353
354
355
    let lora_name = if lora_name.is_null() {
        None
    } else {
        match unsafe { CStr::from_ptr(lora_name) }.to_str() {
            Ok(s) => Some(s.to_owned()),
            Err(e) => {
                tracing::error!(error = ?e, "Failed to convert C string to Rust string (lora_name)");
                return DynamoLlmResult::ERR;
            }
        }
    };
356
    let kv_params = DynamoKvStoredEventParams {
357
358
359
360
361
362
        event_id,
        token_ids,
        num_block_tokens,
        block_ids,
        num_blocks,
        parent_hash,
363
        lora_name,
364
365
366
    };
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
367
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
368
        Ok(_) => DynamoLlmResult::OK,
369
370
        Err(e) => {
            eprintln!("Error publishing stored kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
371
            DynamoLlmResult::ERR
372
373
374
375
        }
    }
}

376
#[unsafe(no_mangle)]
Neelay Shah's avatar
Neelay Shah committed
377
pub extern "C" fn dynamo_kv_event_publish_removed(
378
379
380
    event_id: u64,
    block_ids: *const u64,
    num_blocks: usize,
Neelay Shah's avatar
Neelay Shah committed
381
) -> DynamoLlmResult {
382
383
384
    let publisher = KV_PUB.get().unwrap();
    let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
    match publisher.publish(event) {
Neelay Shah's avatar
Neelay Shah committed
385
        Ok(_) => DynamoLlmResult::OK,
386
387
        Err(e) => {
            eprintln!("Error publishing removed kv event {:?}", e);
Neelay Shah's avatar
Neelay Shah committed
388
            DynamoLlmResult::ERR
389
390
391
392
        }
    }
}

393
/* ------------------------------------------------------------------------
394
 *  Router Bindings for GAIE EPP
395
396
 * ------------------------------------------------------------------------ */

397
398
399
400
401
402
403
404
405
406
407
// Default timeout for bookkeeping operations
const BOOKKEEPING_TIMEOUT_SEC: u64 = 5;
/// Complete routing result for a chat completion request (C-compatible)
#[repr(C)]
pub struct CRoutingResult {
    /// Whether disaggregated mode is active
    pub is_disaggregated: bool,
    /// Prefill worker ID (only valid if is_disaggregated is true)
    pub prefill_worker_id: u64,
    /// Decode worker ID
    pub decode_worker_id: u64,
atchernych's avatar
atchernych committed
408
409
410
411
    /// Data parallel rank selected for the prefill worker
    pub prefill_dp_rank: u32,
    /// Data parallel rank selected for the decode worker
    pub decode_dp_rank: u32,
412
413
414
415
    /// Token IDs (needed for add_request callback)
    pub token_ids: *mut u32,
    /// Number of tokens in the request
    pub token_count: usize,
416
417
}

418
419
420
421
422
423
impl Default for CRoutingResult {
    fn default() -> Self {
        Self {
            is_disaggregated: false,
            prefill_worker_id: 0,
            decode_worker_id: 0,
atchernych's avatar
atchernych committed
424
425
            prefill_dp_rank: 0,
            decode_dp_rank: 0,
426
427
            token_ids: ptr::null_mut(),
            token_count: 0,
428
        }
429
    }
430
}
431

432
433
434
435
436
437
438
439
440
441
442
443
444
/// Container holding routers and preprocessor for query routing
pub struct RouterHandles {
    prefill_router: Arc<PrefillRouter>,
    decode_router: Arc<KvRouter>,
    #[allow(dead_code)]
    model_manager: Arc<ModelManager>,
    #[allow(dead_code)]
    namespace: String,
    /// Cached runtime for executing async operations (avoids creating new runtime per call)
    runtime: Runtime,
    /// Preprocessor for tokenization and template application (fetched via discovery)
    preprocessor: Option<Arc<OpenAIPreprocessor>>,
}
445

446
447
impl RouterHandles {
    /// Query optimal prefill worker for a request.
448
449
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered.
450
451
452
453
    /// Returns worker_id on success.
    async fn query_prefill_worker(
        &self,
        tokens: &[u32],
454
        block_mm_infos: Option<&[Option<dynamo_kv_router::protocols::BlockExtraInfo>]>,
455
456
457
        update_states: bool,
        lora_name: Option<String>,
        priority_jump: f64,
458
        allowed_worker_ids: Option<HashSet<WorkerId>>,
459
    ) -> Result<(u64, Option<u32>), QueryRouterResult> {
460
461
462
463
        if let Some(ref ids) = allowed_worker_ids {
            self.prefill_router.register_workers(ids);
        }

464
        self.prefill_router
465
466
467
468
469
470
            .query_prefill_worker(
                tokens,
                block_mm_infos,
                update_states,
                lora_name,
                priority_jump,
471
                allowed_worker_ids,
472
            )
473
474
475
476
477
478
            .await
            .map_err(|e| {
                tracing::error!(error = ?e, "Prefill query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
479

480
481
482
483
    /// Query optimal decode worker for a request.
    /// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
    /// (since KV cache is being transferred from prefill, not reused).
    ///
484
485
486
    /// When `allowed_worker_ids` is Some, only workers in that set are considered.
    /// This does NOT overwrite the router's internal worker state — it only filters this decision.
    ///
487
488
489
490
491
492
493
494
    /// Note: The C bindings are query-only and must not mutate router state during worker
    /// selection. State updates require a `context_id` (request id) and are managed via the
    /// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
    /// Returns (worker, overlap_blocks) on success.
    async fn query_decode_worker(
        &self,
        tokens: &[u32],
        is_disaggregated: bool,
495
        allowed_worker_ids: Option<HashSet<WorkerId>>,
496
    ) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
497
498
499
500
        if let Some(ref ids) = allowed_worker_ids {
            self.decode_router.register_workers(ids);
        }

501
502
503
504
505
        // For decode phase in disaggregated mode, use overlap_score_weight=0
        // This matches prefill_router.rs
        let config_override = if is_disaggregated {
            Some(RouterConfigOverride {
                overlap_score_weight: Some(0.0),
506
507
                assume_kv_reuse: Some(false),
                track_prefill_tokens: Some(false),
508
                ..Default::default()
509
            })
510
511
512
513
        } else {
            None
        };

514
        self.decode_router
515
516
517
518
519
520
521
522
            .find_best_match(
                None,
                tokens,
                None,
                config_override.as_ref(),
                false,
                None,
                0.0,
523
                None,
524
                allowed_worker_ids,
525
            )
526
527
528
529
530
531
532
            .await
            .map_err(|e| {
                tracing::error!(error = ?e, "Decode query failed");
                QueryRouterResult::ErrQueryFailed
            })
    }
}
533

534
535
/// Opaque handle for the router pair
pub type RouterHandlesPtr = *mut RouterHandles;
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
/// Result codes for query router C FFI
#[repr(u32)]
pub enum QueryRouterResult {
    Ok = 0,
    ErrInvalidHandle = 1,
    ErrInvalidParam = 2,
    ErrInitFailed = 3,
    ErrQueryFailed = 4,
    ErrDisaggEnforced = 5,
    ErrTimeout = 6,
}

/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
fn kv_router_config_from_env() -> KvRouterConfig {
    let mut cfg = KvRouterConfig::default();

    fn env_f64(key: &str) -> Option<f64> {
        std::env::var(key).ok().and_then(|v| v.parse().ok())
555
    }
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    fn env_bool(key: &str) -> Option<bool> {
        std::env::var(key)
            .ok()
            .and_then(|v| match v.to_lowercase().as_str() {
                "true" | "1" | "yes" | "on" => Some(true),
                "false" | "0" | "no" | "off" => Some(false),
                _ => None,
            })
    }

    if let Some(v) = env_f64("DYN_OVERLAP_SCORE_WEIGHT") {
        cfg.overlap_score_weight = v;
    }
    if let Some(v) = env_f64("DYN_ROUTER_TEMPERATURE") {
        cfg.router_temperature = v;
    }
    if let Some(v) = env_bool("DYN_USE_KV_EVENTS") {
        cfg.use_kv_events = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_REPLICA_SYNC") {
        cfg.router_replica_sync = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_ACTIVE_BLOCKS") {
        cfg.router_track_active_blocks = v;
    }
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") {
        cfg.router_track_output_blocks = v;
    }
584
585
586
    if let Some(v) = env_bool("DYN_ROUTER_TRACK_PREFILL_TOKENS") {
        cfg.router_track_prefill_tokens = v;
    }
587
588
589
    if let Some(v) = env_f64("DYN_ROUTER_QUEUE_THRESHOLD") {
        cfg.router_queue_threshold = Some(v);
    }
590
591
592
593
594
595
596
597

    tracing::info!(
        overlap_score_weight = cfg.overlap_score_weight,
        router_temperature = cfg.router_temperature,
        use_kv_events = cfg.use_kv_events,
        router_replica_sync = cfg.router_replica_sync,
        router_track_active_blocks = cfg.router_track_active_blocks,
        router_track_output_blocks = cfg.router_track_output_blocks,
598
        router_track_prefill_tokens = cfg.router_track_prefill_tokens,
599
        router_queue_threshold = ?cfg.router_queue_threshold,
600
601
602
603
        "KvRouterConfig initialized (DYN_* env overrides applied)"
    );

    cfg
604
605
}

606
/// Create router handles for query-only routing
607
///
608
609
610
/// This function waits for at least one decode worker to be discovered before returning.
/// It auto-detects disaggregated mode by checking if prefill workers are present.
/// The KV cache block size is automatically fetched from the model card via discovery.
611
///
612
613
614
/// # Arguments
/// - `namespace`: Namespace for the model
/// - `component`: Component name (defaults to "backend" if NULL or empty)
615
/// - `enforce_disagg`: If true, requires prefill workers to be present at init time
616
/// - `out_handle`: Output handle
617
///
618
619
620
/// # Safety
/// - All string parameters must be valid null-terminated C strings
/// - The returned handle must be freed with `destroy`
621
#[unsafe(no_mangle)]
622
623
624
pub unsafe extern "C" fn create_routers(
    namespace: *const c_char,
    component: *const c_char,
625
    enforce_disagg: bool,
626
627
    out_handle: *mut RouterHandlesPtr,
) -> QueryRouterResult {
628
629
    initialize_tracing();

630
631
    if namespace.is_null() || out_handle.is_null() {
        return QueryRouterResult::ErrInvalidParam;
632
633
    }

634
635
636
    let namespace_str = match unsafe { CStr::from_ptr(namespace) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
637
638
    };

639
640
    let component_str = if component.is_null() {
        "backend".to_string()
641
    } else {
642
643
644
        match unsafe { CStr::from_ptr(component) }.to_str() {
            Ok(s) if !s.is_empty() => s.to_owned(),
            _ => "backend".to_string(),
645
646
647
        }
    };

648
649
650
    // Create the runtime once - it will be stored in RouterHandles and reused
    let runtime = match Runtime::from_settings() {
        Ok(rt) => rt,
651
        Err(e) => {
652
653
            tracing::error!(error = ?e, "Failed to create runtime");
            return QueryRouterResult::ErrInitFailed;
654
655
        }
    };
656
657
658
659
660
661
662
663
664
665

    // Clone for use inside the async block (the original will be moved into handles)
    let runtime_for_async = runtime.clone();

    let result = runtime_for_async.secondary().block_on(async {
        let drt = match DistributedRuntime::from_settings(runtime_for_async.clone()).await {
            Ok(drt) => drt,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create distributed runtime");
                return Err(QueryRouterResult::ErrInitFailed);
666
            }
667
668
        };

669
670
671
672
673
674
675
676
677
678
679
680
681
682
        let DiscoveredModelBootstrap {
            preprocessor,
            card,
            actual_namespace,
        } = match init_preprocessor(&drt, &namespace_str).await {
            Ok(result) => result,
            Err(e) => {
                tracing::error!(error = %e, "Failed to initialize preprocessor");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
        let block_size = card.kv_cache_block_size;
        let model_name = card.display_name.clone();
        let enable_eagle = card.runtime_config.enable_eagle;
683
684
685

        if actual_namespace != namespace_str {
            tracing::info!(
686
687
                base_namespace = %namespace_str,
                actual_namespace = %actual_namespace,
688
                "Worker namespace has rolling-update suffix"
689
            );
690
691
        }

692
693
        let mut kv_router_config = kv_router_config_from_env();
        kv_router_config.skip_initial_worker_wait = true;
694

695
696
697
        // Build endpoint using the actual namespace discovered from workers,
        // which may include a rolling-update hash suffix.
        let component_handle = match drt.namespace(&actual_namespace) {
698
699
700
701
702
703
            Ok(ns) => match ns.component(&component_str) {
                Ok(c) => c,
                Err(e) => {
                    tracing::error!(error = ?e, "Failed to get component");
                    return Err(QueryRouterResult::ErrInitFailed);
                }
704
            },
705
706
707
708
709
710
            Err(e) => {
                tracing::error!(error = ?e, "Failed to get namespace");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
        let endpoint = component_handle.endpoint("generate");
711

712
        let model_manager = Arc::new(ModelManager::new());
713

714
715
716
717
718
        // Create decode router
        let decode_router = match model_manager
            .kv_chooser_for(
                &endpoint,
                block_size,
719
                Some(kv_router_config.clone()),
720
                None,
721
                WORKER_TYPE_DECODE,
722
                Some(model_name.clone()),
723
                enable_eagle,
724
725
726
727
728
729
730
731
732
            )
            .await
        {
            Ok(r) => r,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to create decode router");
                return Err(QueryRouterResult::ErrInitFailed);
            }
        };
733

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
        // Wait for the runtime config watch to be populated with at least one
        // decode worker's ModelRuntimeConfig. skip_initial_worker_wait=true
        // skips this inside KvRouter::new, but the selector needs workers in
        // workers_with_configs to avoid NoEndpoints on the first request.
        // discovery sync already confirmed workers exist; this just waits for
        // the async join of instance IDs + configs to complete in the watch.
        {
            let mut config_watch = model_manager
                .get_or_create_runtime_config_watcher(&endpoint)
                .await
                .map_err(|e| {
                    tracing::error!(error = ?e, "Failed to get runtime config watcher");
                    QueryRouterResult::ErrInitFailed
                })?;
            tracing::info!(
                "Waiting for decode workers to register ModelRuntimeConfig \
                 (no timeout - controlled by K8s StartupProbe)..."
            );
            let wait_result = config_watch.wait_for(|m| !m.is_empty()).await.map(|_| ());
            match wait_result {
                Ok(()) => {
                    let count = config_watch.borrow().len();
                    tracing::info!(
                        worker_count = count,
                        "Runtime config watch populated with decode workers"
                    );
                }
                Err(_) => {
                    tracing::error!(
                        "Runtime config watch closed before any workers appeared. \
                         Decode routing will fail. \
                         Verify workers are running and publishing to discovery."
                    );
                    return Err(QueryRouterResult::ErrInitFailed);
                }
            }
        }

772
773
774
775
776
777
778
779
780
781
782
783
784
785
        // Create PrefillRouter with a pending activation channel.
        // A background task watches discovery for prefill workers and activates
        // the router when one appears. Before activation, requests gracefully
        // fallback to decode-only routing.
        let mut prefill_config = kv_router_config;
        prefill_config.router_track_active_blocks = false;

        let (prefill_tx, prefill_rx) = tokio::sync::oneshot::channel();
        let prefill_router = PrefillRouter::new(
            prefill_rx,
            model_manager.clone(),
            RouterMode::KV,
            block_size,
            Some(prefill_config),
786
            None,
787
788
789
790
791
792
793
794
795
796
797
            enforce_disagg,
            model_name.clone(),
            actual_namespace.clone(),
            enable_eagle,
        );

        // Spawn background discovery watcher for prefill workers.
        // Polls discovery until a prefill-only worker appears in the same
        // rolling-update namespace, then sends its endpoint through the channel
        // to activate the PrefillRouter.
        spawn_prefill_discovery_watcher(drt.clone(), actual_namespace.clone(), prefill_tx);
798
799
800
801
802
803

        Ok((
            prefill_router,
            decode_router,
            model_manager,
            namespace_str,
804
            Some(preprocessor),
805
806
        ))
    });
807
808

    match result {
809
810
811
812
813
814
815
816
817
818
819
        Ok((prefill_router, decode_router, model_manager, namespace_str, preprocessor)) => {
            let handles = RouterHandles {
                prefill_router,
                decode_router,
                model_manager,
                namespace: namespace_str,
                runtime, // Store the runtime for reuse
                preprocessor,
            };
            unsafe { *out_handle = Box::into_raw(Box::new(handles)) };
            QueryRouterResult::Ok
820
        }
821
        Err(code) => code,
822
823
824
825
826
    }
}

/// Add a request to the router's bookkeeping after worker selection.
///
827
828
829
/// Register the request with the KvRouter's scheduler for tracking active blocks
/// and managing prefill/decode lifecycle. Call this after `query_decode` returns
/// worker IDs and before sending the request to the worker.
830
831
///
/// # Safety
832
833
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
834
835
/// - `token_ids` must point to at least `token_count` valid u32 values
#[unsafe(no_mangle)]
836
837
838
pub unsafe extern "C" fn add_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
839
840
841
842
    token_ids: *const u32,
    token_count: usize,
    worker_id: u64,
    dp_rank: u32,
843
844
845
846
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
847

848
849
850
851
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
852
853
854
855
856
857
858
859
    };

    let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
        unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
    } else {
        Vec::new()
    };

860
    let decode_router = handles.decode_router.clone();
861

862
863
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
864

865
866
        tokio::time::timeout(timeout_duration, async {
            let worker = WorkerWithDpRank::new(worker_id, dp_rank);
867
868
869
870
871
872
            let router_config_override = RouterConfigOverride {
                overlap_score_weight: Some(0.0),
                assume_kv_reuse: Some(false),
                track_prefill_tokens: Some(false),
                ..Default::default()
            };
873

874
            // Compute overlap_blocks using the public method
875
            let overlap_blocks = match decode_router
876
                .get_overlap_blocks(&tokens, None, worker, None)
877
878
                .await
            {
879
880
881
882
883
884
885
                Ok(overlap) => overlap,
                Err(e) => {
                    tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
                    0
                }
            };

886
            let cached_tokens = overlap_blocks as usize * decode_router.block_size() as usize;
887
888
889
890
            decode_router
                .add_request(
                    request_id_str.clone(),
                    &tokens,
891
                    None,
892
                    cached_tokens,
893
894
895
                    None,
                    worker,
                    None, // lora_name
896
                    Some(&router_config_override),
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
                )
                .await;

            tracing::debug!(
                request_id = %request_id_str,
                worker_id = worker_id,
                dp_rank = dp_rank,
                overlap_blocks = overlap_blocks,
                token_count = tokens.len(),
                "add_request completed"
            );
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "add_request timed out"
            );
            QueryRouterResult::ErrTimeout
        }
    }
923
924
925
}

/// Mark prefill as completed for a request.
926
927
///
/// Call when the first token is generated to release prefill tokens from decode worker's load
928
929
///
/// # Safety
930
931
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
932
#[unsafe(no_mangle)]
933
934
935
936
937
938
939
pub unsafe extern "C" fn mark_prefill_complete(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
940

941
942
943
944
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
945
946
    };

947
    let decode_router = handles.decode_router.clone();
948

949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);

        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.mark_prefill_completed(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to mark prefill complete"
                );
            } else {
                tracing::debug!(
                    request_id = %request_id_str,
                    "mark_prefill_complete completed"
                );
            }
        })
        .await
    });
968

969
970
971
    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
972
            tracing::warn!(
973
974
975
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "mark_prefill_complete timed out"
976
            );
977
            QueryRouterResult::ErrTimeout
978
        }
979
    }
980
981
982
}

/// Free a request from the router's bookkeeping.
983
984
///
/// Call this when the stream is closed (completed or cancelled) to release all resources.
985
986
///
/// # Safety
987
988
/// - `handle` must be a valid RouterHandles handle
/// - `request_id` must be a valid null-terminated C string
989
#[unsafe(no_mangle)]
990
991
992
993
994
995
996
pub unsafe extern "C" fn free_request(
    handle: RouterHandlesPtr,
    request_id: *const c_char,
) -> QueryRouterResult {
    if handle.is_null() || request_id.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }
997

998
999
1000
1001
    let handles = unsafe { &*handle };
    let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
        Ok(s) => s.to_owned(),
        Err(_) => return QueryRouterResult::ErrInvalidParam,
1002
1003
    };

1004
    let decode_router = handles.decode_router.clone();
1005

1006
1007
    let result = handles.runtime.secondary().block_on(async {
        let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
1008

1009
1010
1011
1012
1013
1014
        tokio::time::timeout(timeout_duration, async {
            if let Err(e) = decode_router.free(&request_id_str).await {
                tracing::warn!(
                    request_id = %request_id_str,
                    error = %e,
                    "Failed to free request"
1015
                );
1016
            } else {
1017
                tracing::debug!(
1018
1019
                    request_id = %request_id_str,
                    "free_request completed"
1020
                );
1021
            }
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        })
        .await
    });

    match result {
        Ok(()) => QueryRouterResult::Ok,
        Err(_elapsed) => {
            tracing::warn!(
                request_id = %request_id_str,
                timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
                "free_request timed out"
            );
            QueryRouterResult::ErrTimeout
1035
1036
        }
    }
1037
}
1038

1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
/// Destroy router handles
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle or null
/// - After this call, `handle` must not be used
#[unsafe(no_mangle)]
pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
    if !handle.is_null() {
        drop(unsafe { Box::from_raw(handle) });
    }
1049
1050
}

1051
/// Free a routing result.
1052
///
1053
/// # Safety
1054
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
1055
#[unsafe(no_mangle)]
1056
1057
1058
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
    if result.is_null() {
        return;
1059
    }
1060

1061
    let res = unsafe { &mut *result };
1062

1063
1064
1065
    // Free token IDs
    if !res.token_ids.is_null() && res.token_count > 0 {
        drop(unsafe {
1066
            Box::from_raw(std::ptr::slice_from_raw_parts_mut(
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                res.token_ids,
                res.token_count,
            ))
        });
        res.token_ids = ptr::null_mut();
        res.token_count = 0;
    }
}

/// Parse a JSON request string, apply the chat template, and tokenize.
/// Returns the token IDs on success, or a `QueryRouterResult` error code.
unsafe fn preprocess_request(
    handles: &RouterHandles,
    request_json: *const c_char,
) -> Result<Vec<u32>, QueryRouterResult> {
1082
1083
1084
1085
    let preprocessor = match &handles.preprocessor {
        Some(p) => p,
        None => {
            tracing::error!("Preprocessor not available");
1086
            return Err(QueryRouterResult::ErrInitFailed);
1087
1088
1089
1090
1091
        }
    };

    let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
        Ok(s) => s,
1092
        Err(_) => return Err(QueryRouterResult::ErrInvalidParam),
1093
1094
1095
1096
1097
1098
1099
    };

    let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
        match serde_json::from_str(json_str) {
            Ok(req) => req,
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse request JSON");
1100
                return Err(QueryRouterResult::ErrInvalidParam);
1101
1102
1103
1104
1105
1106
1107
1108
            }
        };

    let formatted_prompt = match preprocessor.apply_template(&request) {
        Ok(Some(prompt)) => prompt,
        Ok(None) => String::new(),
        Err(e) => {
            tracing::error!(error = ?e, "Failed to apply chat template");
1109
            return Err(QueryRouterResult::ErrQueryFailed);
1110
1111
1112
1113
1114
1115
1116
        }
    };

    let encoding = match preprocessor.tokenize(&formatted_prompt) {
        Ok(enc) => enc,
        Err(e) => {
            tracing::error!(error = ?e, "Failed to tokenize");
1117
            return Err(QueryRouterResult::ErrQueryFailed);
1118
1119
1120
        }
    };

1121
1122
1123
1124
1125
1126
1127
1128
    let token_ids = encoding.token_ids().to_vec();
    tracing::info!(
        token_count = token_ids.len(),
        first_tokens = ?&token_ids[..std::cmp::min(5, token_ids.len())],
        "[EPP-TOKENIZE] Tokenized prompt in C bindings (this is the ONLY tokenization)"
    );

    Ok(token_ids)
1129
}
1130

1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
/// Parse pods JSON into an optional set of allowed worker IDs.
unsafe fn parse_pods_filter(pods_json: *const c_char) -> Option<HashSet<WorkerId>> {
    if pods_json.is_null() {
        return None;
    }
    match unsafe { CStr::from_ptr(pods_json) }.to_str() {
        Ok(s) if !s.is_empty() => match serde_json::from_str::<Vec<serde_json::Value>>(s) {
            Ok(pods) => {
                let mut worker_ids = HashSet::new();
                for pod in &pods {
                    let pod_name = pod
                        .get("pod")
                        .and_then(|p| p.get("podName"))
                        .or_else(|| pod.get("podName"))
                        .and_then(|v| v.as_str());
                    if let Some(name) = pod_name {
                        let worker_id = hash_pod_name(name);
                        tracing::debug!(
                            pod_name = name,
                            worker_id = format!("{:x}", worker_id),
                            "Mapped EPP pod to worker_id"
                        );
                        worker_ids.insert(worker_id);
                    }
                }
                tracing::info!(
                    pod_count = pods.len(),
                    unique_worker_ids = worker_ids.len(),
                    "Parsed EPP pods into allowed_worker_ids filter"
                );
                if worker_ids.is_empty() {
                    None
                } else {
                    Some(worker_ids)
                }
            }
            Err(e) => {
                tracing::error!(error = ?e, "Failed to parse pods JSON");
                None
            }
        },
        _ => None,
    }
}
1175

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
/// Write token IDs into a `CRoutingResult`, transferring ownership to the caller.
fn write_tokens_to_result(tokens: &[u32], out: &mut CRoutingResult) {
    let token_vec: Vec<u32> = tokens.to_vec();
    let mut tokens_boxed = token_vec.into_boxed_slice();
    out.token_ids = tokens_boxed.as_mut_ptr();
    out.token_count = tokens.len();
    std::mem::forget(tokens_boxed);
}

/// Route a request to select the best **prefill** worker only.
///
/// This is used in disaggregated mode where the EPP runs separate prefill and decode
/// scoring profiles.  It tokenizes the request and queries only the prefill router.
///
/// The returned `CRoutingResult` contains:
/// - `prefill_worker_id`: the selected prefill worker
/// - `decode_worker_id`: 0 (unused — decode is handled by `route_decode_request`)
/// - `is_disaggregated`: always true (this function is only called in disagg mode)
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_prefill_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    pods_json: *const c_char,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
    }

    let handles = unsafe { &*handle };

    let tokens = match unsafe { preprocess_request(handles, request_json) } {
        Ok(t) => t,
        Err(code) => return code,
    };

    let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };

    let result = handles.runtime.secondary().block_on(async {
atchernych's avatar
atchernych committed
1222
        let (prefill_worker_id, prefill_dp_rank) = handles
1223
            .query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
1224
1225
            .await?;

1226
1227
        let prefill_dp_rank = prefill_dp_rank.unwrap_or(u32::MAX);

1228
1229
        tracing::info!(
            prefill_worker_id = prefill_worker_id,
atchernych's avatar
atchernych committed
1230
            prefill_dp_rank = prefill_dp_rank,
1231
1232
            token_count = tokens.len(),
            "Routed prefill request"
1233
1234
        );

atchernych's avatar
atchernych committed
1235
        Ok((prefill_worker_id, prefill_dp_rank))
1236
1237
    });

1238
    match result {
atchernych's avatar
atchernych committed
1239
        Ok((prefill_worker_id, prefill_dp_rank)) => {
1240
1241
1242
1243
            let out = unsafe { &mut *out_result };
            *out = CRoutingResult::default();
            out.is_disaggregated = true;
            out.prefill_worker_id = prefill_worker_id;
atchernych's avatar
atchernych committed
1244
            out.prefill_dp_rank = prefill_dp_rank;
1245
            write_tokens_to_result(&tokens, out);
1246
1247
1248
1249
            QueryRouterResult::Ok
        }
        Err(code) => code,
    }
1250
1251
}

1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
/// Route a request to select the best **decode** worker only.
///
/// This is used in both aggregated and disaggregated modes.
/// - When `is_disaggregated` is true, the decode router uses `overlap_score_weight=0`
///   (KV cache is being transferred from prefill, not reused locally).
/// - When `is_disaggregated` is false, normal KV-aware scoring is used.
///
/// The returned `CRoutingResult` contains:
/// - `decode_worker_id`: the selected decode worker
/// - `prefill_worker_id`: 0 (unused — prefill is handled by `route_prefill_request`)
/// - `is_disaggregated`: mirrors the input parameter
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
1264
///
1265
/// # Safety
1266
1267
1268
1269
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
1270
#[unsafe(no_mangle)]
1271
1272
1273
1274
1275
1276
1277
1278
1279
pub unsafe extern "C" fn route_decode_request(
    handle: RouterHandlesPtr,
    request_json: *const c_char,
    pods_json: *const c_char,
    is_disaggregated: bool,
    out_result: *mut CRoutingResult,
) -> QueryRouterResult {
    if handle.is_null() || request_json.is_null() || out_result.is_null() {
        return QueryRouterResult::ErrInvalidParam;
1280
1281
    }

1282
    let handles = unsafe { &*handle };
1283

1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
    let tokens = match unsafe { preprocess_request(handles, request_json) } {
        Ok(t) => t,
        Err(code) => return code,
    };

    let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };

    let result = handles.runtime.secondary().block_on(async {
        let (decode_worker, _overlap_blocks) = handles
            .query_decode_worker(&tokens, is_disaggregated, allowed_worker_ids)
            .await?;

        tracing::info!(
            is_disaggregated = is_disaggregated,
            decode_worker_id = decode_worker.worker_id,
            decode_dp_rank = decode_worker.dp_rank,
            token_count = tokens.len(),
            "Routed decode request"
        );

        Ok(decode_worker)
    });

    match result {
        Ok(decode_worker) => {
            let out = unsafe { &mut *out_result };
            *out = CRoutingResult::default();
            out.is_disaggregated = is_disaggregated;
            out.decode_worker_id = decode_worker.worker_id;
atchernych's avatar
atchernych committed
1313
            out.decode_dp_rank = decode_worker.dp_rank;
1314
1315
1316
1317
            write_tokens_to_result(&tokens, out);
            QueryRouterResult::Ok
        }
        Err(code) => code,
1318
    }
1319
1320
}

1321
/// Initialize the preprocessor and fetch the model card used for routing.
1322
1323
///
/// Waits for discovery to sync (model card must be available for tokenization),
1324
1325
/// then creates the preprocessor from the model card. Router settings are
/// derived directly from the returned card by the caller.
1326
1327
1328
async fn init_preprocessor(
    drt: &DistributedRuntime,
    target_namespace: &str,
1329
) -> anyhow::Result<DiscoveredModelBootstrap> {
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
    let instance_count = wait_for_discovery_sync(drt).await;
    if instance_count == 0 {
        anyhow::bail!("Discovery sync failed: no worker instances found. Is the backend running?");
    }
    tracing::info!(
        "Discovery sync complete, {} worker(s) found",
        instance_count
    );

    // Retry fetching the preprocessor: model card metadata may arrive after
    // worker endpoints are registered.
1341
    let bootstrap = loop {
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
        match fetch_preprocessor_from_discovery(drt, target_namespace).await {
            Ok(result) => break result,
            Err(e) => {
                tracing::warn!(
                    error = %e,
                    target_namespace,
                    "Model card not available yet, retrying in 5s..."
                );
                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
            }
        }
    };

    tracing::info!(
1356
1357
1358
1359
        kv_cache_block_size = bootstrap.card.kv_cache_block_size,
        model_name = %bootstrap.card.display_name,
        actual_namespace = %bootstrap.actual_namespace,
        enable_eagle = bootstrap.card.runtime_config.enable_eagle,
1360
1361
1362
        "Preprocessor initialized from model card"
    );

1363
    Ok(bootstrap)
1364
1365
}

1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
/// Spawn a background task that watches discovery for a prefill-only worker
/// in the given namespace. When found, sends its endpoint through `tx` to
/// activate the PrefillRouter. Polls every 1 second until a match is found.
fn spawn_prefill_discovery_watcher(
    drt: DistributedRuntime,
    target_namespace: String,
    tx: tokio::sync::oneshot::Sender<dynamo_runtime::component::Endpoint>,
) {
    use dynamo_llm::model_card::ModelDeploymentCard;
    use dynamo_runtime::discovery::DiscoveryInstance;

    tokio::spawn(async move {
        let discovery = drt.discovery();
        tracing::info!(
            namespace = target_namespace,
            "Background task: watching for prefill workers to register..."
        );

        loop {
            if let Ok(instances) = discovery.list(DiscoveryQuery::AllModels).await {
                for instance in instances {
                    if let DiscoveryInstance::Model {
                        namespace,
                        component,
                        endpoint,
                        ..
                    } = &instance
                    {
                        if namespace != &target_namespace {
                            continue;
                        }

                        let card = match instance.deserialize_model::<ModelDeploymentCard>() {
                            Ok(card) => card,
                            Err(_) => continue,
                        };

                        if !card.model_type.supports_prefill()
                            || card.model_type.supports_chat()
                            || card.model_type.supports_completions()
                        {
                            continue;
                        }

                        tracing::info!(
                            model_name = card.name(),
                            namespace = namespace.as_str(),
                            "Prefill worker discovered, activating PrefillRouter"
                        );

                        if let Ok(ns) = drt.namespace(namespace)
                            && let Ok(comp) = ns.component(component)
                        {
                            let ep = comp.endpoint(endpoint);
                            if tx.send(ep).is_err() {
                                tracing::debug!("PrefillRouter activation channel already closed");
                            }
                            return;
                        }
                    }
                }
            }

            tokio::time::sleep(Duration::from_secs(1)).await;
        }
    });
}

1434
1435
1436
1437
1438
1439
1440
/// Fetch model card via discovery and create preprocessor.
///
/// This function:
/// 1. Lists all models via discovery
/// 2. Finds the first model in the target namespace (decode workers only)
/// 3. Downloads the model config (tokenizer files) if needed
/// 4. Creates an OpenAIPreprocessor from the model card
1441
/// 5. Returns the preprocessor, the model card, and the resolved worker namespace
1442
1443
1444
async fn fetch_preprocessor_from_discovery(
    drt: &DistributedRuntime,
    target_namespace: &str,
1445
) -> anyhow::Result<DiscoveredModelBootstrap> {
1446
    use dynamo_runtime::discovery::DiscoveryInstance;
1447

1448
    let discovery = drt.discovery();
1449

1450
1451
    // List all models
    let instances = discovery.list(DiscoveryQuery::AllModels).await?;
1452

1453
1454
1455
1456
    // Find first model card in the target namespace (decode workers only).
    // Use prefix matching because workers may append a rolling-update hash
    // suffix to the base namespace (e.g. "ns-dgd-58908edc" vs "ns-dgd").
    let mut model_card: Option<(ModelDeploymentCard, String)> = None;
1457

1458
1459
    for instance in instances {
        if let DiscoveryInstance::Model { namespace, .. } = &instance {
1460
            if !namespace.starts_with(target_namespace) {
1461
1462
1463
                continue;
            }

1464
            let actual_namespace = namespace.clone();
1465
1466
1467
1468
1469
1470
            match instance.deserialize_model::<ModelDeploymentCard>() {
                Ok(card) => {
                    // Skip prefill-only workers, we want decode workers for routing
                    if card.model_type.supports_prefill()
                        && !card.model_type.supports_chat()
                        && !card.model_type.supports_completions()
1471
                    {
1472
                        continue;
1473
                    }
1474
                    model_card = Some((card, actual_namespace));
1475
                    break;
1476
                }
1477
1478
1479
                Err(e) => {
                    tracing::debug!(error = %e, "Failed to deserialize model card, skipping");
                    continue;
1480
1481
1482
                }
            }
        }
1483
    }
1484

1485
    let (mut card, actual_namespace) = model_card.ok_or_else(|| {
1486
1487
1488
1489
1490
1491
1492
        anyhow::anyhow!(
            "No model found in namespace '{}' via discovery",
            target_namespace
        )
    })?;

    tracing::info!(
1493
1494
1495
1496
        model_name = %card.display_name,
        kv_cache_block_size = card.kv_cache_block_size,
        actual_namespace = %actual_namespace,
        enable_eagle = card.runtime_config.enable_eagle,
1497
        "Found model card via discovery"
1498
1499
    );

1500
1501
    // Download config (tokenizer files) if not local
    card.download_config().await?;
1502

1503
    // Create preprocessor
1504
1505
    let preprocessor = OpenAIPreprocessor::new(card.clone())?;
    Ok(DiscoveredModelBootstrap {
1506
        preprocessor,
1507
        card,
1508
        actual_namespace,
1509
    })
1510
}