model_manager.rs 26.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use dashmap::{DashMap, mapref::entry::Entry};
10
use parking_lot::{Mutex, RwLock};
11
use tokio::sync::oneshot;
12

13
use crate::discovery::KvWorkerMonitor;
14
use crate::discovery::runtime_configs::RuntimeConfigs;
15

16
use dynamo_runtime::{
17
    component::{Client, Endpoint, build_transport_type},
18
    discovery::DiscoverySpec,
19
20
21
    prelude::DistributedRuntimeProvider,
    protocols::EndpointId,
};
22
23

use crate::{
24
25
26
27
    kv_router::{
        KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
        scheduler::DefaultWorkerSelector,
    },
28
    local_model::runtime_config::DisaggregatedEndpoint,
29
    model_card::ModelDeploymentCard,
30
    model_type::ModelType,
31
32
33
34
35
36
37
38
    types::{
        generic::tensor::TensorStreamingEngine,
        openai::{
            chat_completions::OpenAIChatCompletionsStreamingEngine,
            completions::OpenAICompletionsStreamingEngine,
            embeddings::OpenAIEmbeddingsStreamingEngine,
        },
    },
39
};
40

41
42
43
44
45
46
47
48
/// State for prefill router activation rendezvous
enum PrefillActivationState {
    /// Decode model registered, waiting for prefill endpoint
    DecodeWaiting(oneshot::Sender<Endpoint>),
    /// Prefill endpoint arrived, waiting for decode model to register
    PrefillReady(oneshot::Receiver<Endpoint>),
}

49
50
51
52
53
54
55
56
57
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

58
59
60
61
62
63
/// Central manager for model engines, routing, and configuration.
///
/// Manages model lifecycle including engines, KV routers, prefill coordination,
/// and per-model busy thresholds for load-based request rejection.
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
64
65
66
67
68
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
69
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
70
71
    // Prefill models don't have engines - they're only tracked for discovery/lifecycle
    prefill_engines: RwLock<ModelEngines<()>>,
72

73
    // These are Mutex because we read and write rarely and equally
74
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
75
    kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>,
76
    prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
77
78
79
80
81

    /// Per-model worker monitors for dynamic KV cache load rejection.
    /// Key: model name, Value: cloneable monitor (all fields are Arc).
    /// HTTP endpoint can update thresholds via monitor.set_threshold().
    worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
82
83
84

    /// Runtime configs per endpoint using DashMap for lock-free access.
    /// Outer DashMap: keyed by EndpointId
85
86
    /// Inner RuntimeConfigs: shared with KvScheduler
    runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigs>>,
87
88
89
90
91
92
93
94
95
96
97
98
99
100
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
101
            tensor_engines: RwLock::new(ModelEngines::default()),
102
            prefill_engines: RwLock::new(ModelEngines::default()),
103
            cards: Mutex::new(HashMap::new()),
104
            kv_choosers: Mutex::new(HashMap::new()),
105
            prefill_router_activators: Mutex::new(HashMap::new()),
106
            worker_monitors: RwLock::new(HashMap::new()),
107
            runtime_configs: DashMap::new(),
108
109
110
        }
    }

111
112
113
114
115
116
117
118
119
120
121
122
123
    pub fn is_valid_checksum(
        &self,
        model_type: ModelType,
        model_name: &str,
        candidate_checksum: &str,
    ) -> Option<bool> {
        let mut results = vec![];
        for unit in model_type.units() {
            let maybe_valid_checksum = match unit {
                ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
                ModelType::Completions => self.completion_engines.read().checksum(model_name),
                ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
                ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
124
                ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
                _ => {
                    continue;
                }
            };
            if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
                tracing::debug!(
                    model_name,
                    valid_checksum,
                    candidate_checksum,
                    "is_valid_checksum: check case"
                );
                valid_checksum == candidate_checksum
            }) {
                results.push(is_valid)
            }
        }
        if results.is_empty() {
            None
        } else {
            // The checksum is valid if it is correct for all the ModelType in the bitflag.
            Some(results.into_iter().all(|x| x))
        }
    }

149
150
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
151
152
    }

153
154
    /// Check if a decode model (chat or completions) is registered
    pub fn has_decode_model(&self, model: &str) -> bool {
155
156
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
157
158
159
160
161
162
163
164
165
166
167
    }

    /// Check if a prefill model is registered
    pub fn has_prefill_model(&self, model: &str) -> bool {
        self.prefill_engines.read().contains(model)
    }

    /// Check if any model (decode or prefill) is registered.
    /// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
    pub fn has_model_any(&self, model: &str) -> bool {
        self.has_decode_model(model) || self.has_prefill_model(model)
168
169
    }

