model_manager.rs 36 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::{collections::HashSet, sync::Arc};
5

6
use dashmap::{DashMap, mapref::entry::Entry};
7
use dynamo_kv_router::{config::KvRouterConfig, protocols::WorkerId};
8
use tokio::sync::oneshot;
9

10
use super::worker_monitor::LoadThresholdConfig;
11
use super::{KvWorkerMonitor, Model, RuntimeConfigWatch, WorkerSet, runtime_config_watch};
12

13
use dynamo_runtime::{
14
    component::{Endpoint, build_transport_type},
15
    discovery::DiscoverySpec,
16
17
18
    prelude::DistributedRuntimeProvider,
    protocols::EndpointId,
};
19
20

use crate::{
21
    kv_router::{KvRouter, router_endpoint_id, scheduler::DefaultWorkerSelector},
22
    local_model::runtime_config::DisaggregatedEndpoint,
23
24
25
26
27
28
    model_card::ModelDeploymentCard,
    types::{
        generic::tensor::TensorStreamingEngine,
        openai::{
            chat_completions::OpenAIChatCompletionsStreamingEngine,
            completions::OpenAICompletionsStreamingEngine,
29
            embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
30
            videos::OpenAIVideosStreamingEngine,
31
32
        },
    },
33
};
34

35
36
37
38
39
40
41
42
/// State for prefill router activation rendezvous
enum PrefillActivationState {
    /// Decode model registered, waiting for prefill endpoint
    DecodeWaiting(oneshot::Sender<Endpoint>),
    /// Prefill endpoint arrived, waiting for decode model to register
    PrefillReady(oneshot::Receiver<Endpoint>),
}

43
44
45
46
47
48
49
50
51
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

52
53
/// Central manager for model engines, routing, and configuration.
///
54
55
/// Models are stored hierarchically: ModelManager → Model → WorkerSet.
/// Each WorkerSet owns a complete pipeline built from its specific configuration.
56
57
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
58
pub struct ModelManager {
59
60
    /// Model name → Model (which contains WorkerSets with engines)
    models: DashMap<String, Arc<Model>>,
61

62
    /// Per-instance model cards, keyed by instance path. Used for cleanup on worker removal.
63
    cards: DashMap<String, ModelDeploymentCard>,
64
65

    /// Prefill router activation rendezvous, keyed by "model_name:namespace".
66
    prefill_router_activators: DashMap<String, PrefillActivationState>,
67

68
    /// Per-endpoint runtime config watchers. Keyed by EndpointId (includes namespace).
69
    runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>,
70
71
72
73
74
75
76
77
78
79
80
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
81
            models: DashMap::new(),
82
83
            cards: DashMap::new(),
            prefill_router_activators: DashMap::new(),
84
            runtime_configs: DashMap::new(),
85
86
87
        }
    }

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    // -- Model access --

    /// Get or create a Model for the given name.
    pub fn get_or_create_model(&self, model_name: &str) -> Arc<Model> {
        self.models
            .entry(model_name.to_string())
            .or_insert_with(|| Arc::new(Model::new(model_name.to_string())))
            .clone()
    }

    /// Get an existing Model, if it exists.
    pub fn get_model(&self, model_name: &str) -> Option<Arc<Model>> {
        self.models
            .get(model_name)
            .map(|entry| entry.value().clone())
    }

    /// Remove a Model if it has no remaining WorkerSets.
    /// Uses atomic remove_if to avoid TOCTOU race between checking is_empty and removing.
    pub fn remove_model_if_empty(&self, model_name: &str) {
        if self
            .models
            .remove_if(model_name, |_, model| model.is_empty())
            .is_some()
        {
            tracing::info!(model_name, "Removed empty model from manager");
        }
    }

    /// Add a WorkerSet to a Model. Creates the Model if it doesn't exist.
