model_manager.rs 30.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use dashmap::{DashMap, mapref::entry::Entry};
10
use parking_lot::{Mutex, RwLock};
11
use tokio::sync::oneshot;
12

13
14
use crate::discovery::KvWorkerMonitor;

15
use dynamo_runtime::{
16
    component::{Client, Endpoint, build_transport_type},
17
    discovery::{DiscoveryQuery, DiscoverySpec, watch_and_extract_field},
18
19
20
    prelude::DistributedRuntimeProvider,
    protocols::EndpointId,
};
21
22

use crate::{
23
24
25
26
27
    kv_router::{
        KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
        scheduler::DefaultWorkerSelector,
    },
    local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig},
28
    model_card::ModelDeploymentCard,
29
    model_type::ModelType,
30
31
32
33
34
35
36
37
    types::{
        generic::tensor::TensorStreamingEngine,
        openai::{
            chat_completions::OpenAIChatCompletionsStreamingEngine,
            completions::OpenAICompletionsStreamingEngine,
            embeddings::OpenAIEmbeddingsStreamingEngine,
        },
    },
38
};
39

40
41
42
43
44
45
46
47
/// State for prefill router activation rendezvous
enum PrefillActivationState {
    /// Decode model registered, waiting for prefill endpoint
    DecodeWaiting(oneshot::Sender<Endpoint>),
    /// Prefill endpoint arrived, waiting for decode model to register
    PrefillReady(oneshot::Receiver<Endpoint>),
}

48
49
50
51
52
53
54
55
56
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

57
58
59
60
61
62
/// Central manager for model engines, routing, and configuration.
///
/// Manages model lifecycle including engines, KV routers, prefill coordination,
/// and per-model busy thresholds for load-based request rejection.
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
63
64
65
66
67
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
68
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
69
70
    // Prefill models don't have engines - they're only tracked for discovery/lifecycle
    prefill_engines: RwLock<ModelEngines<()>>,
71

72
    // These are Mutex because we read and write rarely and equally
73
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
74
    kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>,
75
    prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
76
77
78
79
80

    /// Per-model worker monitors for dynamic KV cache load rejection.
    /// Key: model name, Value: cloneable monitor (all fields are Arc).
    /// HTTP endpoint can update thresholds via monitor.set_threshold().
    worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
81
82
83
84
85

    /// Runtime configs per endpoint using DashMap for lock-free access.
    /// Outer DashMap: keyed by EndpointId
    /// Inner Arc<DashMap>: keyed by WorkerId, shared with KvScheduler
    runtime_configs: DashMap<EndpointId, Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
100
            tensor_engines: RwLock::new(ModelEngines::default()),
101
            prefill_engines: RwLock::new(ModelEngines::default()),
102
            cards: Mutex::new(HashMap::new()),
103
            kv_choosers: Mutex::new(HashMap::new()),
104
            prefill_router_activators: Mutex::new(HashMap::new()),
105
            worker_monitors: RwLock::new(HashMap::new()),
106
            runtime_configs: DashMap::new(),
107
108
109
        }
    }

110
111
112
113
114
115
116
117
118
119
120
121
122
    pub fn is_valid_checksum(
        &self,
        model_type: ModelType,
        model_name: &str,
        candidate_checksum: &str,
    ) -> Option<bool> {
        let mut results = vec![];
        for unit in model_type.units() {
            let maybe_valid_checksum = match unit {
                ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
                ModelType::Completions => self.completion_engines.read().checksum(model_name),
                ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
                ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
123
                ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                _ => {
                    continue;
                }
            };
            if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
                tracing::debug!(
                    model_name,
                    valid_checksum,
                    candidate_checksum,
                    "is_valid_checksum: check case"
                );
                valid_checksum == candidate_checksum
            }) {
                results.push(is_valid)
            }
        }
        if results.is_empty() {
            None
        } else {
            // The checksum is valid if it is correct for all the ModelType in the bitflag.
            Some(results.into_iter().all(|x| x))
        }
    }