170
171
172
173
174
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
175
            .chain(self.list_tensor_models())
176
            .chain(self.list_prefill_models())
177
178
179
            .collect()
    }

180
    pub fn list_chat_completions_models(&self) -> Vec<String> {
181
        self.chat_completion_engines.read().list()
182
183
184
    }

    pub fn list_completions_models(&self) -> Vec<String> {
185
        self.completion_engines.read().list()
186
187
188
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
189
        self.embeddings_engines.read().list()
190
191
    }

192
193
194
195
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

196
197
198
199
    pub fn list_prefill_models(&self) -> Vec<String> {
        self.prefill_engines.read().list()
    }

200
201
202
    pub fn add_completions_model(
        &self,
        model: &str,
203
        card_checksum: &str,
204
205
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
206
        let mut clients = self.completion_engines.write();
207
        clients.add(model, card_checksum, engine)
208
209
210
211
212
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
213
        card_checksum: &str,
214
215
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
216
        let mut clients = self.chat_completion_engines.write();
217
        clients.add(model, card_checksum, engine)
218
219
220
221
222
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
223
        card_checksum: &str,
224
225
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
226
        let mut clients = self.embeddings_engines.write();
227
        clients.add(model, card_checksum, engine)
228
229
    }

230
231
232
    pub fn add_tensor_model(
        &self,
        model: &str,
233
        card_checksum: &str,
234
235
236
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
237
        clients.add(model, card_checksum, engine)
238
239
    }

240
241
242
243
244
245
    pub fn add_prefill_model(
        &self,
        model: &str,
        card_checksum: &str,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
246
        clients.add(model, card_checksum, ())
247
248
    }

249
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
250
        let mut clients = self.completion_engines.write();
251
252
253
254
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
255
        let mut clients = self.chat_completion_engines.write();
256
257
258
259
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
260
        let mut clients = self.embeddings_engines.write();
261
262
263
        clients.remove(model)
    }

264
265
266
267
268
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

269
270
271
272
273
    pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
        clients.remove(model)
    }

274
    pub fn get_embeddings_engine(
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

307
308
309
310
311
312
313
314
315
316
317
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

318
    /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
319
320
321
322
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.lock().insert(key.to_string(), card);
        Ok(())
323
324
    }

325
    /// Remove and return model card for this instance's key. We do this when the instance stops.
326
327
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
328
329
330
331
    }

    pub async fn kv_chooser_for(
        &self,
332
        endpoint: &Endpoint,
333
        kv_cache_block_size: u32,
334
        kv_router_config: Option<KvRouterConfig>,
335
    ) -> anyhow::Result<Arc<KvRouter>> {
336
        let endpoint_id = endpoint.id();
337

338
        if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_id) {
339
340
341
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
342
                    endpoint = %endpoint_id,
343
344
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
345
                    "KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
346
347
348
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
349
350
351
            return Ok(kv_chooser);
        }

352
        let client = endpoint.client().await?;
353
354
355
356
357

        // Register router via discovery mechanism
        let discovery = endpoint.component().drt().discovery();
        let instance_id = discovery.instance_id();

358
        // Build transport for router endpoint based on request plane mode
359
360
        // Use KV_ROUTER_COMPONENT as the component name to distinguish from the generate endpoint's component
        let router_endpoint_id = router_endpoint_id(endpoint.id().namespace);
361
        let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?;
362
363
364
365
366

        let discovery_spec = DiscoverySpec::Endpoint {
            namespace: router_endpoint_id.namespace.clone(),
            component: router_endpoint_id.component.clone(),
            endpoint: router_endpoint_id.name.clone(),
367
            transport,
368
369
370
371
        };

        discovery.register(discovery_spec).await?;

372
373
374
        // Get or create runtime config watcher for this endpoint
        let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;

375
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
376
        let chooser = KvRouter::new(
377
378
            endpoint.clone(),
            client,
379
            workers_with_configs,
380
381
            kv_cache_block_size,
            Some(selector),
382
            kv_router_config,
383
            instance_id,
384
385
        )
        .await?;
386
387
388
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
389
            .insert(endpoint_id, new_kv_chooser.clone());