118
    pub fn add_worker_set(&self, model_name: &str, namespace: &str, worker_set: WorkerSet) {
119
        let model = self.get_or_create_model(model_name);
120
        model.add_worker_set(namespace.to_string(), Arc::new(worker_set));
121
122
123
124
125
126
127
128
129
    }

    /// Remove a WorkerSet from a Model. Removes the Model if it becomes empty.
    pub fn remove_worker_set(&self, model_name: &str, namespace: &str) -> Option<Arc<WorkerSet>> {
        let model = self.models.get(model_name)?;
        let removed = model.remove_worker_set(namespace);
        drop(model);
        self.remove_model_if_empty(model_name);
        removed
130
131
    }

132
133
    // -- Model cards --

134
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
135
        self.cards.iter().map(|r| r.value().clone()).collect()
136
137
    }

138
139
140
141
142
143
144
145
146
147
148
149
150
151
    /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.insert(key.to_string(), card);
        Ok(())
    }

    /// Remove and return model card for this instance's key. We do this when the instance stops.
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.remove(key).map(|(_, v)| v)
    }

    // -- Engine accessors (delegate through Model → WorkerSet) --

152
153
    /// Check if a decode model (chat or completions) is registered
    pub fn has_decode_model(&self, model: &str) -> bool {
154
155
156
        self.models
            .get(model)
            .is_some_and(|m| m.has_decode_engine())
157
158
159
160
    }

    /// Check if a prefill model is registered
    pub fn has_prefill_model(&self, model: &str) -> bool {
161
        self.models.get(model).is_some_and(|m| m.has_prefill())
162
163
164
165
166
    }

    /// Check if any model (decode or prefill) is registered.
    pub fn has_model_any(&self, model: &str) -> bool {
        self.has_decode_model(model) || self.has_prefill_model(model)
167
168
    }

169
    pub fn model_display_names(&self) -> HashSet<String> {
170
171
172
173
174
        self.models
            .iter()
            .filter(|entry| entry.value().is_displayable())
            .map(|entry| entry.key().clone())
            .collect()
175
176
    }

177
    pub fn list_chat_completions_models(&self) -> Vec<String> {
178
179
180
181
182
        self.models
            .iter()
            .filter(|entry| entry.value().has_chat_engine())
            .map(|entry| entry.key().clone())
            .collect()
183
184
185
    }

    pub fn list_completions_models(&self) -> Vec<String> {
186
187
188
189
190
        self.models
            .iter()
            .filter(|entry| entry.value().has_completions_engine())
            .map(|entry| entry.key().clone())
            .collect()
191
192
193
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
194
195
196
197
198
        self.models
            .iter()
            .filter(|entry| entry.value().has_embeddings_engine())
            .map(|entry| entry.key().clone())
            .collect()
199
200
    }

201
    pub fn list_tensor_models(&self) -> Vec<String> {
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        self.models
            .iter()
            .filter(|entry| entry.value().has_tensor_engine())
            .map(|entry| entry.key().clone())
            .collect()
    }

    pub fn list_images_models(&self) -> Vec<String> {
        self.models
            .iter()
            .filter(|entry| entry.value().has_images_engine())
            .map(|entry| entry.key().clone())
            .collect()
    }

    pub fn list_videos_models(&self) -> Vec<String> {
        self.models
            .iter()
            .filter(|entry| entry.value().has_videos_engine())
            .map(|entry| entry.key().clone())
            .collect()
223
224
    }

225
    pub fn list_prefill_models(&self) -> Vec<String> {
226
227
228
229
230
        self.models
            .iter()
            .filter(|entry| entry.value().has_prefill())
            .map(|entry| entry.key().clone())
            .collect()
231
232
    }

233
234
235
236
237
238
239
240
    pub fn get_embeddings_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_embeddings_engine()
241
242
    }

243
244
245
246
247
248
249
250
    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_completions_engine()
251
252
    }