148
149
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
150
151
    }

152
153
    /// Check if a decode model (chat or completions) is registered
    pub fn has_decode_model(&self, model: &str) -> bool {
154
155
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
156
157
158
159
160
161
162
163
164
165
166
    }

    /// Check if a prefill model is registered
    pub fn has_prefill_model(&self, model: &str) -> bool {
        self.prefill_engines.read().contains(model)
    }

    /// Check if any model (decode or prefill) is registered.
    /// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
    pub fn has_model_any(&self, model: &str) -> bool {
        self.has_decode_model(model) || self.has_prefill_model(model)
167
168
    }

169
170
171
172
173
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
174
            .chain(self.list_tensor_models())
175
            .chain(self.list_prefill_models())
176
177
178
            .collect()
    }

179
    pub fn list_chat_completions_models(&self) -> Vec<String> {
180
        self.chat_completion_engines.read().list()
181
182
183
    }

    pub fn list_completions_models(&self) -> Vec<String> {
184
        self.completion_engines.read().list()
185
186
187
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
188
        self.embeddings_engines.read().list()
189
190
    }

191
192
193
194
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

195
196
197
198
    pub fn list_prefill_models(&self) -> Vec<String> {
        self.prefill_engines.read().list()
    }

199
200
201
    pub fn add_completions_model(
        &self,
        model: &str,
202
        card_checksum: &str,
203
204
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
205
        let mut clients = self.completion_engines.write();
206
        clients.add(model, card_checksum, engine)
207
208
209
210
211
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
212
        card_checksum: &str,
213
214
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
215
        let mut clients = self.chat_completion_engines.write();
216
        clients.add(model, card_checksum, engine)
217
218
219
220
221
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
222
        card_checksum: &str,
223
224
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
225
        let mut clients = self.embeddings_engines.write();
226
        clients.add(model, card_checksum, engine)
227
228
    }

229
230
231
    pub fn add_tensor_model(
        &self,
        model: &str,
232
        card_checksum: &str,
233
234
235
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
236
        clients.add(model, card_checksum, engine)
237
238
    }

239
240
241
242
243
244
    pub fn add_prefill_model(
        &self,
        model: &str,
        card_checksum: &str,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
245
        clients.add(model, card_checksum, ())
246
247
    }

248
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
249
        let mut clients = self.completion_engines.write();
250
251
252
253
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
254
        let mut clients = self.chat_completion_engines.write();
255
256
257
258
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
259
        let mut clients = self.embeddings_engines.write();
260
261
262
        clients.remove(model)
    }

263
264
265
266
267
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

268
269
270
271
272
    pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
        clients.remove(model)
    }

273
    pub fn get_embeddings_engine(
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

306
307
308
309
310
311
312
313
314
315
316
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

317
318
319
320
321
    /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.lock().insert(key.to_string(), card);
        Ok(())
322
323
    }

324
325
326
    /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
327
328
329
330
    }

    pub async fn kv_chooser_for(
        &self,
331
        endpoint: &Endpoint,
332
        kv_cache_block_size: u32,
333
        kv_router_config: Option<KvRouterConfig>,
334
    ) -> anyhow::Result<Arc<KvRouter>> {
335
        let endpoint_id = endpoint.id();
336

337
        if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_id) {
338
339
340
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
341
                    endpoint = %endpoint_id,
342
343
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
344
                    "KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
345
346
347
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
348
349
350
            return Ok(kv_chooser);
        }

351
        let client = endpoint.client().await?;
352
353
354
355
356

        // Register router via discovery mechanism
        let discovery = endpoint.component().drt().discovery();
        let instance_id = discovery.instance_id();

357
        // Build transport for router endpoint based on request plane mode
358
359
        // Use KV_ROUTER_COMPONENT as the component name to distinguish from the generate endpoint's component
        let router_endpoint_id = router_endpoint_id(endpoint.id().namespace);
