model_manager.rs 31.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use dashmap::{DashMap, mapref::entry::Entry};
10
use parking_lot::{Mutex, RwLock};
11
use tokio::sync::{Notify, oneshot};
12

13
14
use crate::discovery::KvWorkerMonitor;

15
use dynamo_runtime::{
16
    component::{Client, Endpoint, build_transport_type},
17
    discovery::{DiscoveryQuery, DiscoverySpec, watch_and_extract_field},
18
19
20
    prelude::DistributedRuntimeProvider,
    protocols::EndpointId,
};
21
22

use crate::{
23
24
25
26
27
    kv_router::{
        KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
        scheduler::DefaultWorkerSelector,
    },
    local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig},
28
    model_card::ModelDeploymentCard,
29
    model_type::ModelType,
30
31
32
33
34
35
36
37
    types::{
        generic::tensor::TensorStreamingEngine,
        openai::{
            chat_completions::OpenAIChatCompletionsStreamingEngine,
            completions::OpenAICompletionsStreamingEngine,
            embeddings::OpenAIEmbeddingsStreamingEngine,
        },
    },
38
};
39

40
41
42
43
44
45
46
47
/// State for prefill router activation rendezvous
enum PrefillActivationState {
    /// Decode model registered, waiting for prefill endpoint
    DecodeWaiting(oneshot::Sender<Endpoint>),
    /// Prefill endpoint arrived, waiting for decode model to register
    PrefillReady(oneshot::Receiver<Endpoint>),
}

48
49
50
51
52
53
54
55
56
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

57
58
59
60
61
62
/// Central manager for model engines, routing, and configuration.
///
/// Manages model lifecycle including engines, KV routers, prefill coordination,
/// and per-model busy thresholds for load-based request rejection.
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
63
64
65
66
67
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
68
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
69
70
    // Prefill models don't have engines - they're only tracked for discovery/lifecycle
    prefill_engines: RwLock<ModelEngines<()>>,
71

72
    // These are Mutex because we read and write rarely and equally
73
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
74
    kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>,
75
    prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
76
77
78
79
80

    /// Per-model worker monitors for dynamic KV cache load rejection.
    /// Key: model name, Value: cloneable monitor (all fields are Arc).
    /// HTTP endpoint can update thresholds via monitor.set_threshold().
    worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
81
82
83

    /// Runtime configs per endpoint using DashMap for lock-free access.
    /// Outer DashMap: keyed by EndpointId
84
85
86
87
88
89
90
91
    /// Inner RuntimeConfigsWithNotify: shared with KvScheduler
    runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigsWithNotify>>,
}

/// Runtime configs for an endpoint with a notify for change notifications.
pub struct RuntimeConfigsWithNotify {
    pub configs: DashMap<WorkerId, Option<ModelRuntimeConfig>>,
    pub notify: Notify,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
106
            tensor_engines: RwLock::new(ModelEngines::default()),
107
            prefill_engines: RwLock::new(ModelEngines::default()),
108
            cards: Mutex::new(HashMap::new()),
109
            kv_choosers: Mutex::new(HashMap::new()),
110
            prefill_router_activators: Mutex::new(HashMap::new()),
111
            worker_monitors: RwLock::new(HashMap::new()),
112
            runtime_configs: DashMap::new(),
113
114
115
        }
    }

116
117
118
119
120
121
122
123
124
125
126
127
128
    pub fn is_valid_checksum(
        &self,
        model_type: ModelType,
        model_name: &str,
        candidate_checksum: &str,
    ) -> Option<bool> {
        let mut results = vec![];
        for unit in model_type.units() {
            let maybe_valid_checksum = match unit {
                ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
                ModelType::Completions => self.completion_engines.read().checksum(model_name),
                ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
                ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
129
                ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                _ => {
                    continue;
                }
            };
            if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
                tracing::debug!(
                    model_name,
                    valid_checksum,
                    candidate_checksum,
                    "is_valid_checksum: check case"
                );
                valid_checksum == candidate_checksum
            }) {
                results.push(is_valid)
            }
        }
        if results.is_empty() {
            None
        } else {
            // The checksum is valid if it is correct for all the ModelType in the bitflag.
            Some(results.into_iter().all(|x| x))
        }
    }