253
    pub fn get_chat_completions_engine(
254
255
        &self,
        model: &str,
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_chat_engine()
    }

    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_tensor_engine()
    }

    pub fn get_images_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_images_engine()
    }

    pub fn get_videos_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_videos_engine()
    }

    // -- Combined engine + parsing options (atomically from one WorkerSet) --

    pub fn get_chat_completions_engine_with_parsing(
        &self,
        model: &str,
    ) -> Result<
        (
            OpenAIChatCompletionsStreamingEngine,
            crate::protocols::openai::ParsingOptions,
        ),
        ModelManagerError,
    > {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_chat_engine_with_parsing()
309
310
    }

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    pub fn get_completions_engine_with_parsing(
        &self,
        model: &str,
    ) -> Result<
        (
            OpenAICompletionsStreamingEngine,
            crate::protocols::openai::ParsingOptions,
        ),
        ModelManagerError,
    > {
        self.models
            .get(model)
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
            .get_completions_engine_with_parsing()
    }

    // -- Convenience methods for in-process models (http.rs, grpc.rs) --
    // These create a WorkerSet with a default namespace for local models.
    // TODO: These methods use ModelDeploymentCard::default() for the WorkerSet, which means
    // parsing_options() returns defaults (no tool_call_parser/reasoning_parser). Pass the real
    // MDC from callers so ParsingOptions reflect the model's actual configuration.

333
334
335
    pub fn add_chat_completions_model(
        &self,
        model: &str,
336
        card_checksum: &str,
337
338
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
339
340
341
342
343
344
345
346
347
348
349
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_chat_engine() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_chat_{}", model);
        let mut ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
        ws.chat_engine = Some(engine);
350
        model_entry.add_worker_set(namespace, Arc::new(ws));
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        Ok(())
    }

    pub fn add_completions_model(
        &self,
        model: &str,
        card_checksum: &str,
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_completions_engine() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_completions_{}", model);
        let mut ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
        ws.completions_engine = Some(engine);
371
        model_entry.add_worker_set(namespace, Arc::new(ws));
372
        Ok(())
373
374
375
376
377
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
378
        card_checksum: &str,
379
380
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
381
382
383
384
385
386
387
388
389
390
391
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_embeddings_engine() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_embeddings_{}", model);
        let mut ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
        ws.embeddings_engine = Some(engine);
392
        model_entry.add_worker_set(namespace, Arc::new(ws));
393
        Ok(())
394
395
    }

396
397
398
    pub fn add_tensor_model(
        &self,
        model: &str,
399
        card_checksum: &str,
400
401
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
402
403
404
405
406
407
408
409
410
411
412
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_tensor_engine() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_tensor_{}", model);
        let mut ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
        ws.tensor_engine = Some(engine);
413
        model_entry.add_worker_set(namespace, Arc::new(ws));
414
        Ok(())
415
416
    }

417
418
419
420
421
422
    pub fn add_images_model(
        &self,
        model: &str,
        card_checksum: &str,
        engine: OpenAIImagesStreamingEngine,
    ) -> Result<(), ModelManagerError> {
423
424
425
426
427
428
429
430
431
432
433
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_images_engine() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_images_{}", model);
        let mut ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
        ws.images_engine = Some(engine);
434
        model_entry.add_worker_set(namespace, Arc::new(ws));
435
        Ok(())
436
437
    }

438
439
440
441
442
443
    pub fn add_videos_model(
        &self,
        model: &str,
        card_checksum: &str,
        engine: OpenAIVideosStreamingEngine,
    ) -> Result<(), ModelManagerError> {
444
445
446
447
448
449
450
451
452
453
454
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_videos_engine() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_videos_{}", model);
        let mut ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
        ws.videos_engine = Some(engine);
455
        model_entry.add_worker_set(namespace, Arc::new(ws));
456
        Ok(())
457
458
    }

