model_manager.rs 31.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use dashmap::{DashMap, mapref::entry::Entry};
10
use parking_lot::{Mutex, RwLock};
11
use tokio::sync::{Notify, oneshot};
12

13
14
use crate::discovery::KvWorkerMonitor;

15
use dynamo_runtime::{
16
    component::{Client, Endpoint, build_transport_type},
17
    discovery::{DiscoveryQuery, DiscoverySpec, watch_and_extract_field},
18
19
20
    prelude::DistributedRuntimeProvider,
    protocols::EndpointId,
};
21
22

use crate::{
23
24
25
26
27
    kv_router::{
        KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
        scheduler::DefaultWorkerSelector,
    },
    local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig},
28
    model_card::ModelDeploymentCard,
29
    model_type::ModelType,
30
31
32
33
34
35
36
37
    types::{
        generic::tensor::TensorStreamingEngine,
        openai::{
            chat_completions::OpenAIChatCompletionsStreamingEngine,
            completions::OpenAICompletionsStreamingEngine,
            embeddings::OpenAIEmbeddingsStreamingEngine,
        },
    },
38
};
39

40
41
42
43
44
45
46
47
/// State for prefill router activation rendezvous
enum PrefillActivationState {
    /// Decode model registered, waiting for prefill endpoint
    DecodeWaiting(oneshot::Sender<Endpoint>),
    /// Prefill endpoint arrived, waiting for decode model to register
    PrefillReady(oneshot::Receiver<Endpoint>),
}

48
49
50
51
52
53
54
55
56
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

57
58
59
60
61
62
/// Central manager for model engines, routing, and configuration.
///
/// Manages model lifecycle including engines, KV routers, prefill coordination,
/// and per-model busy thresholds for load-based request rejection.
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
63
64
65
66
67
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
68
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
69
70
    // Prefill models don't have engines - they're only tracked for discovery/lifecycle
    prefill_engines: RwLock<ModelEngines<()>>,
71

72
    // These are Mutex because we read and write rarely and equally
73
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
74
    kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>,
75
    prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
76
77
78
79
80

    /// Per-model worker monitors for dynamic KV cache load rejection.
    /// Key: model name, Value: cloneable monitor (all fields are Arc).
    /// HTTP endpoint can update thresholds via monitor.set_threshold().
    worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
81
82
83

    /// Runtime configs per endpoint using DashMap for lock-free access.
    /// Outer DashMap: keyed by EndpointId
84
85
86
87
88
89
90
91
    /// Inner RuntimeConfigsWithNotify: shared with KvScheduler
    runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigsWithNotify>>,
}

/// Runtime configs for an endpoint with a notify for change notifications.
pub struct RuntimeConfigsWithNotify {
    pub configs: DashMap<WorkerId, Option<ModelRuntimeConfig>>,
    pub notify: Notify,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
106
            tensor_engines: RwLock::new(ModelEngines::default()),
107
            prefill_engines: RwLock::new(ModelEngines::default()),
108
            cards: Mutex::new(HashMap::new()),
109
            kv_choosers: Mutex::new(HashMap::new()),
110
            prefill_router_activators: Mutex::new(HashMap::new()),
111
            worker_monitors: RwLock::new(HashMap::new()),
112
            runtime_configs: DashMap::new(),
113
114
115
        }
    }

116
117
118
119
120
121
122
123
124
125
126
127
128
    pub fn is_valid_checksum(
        &self,
        model_type: ModelType,
        model_name: &str,
        candidate_checksum: &str,
    ) -> Option<bool> {
        let mut results = vec![];
        for unit in model_type.units() {
            let maybe_valid_checksum = match unit {
                ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
                ModelType::Completions => self.completion_engines.read().checksum(model_name),
                ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
                ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
129
                ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                _ => {
                    continue;
                }
            };
            if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
                tracing::debug!(
                    model_name,
                    valid_checksum,
                    candidate_checksum,
                    "is_valid_checksum: check case"
                );
                valid_checksum == candidate_checksum
            }) {
                results.push(is_valid)
            }
        }
        if results.is_empty() {
            None
        } else {
            // The checksum is valid if it is correct for all the ModelType in the bitflag.
            Some(results.into_iter().all(|x| x))
        }
    }

154
155
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
156
157
    }