154
155
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
156
157
    }

158
159
    /// Check if a decode model (chat or completions) is registered
    pub fn has_decode_model(&self, model: &str) -> bool {
160
161
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
162
163
164
165
166
167
168
169
170
171
172
    }

    /// Check if a prefill model is registered
    pub fn has_prefill_model(&self, model: &str) -> bool {
        self.prefill_engines.read().contains(model)
    }

    /// Check if any model (decode or prefill) is registered.
    /// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
    pub fn has_model_any(&self, model: &str) -> bool {
        self.has_decode_model(model) || self.has_prefill_model(model)
173
174
    }

175
176
177
178
179
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
180
            .chain(self.list_tensor_models())
181
            .chain(self.list_prefill_models())
182
183
184
            .collect()
    }

185
    pub fn list_chat_completions_models(&self) -> Vec<String> {
186
        self.chat_completion_engines.read().list()
187
188
189
    }

    pub fn list_completions_models(&self) -> Vec<String> {
190
        self.completion_engines.read().list()
191
192
193
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
194
        self.embeddings_engines.read().list()
195
196
    }

197
198
199
200
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

201
202
203
204
    pub fn list_prefill_models(&self) -> Vec<String> {
        self.prefill_engines.read().list()
    }

205
206
207
    pub fn add_completions_model(
        &self,
        model: &str,
208
        card_checksum: &str,
209
210
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
211
        let mut clients = self.completion_engines.write();
212
        clients.add(model, card_checksum, engine)
213
214
215
216
217
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
218
        card_checksum: &str,
219
220
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
221
        let mut clients = self.chat_completion_engines.write();
222
        clients.add(model, card_checksum, engine)
223
224
225
226
227
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
228
        card_checksum: &str,
229
230
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
231
        let mut clients = self.embeddings_engines.write();
232
        clients.add(model, card_checksum, engine)
233
234
    }

235
236
237
    pub fn add_tensor_model(
        &self,
        model: &str,
238
        card_checksum: &str,
239
240
241
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
242
        clients.add(model, card_checksum, engine)
243
244
    }

245
246
247
248
249
250
    pub fn add_prefill_model(
        &self,
        model: &str,
        card_checksum: &str,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
251
        clients.add(model, card_checksum, ())
252
253
    }

254
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
255
        let mut clients = self.completion_engines.write();
256
257
258
259
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
260
        let mut clients = self.chat_completion_engines.write();
261
262
263
264
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
265
        let mut clients = self.embeddings_engines.write();
266
267
268
        clients.remove(model)
    }

269
270
271
272
273
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

274
275
276
277
278
    pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
        clients.remove(model)
    }

279
    pub fn get_embeddings_engine(
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

312
313
314
315
316
317
318
319
320
321
322
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

323
    /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
324
325
326
327
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.lock().insert(key.to_string(), card);
        Ok(())
328
329
    }

330
    /// Remove and return model card for this instance's key. We do this when the instance stops.
331
332
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
333
334
335
336
    }

    pub async fn kv_chooser_for(
        &self,
337
        endpoint: &Endpoint,
338
        kv_cache_block_size: u32,
339
        kv_router_config: Option<KvRouterConfig>,
340
    ) -> anyhow::Result<Arc<KvRouter>> {
341
        let endpoint_id = endpoint.id();
342

343
        if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_id) {
344
345
346
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
347
                    endpoint = %endpoint_id,
348
349
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
350
                    "KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
351
352
353
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
354
355
356
            return Ok(kv_chooser);
        }

357
        let client = endpoint.client().await?;
358
359
360
361
362

        // Register router via discovery mechanism
        let discovery = endpoint.component().drt().discovery();
        let instance_id = discovery.instance_id();

363
        // Build transport for router endpoint based on request plane mode
364
365
        // Use KV_ROUTER_COMPONENT as the component name to distinguish from the generate endpoint's component
        let router_endpoint_id = router_endpoint_id(endpoint.id().namespace);
366
        let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?;
367
368
369
370
371

        let discovery_spec = DiscoverySpec::Endpoint {
            namespace: router_endpoint_id.namespace.clone(),
            component: router_endpoint_id.component.clone(),
            endpoint: router_endpoint_id.name.clone(),
372
            transport,
373
374
375
376
        };

        discovery.register(discovery_spec).await?;