459
460
461
462
463
    pub fn add_prefill_model(
        &self,
        model: &str,
        card_checksum: &str,
    ) -> Result<(), ModelManagerError> {
464
465
466
467
468
469
470
471
472
473
        let model_entry = self.get_or_create_model(model);
        if model_entry.has_prefill() {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        let namespace = format!("__local_prefill_{}", model);
        let ws = WorkerSet::new(
            namespace.clone(),
            card_checksum.to_string(),
            ModelDeploymentCard::default(),
        );
474
        model_entry.add_worker_set(namespace, Arc::new(ws));
475
        Ok(())
476
477
    }

478
    // -- Model removal --
479

480
481
482
483
    /// Remove a model entirely (all its WorkerSets).
    /// Returns the removed Model, or None if not found.
    pub fn remove_model(&self, model: &str) -> Option<Arc<Model>> {
        self.models.remove(model).map(|(_, m)| m)
484
485
    }

486
487
    // Per-type remove methods for in-process models (used by Python bindings).
    // These remove the specific synthetic WorkerSet created by the corresponding add_*_model method.
488

489
490
491
492
493
    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let namespace = format!("__local_chat_{}", model);
        self.remove_worker_set(model, &namespace)
            .map(|_| ())
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
494
495
    }

496
497
498
499
500
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let namespace = format!("__local_completions_{}", model);
        self.remove_worker_set(model, &namespace)
            .map(|_| ())
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
501
502
    }

503
504
505
506
507
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let namespace = format!("__local_tensor_{}", model);
        self.remove_worker_set(model, &namespace)
            .map(|_| ())
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
508
509
    }

510
511
512
513
514
    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let namespace = format!("__local_embeddings_{}", model);
        self.remove_worker_set(model, &namespace)
            .map(|_| ())
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
515
516
    }

517
518
519
520
521
    pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let namespace = format!("__local_images_{}", model);
        self.remove_worker_set(model, &namespace)
            .map(|_| ())
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
522
523
    }

524
525
526
527
528
    pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let namespace = format!("__local_videos_{}", model);
        self.remove_worker_set(model, &namespace)
            .map(|_| ())
            .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
529
530
    }

531
    // -- KV Router creation --
532

533
    #[allow(clippy::too_many_arguments)]
534
535
    pub async fn kv_chooser_for(
        &self,
536
        endpoint: &Endpoint,
537
        kv_cache_block_size: u32,
538
        kv_router_config: Option<KvRouterConfig>,
539
        worker_type: &'static str,
540
        model_name: Option<String>,
541
        is_eagle: bool,
542
    ) -> anyhow::Result<Arc<KvRouter>> {
543
        let client = endpoint.client().await?;
544

545
        // Register router via discovery mechanism.
546
547
548
        let discovery = endpoint.component().drt().discovery();
        let instance_id = discovery.instance_id();

549
        // Build transport for router endpoint based on request plane mode
550
551
552
        // Use the worker's component name so each target pool gets its own router discovery group
        let router_endpoint_id =
            router_endpoint_id(endpoint.id().namespace, endpoint.id().component);
553
        let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?;
554
555
556
557
558

        let discovery_spec = DiscoverySpec::Endpoint {
            namespace: router_endpoint_id.namespace.clone(),
            component: router_endpoint_id.component.clone(),
            endpoint: router_endpoint_id.name.clone(),
559
            transport,
560
561
562
563
        };

        discovery.register(discovery_spec).await?;

564
        // Get of create runtime config watcher for this endpoint
565
566
        let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;

567
        let selector = DefaultWorkerSelector::new(kv_router_config.clone(), worker_type);
568
        let chooser = KvRouter::new(
569
570
            endpoint.clone(),
            client,
571
            workers_with_configs,
572
            kv_cache_block_size,
573
            selector,
574
            kv_router_config,
575
            worker_type,
576
            model_name,
577
            is_eagle,
578
579
        )
        .await?;
580
        Ok(Arc::new(chooser))
581
    }
582

583
584
585
586
587
588
589
590
    // -- Prefill router coordination --
    // Keyed by "model_name:namespace" so each namespace's decode WorkerSet gets its own
    // prefill router activated by same-namespace prefill workers.

    /// Build a key for a (model, namespace) pair. Used for prefill router activators
    /// and registration guards.
    pub(crate) fn model_namespace_key(model_name: &str, namespace: &str) -> String {
        format!("{}:{}", model_name, namespace)
591
592
    }