158
159
    /// Check if a decode model (chat or completions) is registered
    pub fn has_decode_model(&self, model: &str) -> bool {
160
161
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
162
163
164
165
166
167
168
169
170
171
172
    }

    /// Check if a prefill model is registered
    pub fn has_prefill_model(&self, model: &str) -> bool {
        self.prefill_engines.read().contains(model)
    }

    /// Check if any model (decode or prefill) is registered.
    /// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
    pub fn has_model_any(&self, model: &str) -> bool {
        self.has_decode_model(model) || self.has_prefill_model(model)
173
174
    }

175
176
177
178
179
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
180
            .chain(self.list_tensor_models())
181
            .chain(self.list_prefill_models())
182
183
184
            .collect()
    }

185
    pub fn list_chat_completions_models(&self) -> Vec<String> {
186
        self.chat_completion_engines.read().list()
187
188
189
    }

    pub fn list_completions_models(&self) -> Vec<String> {
190
        self.completion_engines.read().list()
191
192
193
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
194
        self.embeddings_engines.read().list()
195
196
    }

197
198
199
200
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

201
202
203
204
    pub fn list_prefill_models(&self) -> Vec<String> {
        self.prefill_engines.read().list()
    }

205
206
207
    pub fn add_completions_model(
        &self,
        model: &str,
208
        card_checksum: &str,
209
210
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
211
        let mut clients = self.completion_engines.write();
212
        clients.add(model, card_checksum, engine)
213
214
215
216
217
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
218
        card_checksum: &str,
219
220
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
221
        let mut clients = self.chat_completion_engines.write();
222
        clients.add(model, card_checksum, engine)
223
224
225
226
227
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
228
        card_checksum: &str,
229
230
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
231
        let mut clients = self.embeddings_engines.write();
232
        clients.add(model, card_checksum, engine)
233
234
    }

235
236
237
    pub fn add_tensor_model(
        &self,
        model: &str,
238
        card_checksum: &str,
239
240
241
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
242
        clients.add(model, card_checksum, engine)
243
244
    }

245
246
247
248
249
250
    pub fn add_prefill_model(
        &self,
        model: &str,
        card_checksum: &str,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
251
        clients.add(model, card_checksum, ())
252
253
    }

254
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
255
        let mut clients = self.completion_engines.write();
256
257
258
259
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
260
        let mut clients = self.chat_completion_engines.write();
261
262
263
264
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
265
        let mut clients = self.embeddings_engines.write();
266
267
268
        clients.remove(model)
    }

269
270
271
272
273
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

274
275
276
277
278
    pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
        clients.remove(model)
    }

279
    pub fn get_embeddings_engine(
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

312
313
314
315
316
317
318
319
320
321
322
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

323
    /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
324
325
326
327
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.lock().insert(key.to_string(), card);
        Ok(())
328
329
    }

330
    /// Remove and return model card for this instance's key. We do this when the instance stops.
331
332
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
333
334
335
336
    }

    pub async fn kv_chooser_for(
        &self,
337
        endpoint: &Endpoint,
338
        kv_cache_block_size: u32,
339
        kv_router_config: Option<KvRouterConfig>,
340
    ) -> anyhow::Result<Arc<KvRouter>> {
341
        let endpoint_id = endpoint.id();
342

343
        if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_id) {
344
345
346
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
347
                    endpoint = %endpoint_id,
348
349
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
350
                    "KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
351
352
353
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
354
355
356
            return Ok(kv_chooser);
        }

357
        let client = endpoint.client().await?;
358
359
360
361
362

        // Register router via discovery mechanism
        let discovery = endpoint.component().drt().discovery();
        let instance_id = discovery.instance_id();

363
        // Build transport for router endpoint based on request plane mode
364
365
        // Use KV_ROUTER_COMPONENT as the component name to distinguish from the generate endpoint's component
        let router_endpoint_id = router_endpoint_id(endpoint.id().namespace);
366
        let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?;
367
368
369
370
371

        let discovery_spec = DiscoverySpec::Endpoint {
            namespace: router_endpoint_id.namespace.clone(),
            component: router_endpoint_id.component.clone(),
            endpoint: router_endpoint_id.name.clone(),
372
            transport,
373
374
375
376
377
378
        };

        discovery.register(discovery_spec).await?;

        // Use instance_id (hex) as the consumer ID for NATS consumer coordination
        let consumer_id = instance_id.to_string();