377
378
379
        // Get or create runtime config watcher for this endpoint
        let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;

380
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
381
        let chooser = KvRouter::new(
382
383
            endpoint.clone(),
            client,
384
            workers_with_configs,
385
386
            kv_cache_block_size,
            Some(selector),
387
            kv_router_config,
388
            instance_id,
389
390
        )
        .await?;
391
392
393
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
394
            .insert(endpoint_id, new_kv_chooser.clone());
395
396
        Ok(new_kv_chooser)
    }
397

398
399
    fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(id).cloned()
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    }

    /// Register a prefill router for a decode model. Returns a receiver that will be
    /// activated when the corresponding prefill model is discovered.
    /// Returns None if the decode model was already registered.
    pub fn register_prefill_router(
        &self,
        model_name: String,
    ) -> Option<oneshot::Receiver<Endpoint>> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(&model_name) {
            Some(PrefillActivationState::PrefillReady(rx)) => {
                // Prefill endpoint already arrived - rx will immediately resolve
                tracing::debug!(
                    model_name = %model_name,
                    "Prefill endpoint already available, returning receiver with endpoint"
                );
                Some(rx)
            }
            Some(PrefillActivationState::DecodeWaiting(tx)) => {
                // Decode already registered - this shouldn't happen, restore state and return None
                tracing::error!(
                    model_name = %model_name,
                    "Decode model already registered for this prefill router"
                );
                activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
                None
            }
            None => {
                // New registration: create tx/rx pair, store sender and return receiver
                let (tx, rx) = oneshot::channel();
                activators.insert(
                    model_name.clone(),
                    PrefillActivationState::DecodeWaiting(tx),
                );
                tracing::debug!(
                    model_name = %model_name,
                    "No prefill endpoint available yet, storing sender for future activation"
                );
                Some(rx)
            }
        }
    }

    /// Activate a prefill router by sending the endpoint through the oneshot channel.
    /// If no decode model has registered yet, stores the endpoint for future retrieval.
    pub fn activate_prefill_router(
        &self,
        model_name: &str,
        endpoint: Endpoint,
    ) -> anyhow::Result<()> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(model_name) {
            Some(PrefillActivationState::DecodeWaiting(sender)) => {
                // Decode model already registered
                sender.send(endpoint).map_err(|_| {
                    anyhow::anyhow!(
                        "Failed to send endpoint to prefill router activator for model: {}",
                        model_name
                    )
                })?;

                tracing::info!(
                    model_name = %model_name,
                    "Activated prefill router for already-registered decode model"
                );

                Ok(())
            }
            Some(PrefillActivationState::PrefillReady(_)) => {
                // Prefill already activated - this shouldn't happen
                anyhow::bail!("Prefill router for model {} already activated", model_name);
            }
            None => {
                // Decode model not registered yet - create pair and immediately send endpoint
                let (tx, rx) = oneshot::channel();

                tx.send(endpoint).map_err(|_| {
                    anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
                })?;

                // Store the receiver for when decode model registers
                activators.insert(
                    model_name.to_string(),
                    PrefillActivationState::PrefillReady(rx),
                );

                tracing::info!(
                    model_name = %model_name,
                    "Stored prefill endpoint for future decode model registration"
                );

                Ok(())
            }
        }
497
498
    }

499
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
500
        self.cards
501
502
            .lock()
            .values()
503
504
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
505
            .map(|parser| parser.to_string())
506
    }
507
508
509
510
511
512
513
514
515

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
516
517
518

    /// Gets or sets the busy threshold for a model via its worker monitor.
    ///
519
520
    /// Get or set the active decode blocks threshold for a model's worker monitor.
    ///
521
522
    /// This is the primary API for HTTP endpoints and external callers.
    /// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
523
    /// based on KV cache block utilization.