360
        let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?;
361
362
363
364
365

        let discovery_spec = DiscoverySpec::Endpoint {
            namespace: router_endpoint_id.namespace.clone(),
            component: router_endpoint_id.component.clone(),
            endpoint: router_endpoint_id.name.clone(),
366
            transport,
367
368
369
370
371
372
        };

        discovery.register(discovery_spec).await?;

        // Use instance_id (hex) as the consumer ID for NATS consumer coordination
        let consumer_id = instance_id.to_string();
373

374
375
376
        // Get or create runtime config watcher for this endpoint
        let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;

377
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
378
        let chooser = KvRouter::new(
379
380
            endpoint.clone(),
            client,
381
            workers_with_configs,
382
383
            kv_cache_block_size,
            Some(selector),
384
            kv_router_config,
385
            consumer_id,
386
387
        )
        .await?;
388
389
390
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
391
            .insert(endpoint_id, new_kv_chooser.clone());
392
393
        Ok(new_kv_chooser)
    }
394

395
396
    fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(id).cloned()
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    }

    /// Register a prefill router for a decode model. Returns a receiver that will be
    /// activated when the corresponding prefill model is discovered.
    /// Returns None if the decode model was already registered.
    pub fn register_prefill_router(
        &self,
        model_name: String,
    ) -> Option<oneshot::Receiver<Endpoint>> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(&model_name) {
            Some(PrefillActivationState::PrefillReady(rx)) => {
                // Prefill endpoint already arrived - rx will immediately resolve
                tracing::debug!(
                    model_name = %model_name,
                    "Prefill endpoint already available, returning receiver with endpoint"
                );
                Some(rx)
            }
            Some(PrefillActivationState::DecodeWaiting(tx)) => {
                // Decode already registered - this shouldn't happen, restore state and return None
                tracing::error!(
                    model_name = %model_name,
                    "Decode model already registered for this prefill router"
                );
                activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
                None
            }
            None => {
                // New registration: create tx/rx pair, store sender and return receiver
                let (tx, rx) = oneshot::channel();
                activators.insert(
                    model_name.clone(),
                    PrefillActivationState::DecodeWaiting(tx),
                );
                tracing::debug!(
                    model_name = %model_name,
                    "No prefill endpoint available yet, storing sender for future activation"
                );
                Some(rx)
            }
        }
    }

    /// Activate a prefill router by sending the endpoint through the oneshot channel.
    /// If no decode model has registered yet, stores the endpoint for future retrieval.
    pub fn activate_prefill_router(
        &self,
        model_name: &str,
        endpoint: Endpoint,
    ) -> anyhow::Result<()> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(model_name) {
            Some(PrefillActivationState::DecodeWaiting(sender)) => {
                // Decode model already registered
                sender.send(endpoint).map_err(|_| {
                    anyhow::anyhow!(
                        "Failed to send endpoint to prefill router activator for model: {}",
                        model_name
                    )
                })?;

                tracing::info!(
                    model_name = %model_name,
                    "Activated prefill router for already-registered decode model"
                );

                Ok(())
            }
            Some(PrefillActivationState::PrefillReady(_)) => {
                // Prefill already activated - this shouldn't happen
                anyhow::bail!("Prefill router for model {} already activated", model_name);
            }
            None => {
                // Decode model not registered yet - create pair and immediately send endpoint
                let (tx, rx) = oneshot::channel();

                tx.send(endpoint).map_err(|_| {
                    anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
                })?;

                // Store the receiver for when decode model registers
                activators.insert(
                    model_name.to_string(),
                    PrefillActivationState::PrefillReady(rx),
                );

                tracing::info!(
                    model_name = %model_name,
                    "Stored prefill endpoint for future decode model registration"
                );

                Ok(())
            }
        }
494
495
    }

496
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
497
        self.cards
498
499
            .lock()
            .values()
500
501
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
502
            .map(|parser| parser.to_string())