593
594
595
    /// Register a prefill router for a decode WorkerSet. Returns a receiver that will be
    /// activated when the corresponding prefill model in the same namespace is discovered.
    /// Returns None if a decode WorkerSet in this namespace was already registered.
596
597
    pub fn register_prefill_router(
        &self,
598
599
        model_name: &str,
        namespace: &str,
600
    ) -> Option<oneshot::Receiver<Endpoint>> {
601
602
        let key = Self::model_namespace_key(model_name, namespace);
        match self.prefill_router_activators.remove(&key) {
603
            Some((_, PrefillActivationState::PrefillReady(rx))) => {
604
605
606
                // Prefill endpoint already arrived - rx will immediately resolve
                tracing::debug!(
                    model_name = %model_name,
607
608
                    namespace = %namespace,
                    "Prefill endpoint already available for namespace, returning receiver"
609
610
611
                );
                Some(rx)
            }
612
            Some((key, PrefillActivationState::DecodeWaiting(tx))) => {
613
614
615
                // Decode already registered - this shouldn't happen, restore state and return None
                tracing::error!(
                    model_name = %model_name,
616
617
                    namespace = %namespace,
                    "Decode WorkerSet already registered for this prefill router"
618
                );
619
620
                self.prefill_router_activators
                    .insert(key, PrefillActivationState::DecodeWaiting(tx));
621
622
623
624
625
                None
            }
            None => {
                // New registration: create tx/rx pair, store sender and return receiver
                let (tx, rx) = oneshot::channel();
626
627
                self.prefill_router_activators
                    .insert(key, PrefillActivationState::DecodeWaiting(tx));
628
629
                tracing::debug!(
                    model_name = %model_name,
630
631
                    namespace = %namespace,
                    "No prefill endpoint for namespace yet, storing sender for future activation"
632
633
634
635
636
637
638
                );
                Some(rx)
            }
        }
    }

    /// Activate a prefill router by sending the endpoint through the oneshot channel.
639
    /// The namespace must match the decode WorkerSet's namespace.
640
641
642
    pub fn activate_prefill_router(
        &self,
        model_name: &str,
643
        namespace: &str,
644
645
        endpoint: Endpoint,
    ) -> anyhow::Result<()> {
646
647
        let key = Self::model_namespace_key(model_name, namespace);
        match self.prefill_router_activators.remove(&key) {
648
            Some((_, PrefillActivationState::DecodeWaiting(sender))) => {
649
650
                sender.send(endpoint).map_err(|_| {
                    anyhow::anyhow!(
651
652
653
                        "Failed to send endpoint to prefill router activator for {}:{}",
                        model_name,
                        namespace
654
655
656
657
                    )
                })?;
                tracing::info!(
                    model_name = %model_name,
658
659
                    namespace = %namespace,
                    "Activated prefill router for decode WorkerSet"
660
661
662
                );
                Ok(())
            }
663
            Some((_, PrefillActivationState::PrefillReady(_))) => {
664
665
666
667
668
                anyhow::bail!(
                    "Prefill router for {}:{} already activated",
                    model_name,
                    namespace
                );
669
670
671
672
            }
            None => {
                let (tx, rx) = oneshot::channel();
                tx.send(endpoint).map_err(|_| {
673
674
675
676
677
                    anyhow::anyhow!(
                        "Failed to send endpoint for prefill model {}:{}",
                        model_name,
                        namespace
                    )
678
                })?;
679
680
                self.prefill_router_activators
                    .insert(key, PrefillActivationState::PrefillReady(rx));
681
682
                tracing::info!(
                    model_name = %model_name,
683
684
                    namespace = %namespace,
                    "Stored prefill endpoint for future decode WorkerSet registration"
685
686
687
688
                );
                Ok(())
            }
        }
689
690
    }