524
525
526
527
528
529
530
531
532
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as f64, or `None` if no monitor exists for this model.
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    pub fn active_decode_blocks_threshold(
        &self,
        model: &str,
        threshold: Option<f64>,
    ) -> Option<f64> {
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
                monitor.set_active_decode_blocks_threshold(value);
                Some(value)
            }
            None => Some(monitor.active_decode_blocks_threshold()),
        }
    }

    /// Get or set the active prefill tokens threshold for a model's worker monitor.
    ///
    /// The threshold is a literal token count (not a percentage).
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as u64, or `None` if no monitor exists for this model.
    pub fn active_prefill_tokens_threshold(
        &self,
        model: &str,
        threshold: Option<u64>,
    ) -> Option<u64> {
567
568
569
570
571
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
572
                monitor.set_active_prefill_tokens_threshold(value);
573
574
                Some(value)
            }
575
            None => Some(monitor.active_prefill_tokens_threshold()),
576
577
578
579
580
        }
    }

    /// Gets or creates a worker monitor for a model.
    ///
581
582
    /// If a monitor already exists, updates its thresholds and returns a clone.
    /// If no monitor exists, creates one with the given client and thresholds.
583
584
585
586
587
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `client` - The client for subscribing to KV metrics (only used if creating new)
588
589
    /// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
    /// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
590
591
592
593
594
595
596
    ///
    /// # Returns
    ///
    /// A cloneable monitor that shares state with the stored instance.
    pub fn get_or_create_worker_monitor(
        &self,
        model: &str,
597
        client: Client,
598
599
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
600
601
602
603
    ) -> KvWorkerMonitor {
        let mut monitors = self.worker_monitors.write();

        if let Some(existing) = monitors.get(model) {
604
605
            existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
            existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
606
607
            existing.clone()
        } else {
608
609
610
611
612
            let monitor = KvWorkerMonitor::new(
                client,
                active_decode_blocks_threshold,
                active_prefill_tokens_threshold,
            );
613
614
615
616
617
618
619
620
621
622
            monitors.insert(model.to_string(), monitor.clone());
            monitor
        }
    }

    /// Gets an existing worker monitor for a model, if one exists.
    pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
        self.worker_monitors.read().get(model).cloned()
    }

623
624
    /// Get or create a runtime config watcher for an endpoint.
    /// Spawns a background task to watch DiscoveryQuery::EndpointModels.
625
    /// Returns a shared RuntimeConfigsWithNotify that KvScheduler can use directly.
626
627
628
    pub async fn get_or_create_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
629
    ) -> anyhow::Result<Arc<RuntimeConfigsWithNotify>> {
630
631
632
633
634
635
636
637
        let endpoint_id = endpoint.id();

        // Fast path: return existing if present
        if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
            return Ok(existing.clone());
        }

        // Atomic get-or-insert to avoid TOCTOU race
638
639
640
641
642
        let inner = Arc::new(RuntimeConfigsWithNotify {
            configs: DashMap::new(),
            notify: Notify::new(),
        });
        let (result, is_new) = match self.runtime_configs.entry(endpoint_id) {
643
644
            Entry::Occupied(e) => (e.get().clone(), false),
            Entry::Vacant(e) => {
645
646
                e.insert(inner.clone());
                (inner, true)
647
648
649
650
651
            }
        };

        // Only spawn watcher if we were the one who inserted
        if is_new {
652
            self.spawn_runtime_config_watcher(endpoint, result.clone())
653
654
655
                .await?;
        }

656
        Ok(result)
657
658
659
660
661
662
663
664
665
    }

    /// Get disaggregated endpoint for a specific worker.
    /// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
    pub fn get_disaggregated_endpoint(
        &self,
        endpoint_id: &EndpointId,
        worker_id: WorkerId,
    ) -> Option<DisaggregatedEndpoint> {
666
667
        let inner = self.runtime_configs.get(endpoint_id)?;
        let config_ref = inner.configs.get(&worker_id)?;
668
669
670
671
        config_ref.as_ref()?.disaggregated_endpoint.clone()
    }

    /// Spawn background task to watch runtime configs via discovery.
672
    /// Blocks until at least one worker with a runtime config is available.
673
674
675
    async fn spawn_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
676
        inner: Arc<RuntimeConfigsWithNotify>,
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    ) -> anyhow::Result<()> {
        let component = endpoint.component();
        let cancellation_token = component.drt().primary_token();

        // Set up discovery watch for EndpointModels
        let discovery = component.drt().discovery();
        let endpoint_id = endpoint.id();
        let discovery_key = DiscoveryQuery::EndpointModels {
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
            .await?;

        // Extract runtime_config from ModelDeploymentCard
        let mut runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });

        // Also watch instance IDs
        let client = endpoint.client().await?;
        let mut instance_ids_rx = client.instance_avail_watcher();