379

380
381
382
        // Get or create runtime config watcher for this endpoint
        let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;

383
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
384
        let chooser = KvRouter::new(
385
386
            endpoint.clone(),
            client,
387
            workers_with_configs,
388
389
            kv_cache_block_size,
            Some(selector),
390
            kv_router_config,
391
            consumer_id,
392
393
        )
        .await?;
394
395
396
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
397
            .insert(endpoint_id, new_kv_chooser.clone());
398
399
        Ok(new_kv_chooser)
    }
400

401
402
    fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(id).cloned()
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    }

    /// Register a prefill router for a decode model. Returns a receiver that will be
    /// activated when the corresponding prefill model is discovered.
    /// Returns None if the decode model was already registered.
    pub fn register_prefill_router(
        &self,
        model_name: String,
    ) -> Option<oneshot::Receiver<Endpoint>> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(&model_name) {
            Some(PrefillActivationState::PrefillReady(rx)) => {
                // Prefill endpoint already arrived - rx will immediately resolve
                tracing::debug!(
                    model_name = %model_name,
                    "Prefill endpoint already available, returning receiver with endpoint"
                );
                Some(rx)
            }
            Some(PrefillActivationState::DecodeWaiting(tx)) => {
                // Decode already registered - this shouldn't happen, restore state and return None
                tracing::error!(
                    model_name = %model_name,
                    "Decode model already registered for this prefill router"
                );
                activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
                None
            }
            None => {
                // New registration: create tx/rx pair, store sender and return receiver
                let (tx, rx) = oneshot::channel();
                activators.insert(
                    model_name.clone(),
                    PrefillActivationState::DecodeWaiting(tx),
                );
                tracing::debug!(
                    model_name = %model_name,
                    "No prefill endpoint available yet, storing sender for future activation"
                );
                Some(rx)
            }
        }
    }

    /// Activate a prefill router by sending the endpoint through the oneshot channel.
    /// If no decode model has registered yet, stores the endpoint for future retrieval.
    pub fn activate_prefill_router(
        &self,
        model_name: &str,
        endpoint: Endpoint,
    ) -> anyhow::Result<()> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(model_name) {
            Some(PrefillActivationState::DecodeWaiting(sender)) => {
                // Decode model already registered
                sender.send(endpoint).map_err(|_| {
                    anyhow::anyhow!(
                        "Failed to send endpoint to prefill router activator for model: {}",
                        model_name
                    )
                })?;

                tracing::info!(
                    model_name = %model_name,
                    "Activated prefill router for already-registered decode model"
                );

                Ok(())
            }
            Some(PrefillActivationState::PrefillReady(_)) => {
                // Prefill already activated - this shouldn't happen
                anyhow::bail!("Prefill router for model {} already activated", model_name);
            }
            None => {
                // Decode model not registered yet - create pair and immediately send endpoint
                let (tx, rx) = oneshot::channel();

                tx.send(endpoint).map_err(|_| {
                    anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
                })?;

                // Store the receiver for when decode model registers
                activators.insert(
                    model_name.to_string(),
                    PrefillActivationState::PrefillReady(rx),
                );

                tracing::info!(
                    model_name = %model_name,
                    "Stored prefill endpoint for future decode model registration"
                );

                Ok(())
            }
        }
500
501
    }

502
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
503
        self.cards
504
505
            .lock()
            .values()
506
507
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
508
            .map(|parser| parser.to_string())
509
    }
510
511
512
513
514
515
516
517
518

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
519
520
521

    /// Gets or sets the busy threshold for a model via its worker monitor.
    ///
522
523
    /// Get or set the active decode blocks threshold for a model's worker monitor.
    ///
524
525
    /// This is the primary API for HTTP endpoints and external callers.
    /// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
526
    /// based on KV cache block utilization.