691
692
693
694
695
696
697
698
699
700
701
    /// Remove the prefill router activator for a (model, namespace) pair.
    /// Called when a WorkerSet is removed to prevent stale activators.
    pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) {
        let key = Self::model_namespace_key(model_name, namespace);
        if self.prefill_router_activators.remove(&key).is_some() {
            tracing::debug!(
                model_name = %model_name,
                namespace = %namespace,
                "Cleaned up prefill router activator for removed WorkerSet"
            );
        }
702
703
    }

704
    // -- Worker monitoring --
705

706
    /// Gets or sets the load threshold config for a model's worker monitor.
707
    /// Checks across all WorkerSets for the model.
708
    pub fn load_threshold_config(
709
710
        &self,
        model: &str,
711
712
        config: Option<&LoadThresholdConfig>,
    ) -> Option<LoadThresholdConfig> {
713
714
        let model_entry = self.models.get(model)?;
        model_entry.load_threshold_config(config)
715
716
    }

717
718
    /// Gets an existing worker monitor for a specific namespace of a model.
    pub fn get_worker_monitor_for_namespace(
719
720
        &self,
        model: &str,
721
722
723
724
725
726
727
728
729
730
731
732
733
        namespace: &str,
    ) -> Option<KvWorkerMonitor> {
        let model_entry = self.models.get(model)?;
        model_entry.get_worker_monitor_for_namespace(namespace)
    }

    /// Lists all models with worker monitors configured.
    pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> {
        let mut result = Vec::new();
        for entry in self.models.iter() {
            if let Some(config) = entry.value().load_threshold_config(None) {
                result.push((entry.key().clone(), config));
            }
734
        }
735
        result
736
737
    }

738
739
    // -- Runtime configs --

740
    /// Get or create a runtime config watcher for an endpoint.
741
742
    /// Spawns a background task that joins instance availability and config discovery.
    /// Returns a `watch::Receiver` with the latest `HashMap<WorkerId, ModelRuntimeConfig>`.
743
744
745
    pub async fn get_or_create_runtime_config_watcher(
        &self,
        endpoint: &Endpoint,
746
    ) -> anyhow::Result<RuntimeConfigWatch> {
747
748
749
750
751
752
        let endpoint_id = endpoint.id();

        if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
            return Ok(existing.clone());
        }

753
754
755
756
757
758
        // Slow path: create the watch (spawns a background task).
        // If another caller raced us, the entry() below picks up the winner;
        // the loser's background task stops once its receivers are dropped.
        let rx = runtime_config_watch(endpoint).await?;
        let result = match self.runtime_configs.entry(endpoint_id) {
            Entry::Occupied(e) => e.get().clone(),
759
            Entry::Vacant(e) => {
760
761
                e.insert(rx.clone());
                rx
762
763
764
            }
        };

765
        Ok(result)
766
767
768
769
770
771
772
773
    }

    /// Get disaggregated endpoint for a specific worker.
    pub fn get_disaggregated_endpoint(
        &self,
        endpoint_id: &EndpointId,
        worker_id: WorkerId,
    ) -> Option<DisaggregatedEndpoint> {
774
775
776
        let rx = self.runtime_configs.get(endpoint_id)?;
        let configs = rx.borrow();
        configs.get(&worker_id)?.disaggregated_endpoint.clone()
777
    }
778
}
779

780
781
782
783
784
785
786
787
788
789
790
#[cfg(test)]
mod tests {
    use super::*;
    use crate::model_card::ModelDeploymentCard;

    fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
        WorkerSet::new(
            namespace.to_string(),
            mdcsum.to_string(),
            ModelDeploymentCard::default(),
        )
791
    }
792

793
    // -- CRUD delegation tests --
794

795
796
797
798
    #[test]
    fn test_add_and_get_worker_set() {
        let mm = ModelManager::new();
        let ws = make_worker_set("ns1", "abc");
799
        mm.add_worker_set("llama", "ns1", ws);
800
801
802
803
804
805

        let model = mm.get_model("llama");
        assert!(model.is_some());
        let model = model.unwrap();
        assert!(model.has_worker_set("ns1"));
        assert_eq!(model.worker_set_count(), 1);
806
807
    }