503
    }
504
505
506
507
508
509
510
511
512

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
513
514
515

    /// Gets or sets the busy threshold for a model via its worker monitor.
    ///
516
517
    /// Get or set the active decode blocks threshold for a model's worker monitor.
    ///
518
519
    /// This is the primary API for HTTP endpoints and external callers.
    /// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
520
    /// based on KV cache block utilization.
521
522
523
524
525
526
527
528
529
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as f64, or `None` if no monitor exists for this model.
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    pub fn active_decode_blocks_threshold(
        &self,
        model: &str,
        threshold: Option<f64>,
    ) -> Option<f64> {
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
                monitor.set_active_decode_blocks_threshold(value);
                Some(value)
            }
            None => Some(monitor.active_decode_blocks_threshold()),
        }
    }

    /// Get or set the active prefill tokens threshold for a model's worker monitor.
    ///
    /// The threshold is a literal token count (not a percentage).
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as u64, or `None` if no monitor exists for this model.
    pub fn active_prefill_tokens_threshold(
        &self,
        model: &str,
        threshold: Option<u64>,
    ) -> Option<u64> {
564
565
566
567
568
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
569
                monitor.set_active_prefill_tokens_threshold(value);
570
571
                Some(value)
            }
572
            None => Some(monitor.active_prefill_tokens_threshold()),
573
574
575
576
577
        }
    }

    /// Gets or creates a worker monitor for a model.
    ///
578
579
    /// If a monitor already exists, updates its thresholds and returns a clone.
    /// If no monitor exists, creates one with the given client and thresholds.
580
581
582
583
584
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `client` - The client for subscribing to KV metrics (only used if creating new)
585
586
    /// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
    /// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