703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        // Wait for at least one worker with runtime config before proceeding.
        // This ensures the DashMap is populated before KvScheduler starts.
        tracing::info!("ModelManager: Waiting for at least one worker with runtime config...");
        runtime_configs_rx
            .changed()
            .await
            .map_err(|_| anyhow::anyhow!("runtime configs watch sender shutdown while waiting"))?;

        // Populate initial state
        {
            let instance_ids = instance_ids_rx.borrow();
            let configs = runtime_configs_rx.borrow();
            for worker_id in instance_ids.iter() {
                let config = configs.get(worker_id).cloned();
                inner.configs.insert(*worker_id, config);
            }
            tracing::info!(
                "ModelManager: Found {} workers, proceeding",
                inner.configs.len()
            );
        }

        // Spawn background task to update configs for future changes
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        let cancel_token = cancellation_token.clone();
        tokio::spawn(async move {
            tracing::trace!("ModelManager runtime config watcher started");
            loop {
                // Wait for either instances or configs to change
                tokio::select! {
                    _ = cancel_token.cancelled() => {
                        tracing::trace!("ModelManager runtime config watcher shutting down");
                        break;
                    }
                    result = instance_ids_rx.changed() => {
                        if result.is_err() {
                            tracing::warn!("instance IDs watch sender shutdown in ModelManager");
                            break;
                        }
                    }
                    result = runtime_configs_rx.changed() => {
                        if result.is_err() {
                            tracing::warn!("runtime configs watch sender shutdown in ModelManager");
                            break;
                        }
                    }
                }

                // Get the latest values from both channels
                let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
                let new_configs = runtime_configs_rx.borrow_and_update().clone();

                // Update the DashMap
                // First, remove workers that no longer exist
                let current_workers: HashSet<WorkerId> =
757
                    inner.configs.iter().map(|r| *r.key()).collect();
758
759
                let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
                for removed_worker in current_workers.difference(&new_workers) {
760
                    inner.configs.remove(removed_worker);
761
762
763
764
765
766
                }

                // Then, add/update workers
                for worker_id in &new_instance_ids {
                    let config = new_configs.get(worker_id).cloned();
                    if config.is_some() {
767
                        let prev_config = inner.configs.get(worker_id);
768
769
                        if prev_config.as_ref().map(|r| r.value()) != Some(&config) {
                            tracing::info!(
770
                                "ModelManager: Runtime config found for worker_id: {worker_id}"
771
772
773
                            );
                        }
                    }
774
                    inner.configs.insert(*worker_id, config);
775
776
                }

777
778
779
                // Notify waiters that configs have changed
                inner.notify.notify_waiters();

780
781
                tracing::trace!(
                    "ModelManager: Updated runtime_configs with {} workers",
782
                    inner.configs.len()
783
784
785
786
787
788
789
790
                );
            }
            tracing::trace!("ModelManager runtime config watcher shutting down");
        });

        Ok(())
    }

791
792
    /// Lists all models that have worker monitors (and thus busy thresholds) configured.
    ///
793
794
    /// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
    pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
795
796
797
        self.worker_monitors
            .read()
            .iter()
798
799
800
801
802
803
804
            .map(|(k, monitor)| {
                (
                    k.clone(),
                    monitor.active_decode_blocks_threshold(),
                    monitor.active_prefill_tokens_threshold(),
                )
            })
805
806
            .collect()
    }
807
808
809
810
811
812
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
813
814
815
    /// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
    /// same card.
    checksums: HashMap<String, String>,
816
817
818
819
820
821
822
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
823
            checksums: HashMap::new(),
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

839
    fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
840
841
842
843
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
844
845
        self.checksums
            .insert(model.to_string(), checksum.to_string());
846
847
848
849
850
851
852
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
853
        let _ = self.checksums.remove(model);
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
868
869
870
871
872
873

    /// Returns a newly allocated String for called convenience. All the places I use
    /// this I need a String.
    pub fn checksum(&self, model: &str) -> Option<String> {
        self.checksums.get(model).map(|s| s.to_string())
    }
874
}