808
809
810
811
812
    #[test]
    fn test_add_worker_set_creates_model() {
        let mm = ModelManager::new();
        assert!(mm.get_model("llama").is_none());

813
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
814
        assert!(mm.get_model("llama").is_some());
815
816
    }

817
818
819
    #[test]
    fn test_remove_worker_set_removes_empty_model() {
        let mm = ModelManager::new();
820
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
821
822
823
824
825
826
827
828
        assert!(mm.get_model("llama").is_some());

        let removed = mm.remove_worker_set("llama", "ns1");
        assert!(removed.is_some());
        assert_eq!(removed.unwrap().namespace(), "ns1");

        // Model should be auto-removed since it's now empty
        assert!(mm.get_model("llama").is_none());
829
830
    }

831
832
833
    #[test]
    fn test_remove_worker_set_keeps_model_with_remaining() {
        let mm = ModelManager::new();
834
835
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
        mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"));
836
837
838
839
840
841
842
843

        mm.remove_worker_set("llama", "ns1");

        // Model should still exist with ns2
        let model = mm.get_model("llama").unwrap();
        assert!(!model.has_worker_set("ns1"));
        assert!(model.has_worker_set("ns2"));
        assert_eq!(model.worker_set_count(), 1);
844
845
    }

846
847
848
849
850
851
852
853
854
    #[test]
    fn test_remove_worker_set_nonexistent_model() {
        let mm = ModelManager::new();
        assert!(mm.remove_worker_set("llama", "ns1").is_none());
    }

    #[test]
    fn test_remove_worker_set_nonexistent_namespace() {
        let mm = ModelManager::new();
855
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
856
857
858
859
860
861
862
863
864
        assert!(mm.remove_worker_set("llama", "ns2").is_none());

        // Model should still exist (ns1 still there)
        assert!(mm.get_model("llama").is_some());
    }

    #[test]
    fn test_remove_model_if_empty_noop_when_not_empty() {
        let mm = ModelManager::new();
865
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
866
867
868

        mm.remove_model_if_empty("llama");
        assert!(mm.get_model("llama").is_some()); // Still has ns1
869
870
    }

871
872
873
874
    #[test]
    fn test_remove_model_if_empty_noop_when_missing() {
        let mm = ModelManager::new();
        mm.remove_model_if_empty("nonexistent"); // Should not panic
875
876
    }