390
391
        Ok(new_kv_chooser)
    }
392

393
394
    fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(id).cloned()
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    }

    /// Register a prefill router for a decode model. Returns a receiver that will be
    /// activated when the corresponding prefill model is discovered.
    /// Returns None if the decode model was already registered.
    pub fn register_prefill_router(
        &self,
        model_name: String,
    ) -> Option<oneshot::Receiver<Endpoint>> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(&model_name) {
            Some(PrefillActivationState::PrefillReady(rx)) => {
                // Prefill endpoint already arrived - rx will immediately resolve
                tracing::debug!(
                    model_name = %model_name,
                    "Prefill endpoint already available, returning receiver with endpoint"
                );
                Some(rx)
            }
            Some(PrefillActivationState::DecodeWaiting(tx)) => {
                // Decode already registered - this shouldn't happen, restore state and return None
                tracing::error!(
                    model_name = %model_name,
                    "Decode model already registered for this prefill router"
                );
                activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
                None
            }
            None => {
                // New registration: create tx/rx pair, store sender and return receiver
                let (tx, rx) = oneshot::channel();
                activators.insert(
                    model_name.clone(),
                    PrefillActivationState::DecodeWaiting(tx),
                );
                tracing::debug!(
                    model_name = %model_name,
                    "No prefill endpoint available yet, storing sender for future activation"
                );
                Some(rx)
            }
        }
    }

    /// Activate a prefill router by sending the endpoint through the oneshot channel.
    /// If no decode model has registered yet, stores the endpoint for future retrieval.
    pub fn activate_prefill_router(
        &self,
        model_name: &str,
        endpoint: Endpoint,
    ) -> anyhow::Result<()> {
        let mut activators = self.prefill_router_activators.lock();

        match activators.remove(model_name) {
            Some(PrefillActivationState::DecodeWaiting(sender)) => {
                // Decode model already registered
                sender.send(endpoint).map_err(|_| {
                    anyhow::anyhow!(
                        "Failed to send endpoint to prefill router activator for model: {}",
                        model_name
                    )
                })?;

                tracing::info!(
                    model_name = %model_name,
                    "Activated prefill router for already-registered decode model"
                );

                Ok(())
            }
            Some(PrefillActivationState::PrefillReady(_)) => {
                // Prefill already activated - this shouldn't happen
                anyhow::bail!("Prefill router for model {} already activated", model_name);
            }
            None => {
                // Decode model not registered yet - create pair and immediately send endpoint
                let (tx, rx) = oneshot::channel();

                tx.send(endpoint).map_err(|_| {
                    anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
                })?;

                // Store the receiver for when decode model registers
                activators.insert(
                    model_name.to_string(),
                    PrefillActivationState::PrefillReady(rx),
                );

                tracing::info!(
                    model_name = %model_name,
                    "Stored prefill endpoint for future decode model registration"
                );

                Ok(())
            }
        }
492
493
    }

494
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
495
        self.cards
496
497
            .lock()
            .values()
498
499
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
500
            .map(|parser| parser.to_string())
501
    }
502
503
504
505
506
507
508
509
510

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
511
512
513

    /// Gets or sets the busy threshold for a model via its worker monitor.
    ///
514
515
    /// Get or set the active decode blocks threshold for a model's worker monitor.
    ///
516
517
    /// This is the primary API for HTTP endpoints and external callers.
    /// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
518
    /// based on KV cache block utilization.
519
520
521
522
523
524
525
526
527
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as f64, or `None` if no monitor exists for this model.
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    pub fn active_decode_blocks_threshold(
        &self,
        model: &str,
        threshold: Option<f64>,
    ) -> Option<f64> {
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
                monitor.set_active_decode_blocks_threshold(value);
                Some(value)
            }
            None => Some(monitor.active_decode_blocks_threshold()),
        }
    }

    /// Get or set the active prefill tokens threshold for a model's worker monitor.
    ///
    /// The threshold is a literal token count (not a percentage).
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `threshold` - `Some(value)` to set, `None` to get existing
    ///
    /// # Returns
    ///
    /// The threshold value as u64, or `None` if no monitor exists for this model.
    pub fn active_prefill_tokens_threshold(
        &self,
        model: &str,
        threshold: Option<u64>,
    ) -> Option<u64> {
562
563
564
565
566
        let monitors = self.worker_monitors.read();
        let monitor = monitors.get(model)?;

        match threshold {
            Some(value) => {
567
                monitor.set_active_prefill_tokens_threshold(value);
568
569
                Some(value)
            }
570
            None => Some(monitor.active_prefill_tokens_threshold()),
571
572
573
574
575
        }
    }

    /// Gets or creates a worker monitor for a model.
    ///