587
588
589
590
591
592
593
    ///
    /// # Returns
    ///
    /// A cloneable monitor that shares state with the stored instance.
    pub fn get_or_create_worker_monitor(
        &self,
        model: &str,
594
        client: Client,
595
596
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
597
598
599
600
    ) -> KvWorkerMonitor {
        let mut monitors = self.worker_monitors.write();

        if let Some(existing) = monitors.get(model) {
601
602
            existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
            existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
603
604
            existing.clone()
        } else {
605
606
607
608
609
            let monitor = KvWorkerMonitor::new(
                client,
                active_decode_blocks_threshold,
                active_prefill_tokens_threshold,
            );
610
611
612
613
614
615
616
617
618
619
            monitors.insert(model.to_string(), monitor.clone());
            monitor
        }
    }

    /// Gets an existing worker monitor for a model, if one exists.
    pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
        self.worker_monitors.read().get(model).cloned()
    }

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    /// Get or create a runtime config watcher for an endpoint.
    /// Spawns a background task to watch DiscoveryQuery::EndpointModels.
    /// Returns a shared Arc<DashMap> that KvScheduler can use directly.
    pub async fn get_or_create_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
    ) -> anyhow::Result<Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>> {
        let endpoint_id = endpoint.id();

        // Fast path: return existing if present
        if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
            return Ok(existing.clone());
        }

        // Atomic get-or-insert to avoid TOCTOU race
        let inner_map = Arc::new(DashMap::new());
        let (map, is_new) = match self.runtime_configs.entry(endpoint_id) {
            Entry::Occupied(e) => (e.get().clone(), false),
            Entry::Vacant(e) => {
                e.insert(inner_map.clone());
                (inner_map, true)
            }
        };

        // Only spawn watcher if we were the one who inserted
        if is_new {
            self.spawn_runtime_config_watcher(endpoint, map.clone())
                .await?;
        }

        Ok(map)
    }

    /// Get disaggregated endpoint for a specific worker.
    /// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
    pub fn get_disaggregated_endpoint(
        &self,
        endpoint_id: &EndpointId,
        worker_id: WorkerId,
    ) -> Option<DisaggregatedEndpoint> {
        let inner_map = self.runtime_configs.get(endpoint_id)?;
        let config_ref = inner_map.get(&worker_id)?;
        config_ref.as_ref()?.disaggregated_endpoint.clone()
    }

    /// Spawn background task to watch runtime configs via discovery.
    async fn spawn_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
        inner_map: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
    ) -> anyhow::Result<()> {
        let component = endpoint.component();
        let cancellation_token = component.drt().primary_token();

        // Set up discovery watch for EndpointModels
        let discovery = component.drt().discovery();
        let endpoint_id = endpoint.id();
        let discovery_key = DiscoveryQuery::EndpointModels {
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
            .await?;

        // Extract runtime_config from ModelDeploymentCard
        let mut runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });

        // Also watch instance IDs
        let client = endpoint.client().await?;
        let mut instance_ids_rx = client.instance_avail_watcher();

        // Spawn background task to update inner_map
        let cancel_token = cancellation_token.clone();
        tokio::spawn(async move {
            tracing::trace!("ModelManager runtime config watcher started");
            loop {
                // Wait for either instances or configs to change
                tokio::select! {
                    _ = cancel_token.cancelled() => {
                        tracing::trace!("ModelManager runtime config watcher shutting down");
                        break;
                    }
                    result = instance_ids_rx.changed() => {
                        if result.is_err() {
                            tracing::warn!("instance IDs watch sender shutdown in ModelManager");
                            break;
                        }
                    }
                    result = runtime_configs_rx.changed() => {
                        if result.is_err() {
                            tracing::warn!("runtime configs watch sender shutdown in ModelManager");
                            break;
                        }
                    }
                }

                // Get the latest values from both channels
                let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
                let new_configs = runtime_configs_rx.borrow_and_update().clone();

                // Update the DashMap
                // First, remove workers that no longer exist
                let current_workers: HashSet<WorkerId> =
                    inner_map.iter().map(|r| *r.key()).collect();
                let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
                for removed_worker in current_workers.difference(&new_workers) {
                    inner_map.remove(removed_worker);
                }

                // Then, add/update workers
                for worker_id in &new_instance_ids {
                    let config = new_configs.get(worker_id).cloned();
                    if config.is_some() {
                        let prev_config = inner_map.get(worker_id);
                        if prev_config.as_ref().map(|r| r.value()) != Some(&config) {
                            tracing::info!(
                                "ModelManager: Runtime config found for worker_id: {}",
                                worker_id
                            );
                        }
                    }
                    inner_map.insert(*worker_id, config);
                }

                tracing::trace!(
                    "ModelManager: Updated runtime_configs with {} workers",
                    inner_map.len()
                );
            }
            tracing::trace!("ModelManager runtime config watcher shutting down");
        });

        Ok(())
    }

760
761
    /// Lists all models that have worker monitors (and thus busy thresholds) configured.
    ///
762
763
    /// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
    pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
764
765
766
        self.worker_monitors
            .read()
            .iter()
767
768
769
770
771
772
773
            .map(|(k, monitor)| {
                (
                    k.clone(),
                    monitor.active_decode_blocks_threshold(),
                    monitor.active_prefill_tokens_threshold(),
                )
            })
774
775
            .collect()
    }
776
777
778
779
780
781
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
782
783
784
    /// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
    /// same card.
    checksums: HashMap<String, String>,
785
786
787
788
789
790
791
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
792
            checksums: HashMap::new(),
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

808
    fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
809
810
811
812
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
813
814
        self.checksums
            .insert(model.to_string(), checksum.to_string());
815
816
817
818
819
820
821
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
822
        let _ = self.checksums.remove(model);
823
824
825
826
827
828
829
830
831
832
833
834
835
836
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
837
838
839
840
841
842

    /// Returns a newly allocated String for called convenience. All the places I use
    /// this I need a String.
    pub fn checksum(&self, model: &str) -> Option<String> {
        self.checksums.get(model).map(|s| s.to_string())
    }
843
}