527
528
529
530
531
532
533
534
535
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as f64, or `None` if no monitor exists for this model.
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    pub fn active_decode_blocks_threshold(
        &self,
        model: &str,
        threshold: Option<f64>,
    ) -> Option<f64> {
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
                monitor.set_active_decode_blocks_threshold(value);
                Some(value)
            }
            None => Some(monitor.active_decode_blocks_threshold()),
        }
    }

    /// Get or set the active prefill tokens threshold for a model's worker monitor.
    ///
    /// The threshold is a literal token count (not a percentage).
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as u64, or `None` if no monitor exists for this model.
    pub fn active_prefill_tokens_threshold(
        &self,
        model: &str,
        threshold: Option<u64>,
    ) -> Option<u64> {
570
571
572
573
574
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
575
                monitor.set_active_prefill_tokens_threshold(value);
576
577
                Some(value)
            }
578
            None => Some(monitor.active_prefill_tokens_threshold()),
579
580
581
582
583
        }
    }

    /// Gets or creates a worker monitor for a model.
    ///
584
585
    /// If a monitor already exists, updates its thresholds and returns a clone.
    /// If no monitor exists, creates one with the given client and thresholds.
586
587
588
589
590
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `client` - The client for subscribing to KV metrics (only used if creating new)
591
592
    /// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
    /// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
593
594
595
596
597
598
599
    ///
    /// # Returns
    ///
    /// A cloneable monitor that shares state with the stored instance.
    pub fn get_or_create_worker_monitor(
        &self,
        model: &str,
600
        client: Client,
601
602
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
603
604
605
606
    ) -> KvWorkerMonitor {
        let mut monitors = self.worker_monitors.write();

        if let Some(existing) = monitors.get(model) {
607
608
            existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
            existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
609
610
            existing.clone()
        } else {
611
612
613
614
615
            let monitor = KvWorkerMonitor::new(
                client,
                active_decode_blocks_threshold,
                active_prefill_tokens_threshold,
            );
616
617
618
619
620
621
622
623
624
625
            monitors.insert(model.to_string(), monitor.clone());
            monitor
        }
    }

    /// Gets an existing worker monitor for a model, if one exists.
    pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
        self.worker_monitors.read().get(model).cloned()
    }

626
627
    /// Get or create a runtime config watcher for an endpoint.
    /// Spawns a background task to watch DiscoveryQuery::EndpointModels.
628
    /// Returns a shared RuntimeConfigsWithNotify that KvScheduler can use directly.
629
630
631
    pub async fn get_or_create_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
632
    ) -> anyhow::Result<Arc<RuntimeConfigsWithNotify>> {
633
634
635
636
637
638
639
640
        let endpoint_id = endpoint.id();

        // Fast path: return existing if present
        if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
            return Ok(existing.clone());
        }

        // Atomic get-or-insert to avoid TOCTOU race
641
642
643
644
645
        let inner = Arc::new(RuntimeConfigsWithNotify {
            configs: DashMap::new(),
            notify: Notify::new(),
        });
        let (result, is_new) = match self.runtime_configs.entry(endpoint_id) {
646
647
            Entry::Occupied(e) => (e.get().clone(), false),
            Entry::Vacant(e) => {
648
649
                e.insert(inner.clone());
                (inner, true)
650
651
652
653
654
            }
        };

        // Only spawn watcher if we were the one who inserted
        if is_new {
655
            self.spawn_runtime_config_watcher(endpoint, result.clone())
656
657
658
                .await?;
        }

659
        Ok(result)
660
661
662
663
664
665
666
667
668
    }

    /// Get disaggregated endpoint for a specific worker.
    /// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
    pub fn get_disaggregated_endpoint(
        &self,
        endpoint_id: &EndpointId,
        worker_id: WorkerId,
    ) -> Option<DisaggregatedEndpoint> {
669
670
        let inner = self.runtime_configs.get(endpoint_id)?;
        let config_ref = inner.configs.get(&worker_id)?;
671
672
673
674
        config_ref.as_ref()?.disaggregated_endpoint.clone()
    }

    /// Spawn background task to watch runtime configs via discovery.
675
    /// Blocks until at least one worker with a runtime config is available.
676
677
678
    async fn spawn_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