576
577
    /// If a monitor already exists, updates its thresholds and returns a clone.
    /// If no monitor exists, creates one with the given client and thresholds.
578
579
580
581
582
    ///
    /// # Arguments
    ///
    /// * `model` - The model name
    /// * `client` - The client for subscribing to KV metrics (only used if creating new)
583
584
    /// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
    /// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
585
586
587
588
589
590
591
    ///
    /// # Returns
    ///
    /// A cloneable monitor that shares state with the stored instance.
    pub fn get_or_create_worker_monitor(
        &self,
        model: &str,
592
        client: Client,
593
594
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
595
596
597
598
    ) -> KvWorkerMonitor {
        let mut monitors = self.worker_monitors.write();

        if let Some(existing) = monitors.get(model) {
599
600
            existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
            existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
601
602
            existing.clone()
        } else {
603
604
605
606
607
            let monitor = KvWorkerMonitor::new(
                client,
                active_decode_blocks_threshold,
                active_prefill_tokens_threshold,
            );
608
609
610
611
612
613
614
615
616
617
            monitors.insert(model.to_string(), monitor.clone());
            monitor
        }
    }

    /// Gets an existing worker monitor for a model, if one exists.
    pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
        self.worker_monitors.read().get(model).cloned()
    }

618
    /// Get or create a runtime config watcher for an endpoint.
619
620
    /// Spawns a background task to watch for worker config changes.
    /// Returns a shared RuntimeConfigs that KvScheduler can use directly.
621
622
623
    pub async fn get_or_create_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
624
    ) -> anyhow::Result<Arc<RuntimeConfigs>> {
625
626
627
628
629
630
631
632
        let endpoint_id = endpoint.id();

        // Fast path: return existing if present
        if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
            return Ok(existing.clone());
        }

        // Atomic get-or-insert to avoid TOCTOU race
633
        let inner = Arc::new(RuntimeConfigs::new());
634
        let (result, is_new) = match self.runtime_configs.entry(endpoint_id) {
635
636
            Entry::Occupied(e) => (e.get().clone(), false),
            Entry::Vacant(e) => {
637
638
                e.insert(inner.clone());
                (inner, true)
639
640
641
642
643
            }
        };

        // Only spawn watcher if we were the one who inserted
        if is_new {
644
            result.start_watcher(endpoint).await?;
645
646
        }

647
        Ok(result)
648
649
650
651
652
653
654
655
656
    }

    /// Get disaggregated endpoint for a specific worker.
    /// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
    pub fn get_disaggregated_endpoint(
        &self,
        endpoint_id: &EndpointId,
        worker_id: WorkerId,
    ) -> Option<DisaggregatedEndpoint> {
657
658
        let inner = self.runtime_configs.get(endpoint_id)?;
        let config_ref = inner.configs.get(&worker_id)?;
659
660
661
        config_ref.as_ref()?.disaggregated_endpoint.clone()
    }

662
663
    /// Lists all models that have worker monitors (and thus busy thresholds) configured.
    ///
664
665
    /// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
    pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
666
667
668
        self.worker_monitors
            .read()
            .iter()
669
670
671
672
673
674
675
            .map(|(k, monitor)| {
                (
                    k.clone(),
                    monitor.active_decode_blocks_threshold(),
                    monitor.active_prefill_tokens_threshold(),
                )
            })
676
677
            .collect()
    }
678
679
680
681
682
683
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
684
685
686
    /// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
    /// same card.
    checksums: HashMap<String, String>,
687
688
689
690
691
692
693
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
694
            checksums: HashMap::new(),
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

710
    fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
711
712
713
714
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
715
716
        self.checksums
            .insert(model.to_string(), checksum.to_string());
717
718
719
720
721
722
723
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
724
        let _ = self.checksums.remove(model);
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
739
740
741
742
743
744

    /// Returns a newly allocated String for called convenience. All the places I use
    /// this I need a String.
    pub fn checksum(&self, model: &str) -> Option<String> {
        self.checksums.get(model).map(|s| s.to_string())
    }
745
}