877
878
879
    #[test]
    fn test_remove_model() {
        let mm = ModelManager::new();
880
881
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
        mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"));
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906

        let removed = mm.remove_model("llama");
        assert!(removed.is_some());
        assert!(mm.get_model("llama").is_none());
    }

    #[test]
    fn test_get_or_create_model_idempotent() {
        let mm = ModelManager::new();
        let m1 = mm.get_or_create_model("llama");
        let m2 = mm.get_or_create_model("llama");
        // Both should point to the same Model (same Arc)
        assert!(Arc::ptr_eq(&m1, &m2));
    }

    // -- Model listing and filtering tests --

    #[test]
    fn test_has_decode_model() {
        let mm = ModelManager::new();

        // No model → false
        assert!(!mm.has_decode_model("llama"));

        // Prefill-only set (no engines) → false
907
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
908
909
910
911
912
913
914
915
        assert!(!mm.has_decode_model("llama"));
    }

    #[test]
    fn test_has_prefill_model() {
        let mm = ModelManager::new();

        // Prefill set = no engines
916
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
917
918
919
920
921
922
923
924
        assert!(mm.has_prefill_model("llama"));
    }

    #[test]
    fn test_has_model_any() {
        let mm = ModelManager::new();
        assert!(!mm.has_model_any("llama"));

925
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
926
927
928
929
930
931
        assert!(mm.has_model_any("llama")); // has prefill
    }

    #[test]
    fn test_model_display_names_includes_prefill() {
        let mm = ModelManager::new();
932
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
933
934
935
936
937
938
939
940
941
942
943
944
945
946

        let names = mm.model_display_names();
        assert!(names.contains("llama"));
    }

    #[test]
    fn test_model_display_names_empty() {
        let mm = ModelManager::new();
        assert!(mm.model_display_names().is_empty());
    }

    #[test]
    fn test_list_prefill_models() {
        let mm = ModelManager::new();
947
948
        mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
        mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def"));
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037

        let prefill = mm.list_prefill_models();
        assert_eq!(prefill.len(), 2);
        assert!(prefill.contains(&"llama".to_string()));
        assert!(prefill.contains(&"gpt".to_string()));
    }

    // -- Model card tests --

    #[test]
    fn test_save_and_remove_model_card() {
        let mm = ModelManager::new();
        let card = ModelDeploymentCard::default();
        mm.save_model_card("instance/key/1", card.clone()).unwrap();

        let cards = mm.get_model_cards();
        assert_eq!(cards.len(), 1);

        let removed = mm.remove_model_card("instance/key/1");
        assert!(removed.is_some());
        assert!(mm.get_model_cards().is_empty());
    }

    #[test]
    fn test_remove_model_card_nonexistent() {
        let mm = ModelManager::new();
        assert!(mm.remove_model_card("nonexistent").is_none());
    }

    // -- Prefill router rendezvous tests --
    // Note: activate_prefill_router requires an Endpoint (needs DistributedRuntime),
    // so we test the registration state machine and cleanup only.

    #[test]
    fn test_prefill_router_register_new() {
        let mm = ModelManager::new();

        // First registration for a (model, namespace) returns Some(rx)
        let rx = mm.register_prefill_router("llama", "ns1");
        assert!(rx.is_some());
    }

    #[test]
    fn test_prefill_router_double_register_returns_none() {
        let mm = ModelManager::new();

        let rx1 = mm.register_prefill_router("llama", "ns1");
        assert!(rx1.is_some());

        // Second registration for the same (model, namespace) returns None
        let rx2 = mm.register_prefill_router("llama", "ns1");
        assert!(rx2.is_none());
    }

    #[test]
    fn test_prefill_router_different_namespaces_independent() {
        let mm = ModelManager::new();

        // Different namespaces should be independent
        let rx1 = mm.register_prefill_router("llama", "ns1");
        let rx2 = mm.register_prefill_router("llama", "ns2");
        assert!(rx1.is_some());
        assert!(rx2.is_some());
    }

    #[test]
    fn test_prefill_router_different_models_independent() {
        let mm = ModelManager::new();

        // Different models should be independent
        let rx1 = mm.register_prefill_router("llama", "ns1");
        let rx2 = mm.register_prefill_router("gpt", "ns1");
        assert!(rx1.is_some());
        assert!(rx2.is_some());
    }

    #[test]
    fn test_prefill_router_remove_allows_reregister() {
        let mm = ModelManager::new();

        let rx = mm.register_prefill_router("llama", "ns1");
        assert!(rx.is_some());

        // Remove the activator
        mm.remove_prefill_activator("llama", "ns1");

        // Should be able to register again
        let rx2 = mm.register_prefill_router("llama", "ns1");
        assert!(rx2.is_some());
1038
1039
    }

1040
1041
1042
1043
1044
    #[test]
    fn test_prefill_router_remove_nonexistent_noop() {
        let mm = ModelManager::new();
        // Should not panic
        mm.remove_prefill_activator("llama", "ns1");
1045
    }
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    #[test]
    fn test_model_namespace_key_format() {
        assert_eq!(
            ModelManager::model_namespace_key("llama", "ns1"),
            "llama:ns1"
        );
        assert_eq!(
            ModelManager::model_namespace_key("gpt-4", "default-abc"),
            "gpt-4:default-abc"
        );
1057
    }
1058
}