679
        inner: Arc<RuntimeConfigsWithNotify>,
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    ) -> anyhow::Result<()> {
        let component = endpoint.component();
        let cancellation_token = component.drt().primary_token();

        // Set up discovery watch for EndpointModels
        let discovery = component.drt().discovery();
        let endpoint_id = endpoint.id();
        let discovery_key = DiscoveryQuery::EndpointModels {
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
            .await?;

        // Extract runtime_config from ModelDeploymentCard
        let mut runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });

        // Also watch instance IDs
        let client = endpoint.client().await?;
        let mut instance_ids_rx = client.instance_avail_watcher();

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        // Wait for at least one worker with runtime config before proceeding.
        // This ensures the DashMap is populated before KvScheduler starts.
        tracing::info!("ModelManager: Waiting for at least one worker with runtime config...");
        runtime_configs_rx
            .changed()
            .await
            .map_err(|_| anyhow::anyhow!("runtime configs watch sender shutdown while waiting"))?;

        // Populate initial state
        {
            let instance_ids = instance_ids_rx.borrow();
            let configs = runtime_configs_rx.borrow();
            for worker_id in instance_ids.iter() {
                let config = configs.get(worker_id).cloned();
                inner.configs.insert(*worker_id, config);
            }
            tracing::info!(
                "ModelManager: Found {} workers, proceeding",
                inner.configs.len()
            );
        }

        // Spawn background task to update configs for future changes
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        let cancel_token = cancellation_token.clone();
        tokio::spawn(async move {
            tracing::trace!("ModelManager runtime config watcher started");
            loop {
                // Wait for either instances or configs to change
                tokio::select! {
                    _ = cancel_token.cancelled() => {
                        tracing::trace!("ModelManager runtime config watcher shutting down");
                        break;
                    }
                    result = instance_ids_rx.changed() => {
                        if result.is_err() {
                            tracing::warn!("instance IDs watch sender shutdown in ModelManager");
                            break;
                        }
                    }
                    result = runtime_configs_rx.changed() => {
                        if result.is_err() {
                            tracing::warn!("runtime configs watch sender shutdown in ModelManager");
                            break;
                        }
                    }
                }

                // Get the latest values from both channels
                let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
                let new_configs = runtime_configs_rx.borrow_and_update().clone();

                // Update the DashMap
                // First, remove workers that no longer exist
                let current_workers: HashSet<WorkerId> =
760
                    inner.configs.iter().map(|r| *r.key()).collect();
761
762
                let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
                for removed_worker in current_workers.difference(&new_workers) {
763
                    inner.configs.remove(removed_worker);
764
765
766
767
768
769
                }

                // Then, add/update workers
                for worker_id in &new_instance_ids {
                    let config = new_configs.get(worker_id).cloned();
                    if config.is_some() {
770
                        let prev_config = inner.configs.get(worker_id);
771
772
                        if prev_config.as_ref().map(|r| r.value()) != Some(&config) {
                            tracing::info!(
773
                                "ModelManager: Runtime config found for worker_id: {worker_id}"
774
775
776
                            );
                        }
                    }
777
                    inner.configs.insert(*worker_id, config);
778
779
                }

780
781
782
                // Notify waiters that configs have changed
                inner.notify.notify_waiters();

783
784
                tracing::trace!(
                    "ModelManager: Updated runtime_configs with {} workers",
785
                    inner.configs.len()
786
787
788
789
790
791
792
793
                );
            }
            tracing::trace!("ModelManager runtime config watcher shutting down");
        });

        Ok(())
    }

794
795
    /// Lists all models that have worker monitors (and thus busy thresholds) configured.
    ///
796
797
    /// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
    pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
798
799
800
        self.worker_monitors
            .read()
            .iter()
801
802
803
804
805
806
807
            .map(|(k, monitor)| {
                (
                    k.clone(),
                    monitor.active_decode_blocks_threshold(),
                    monitor.active_prefill_tokens_threshold(),
                )
            })
808
809
            .collect()
    }
810
811
812
813
814
815
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
816
817
818
    /// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
    /// same card.
    checksums: HashMap<String, String>,
819
820
821
822
823
824
825
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
826
            checksums: HashMap::new(),
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

842
    fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
843
844
845
846
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
847
848
        self.checksums
            .insert(model.to_string(), checksum.to_string());
849
850
851
852
853
854
855
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
856
        let _ = self.checksums.remove(model);
857
858
859
860
861
862
863
864
865
866
867
868
869
870
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
871
872
873
874
875
876

    /// Returns a newly allocated String for called convenience. All the places I use
    /// this I need a String.
    pub fn checksum(&self, model: &str) -> Option<String> {
        self.checksums.get(model).map(|s| s.to_string())
    }
877
}