watcher.rs 51.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Ryan Olson's avatar
Ryan Olson committed
2
3
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use std::time::Duration;
6
use tokio::sync::Notify;
7
use tokio::sync::mpsc::Sender;
8
use tokio::task::JoinHandle;
9

10
use anyhow::Context as _;
11
use dashmap::{DashMap, DashSet};
12
use dynamo_kv_router::PrefillLoadEstimator;
13
use futures::StreamExt;
Ryan Olson's avatar
Ryan Olson committed
14

Neelay Shah's avatar
Neelay Shah committed
15
use dynamo_runtime::{
16
    DistributedRuntime,
17
18
19
20
    discovery::{
        DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery, DiscoveryStream,
        ModelCardInstanceId,
    },
21
    pipeline::{
22
23
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
24
    },
25
    protocols::{EndpointId, annotated::Annotated},
26
};
Ryan Olson's avatar
Ryan Olson committed
27

28
29
use crate::{
    backend::Backend,
30
    discovery::{KvWorkerMonitor, WORKER_TYPE_DECODE, WorkerSet},
31
    entrypoint::{self, ChatEngineFactoryCallback, RouterConfig},
32
    http::service::metrics::Metrics,
33
    kv_router::PrefillRouter,
34
    model_card::ModelDeploymentCard,
35
36
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
37
38
39
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
40
            audios::{NvAudioSpeechResponse, NvCreateAudioSpeechRequest},
41
42
43
44
45
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
46
            images::{NvCreateImageRequest, NvImagesResponse},
47
            videos::{NvCreateVideoRequest, NvVideosResponse},
48
        },
49
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
50
51
    },
};
52

53
use super::ModelManager;
54
55
56
57
58
59
60
61
62
63
64
use crate::namespace::NamespaceFilter;

/// Constructs the WorkerSet storage key. Prefill and decode workers in the same
/// namespace get different keys so they don't block each other's registration.
fn worker_set_key(namespace: &str, model_type: ModelType) -> String {
    if model_type.supports_prefill() {
        format!("{}:prefill", namespace)
    } else {
        namespace.to_string()
    }
}
65

66
#[derive(Debug, Clone)]
67
pub enum ModelUpdate {
68
69
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
70
71
}

72
pub struct ModelWatcher {
73
    manager: Arc<ModelManager>,
74
    drt: DistributedRuntime,
75
    router_config: RouterConfig,
76
    migration_limit: u32,
77
    migration_max_seq_len: Option<u32>,
78
    notify_on_model: Notify,
79
    model_update_tx: Option<Sender<ModelUpdate>>,
80
    chat_engine_factory: Option<ChatEngineFactoryCallback>,
81
    prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
82
    metrics: Arc<Metrics>,
83
84
    /// Guards against concurrent pipeline construction for the same (model, namespace).
    registering_worker_sets: DashSet<String>,
85
86
87
    /// Wakes tasks blocked in `recover_concurrent_registration` when a
    /// `RegistrationGuard` drops (i.e. a registration completes or panics).
    registration_notify: Notify,
88
89
90
    /// Tracks in-flight `handle_put` tasks by instance path so that `handle_delete`
    /// can await a racing put before proceeding with cleanup.
    pending_puts: DashMap<String, JoinHandle<()>>,
Ryan Olson's avatar
Ryan Olson committed
91
92
}

93
94
95
96
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
97
    ModelType::Images,
98
    ModelType::Audios,
99
100
    ModelType::Videos,
    ModelType::TensorBased,
101
    ModelType::Prefill,
102
];
103

104
105
106
107
108
109
110
111
112
113
/// Returns true if no models in the manager support the given model type.
fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bool {
    if model_type == ModelType::Chat {
        manager.list_chat_completions_models().is_empty()
    } else if model_type == ModelType::Completions {
        manager.list_completions_models().is_empty()
    } else if model_type == ModelType::Embedding {
        manager.list_embeddings_models().is_empty()
    } else if model_type == ModelType::Images {
        manager.list_images_models().is_empty()
114
115
    } else if model_type == ModelType::Audios {
        manager.list_audios_models().is_empty()
116
117
118
119
120
121
122
123
124
125
126
    } else if model_type == ModelType::Videos {
        manager.list_videos_models().is_empty()
    } else if model_type == ModelType::TensorBased {
        manager.list_tensor_models().is_empty()
    } else if model_type == ModelType::Prefill {
        manager.list_prefill_models().is_empty()
    } else {
        true
    }
}

127
128
/// RAII guard that removes a key from a `DashSet` on drop and wakes any tasks
/// waiting for the registration to finish via the shared [`Notify`].
129
130
131
132
133
/// Ensures `registering_worker_sets` is cleaned up even if the registration
/// task panics, preventing permanent poisoning of the registration key.
struct RegistrationGuard<'a> {
    set: &'a DashSet<String>,
    key: String,
134
    notify: &'a Notify,
135
136
137
138
139
}

impl Drop for RegistrationGuard<'_> {
    fn drop(&mut self) {
        self.set.remove(&self.key);
140
        self.notify.notify_waiters();
141
142
143
    }
}

144
impl ModelWatcher {
145
    #[allow(clippy::too_many_arguments)]
146
    pub fn new(
147
        runtime: DistributedRuntime,
148
        model_manager: Arc<ModelManager>,
149
        router_config: RouterConfig,
150
        migration_limit: u32,
151
        migration_max_seq_len: Option<u32>,
152
        chat_engine_factory: Option<ChatEngineFactoryCallback>,
153
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
154
        metrics: Arc<Metrics>,
155
156
    ) -> ModelWatcher {
        Self {
157
            manager: model_manager,
158
            drt: runtime,
159
            router_config,
160
            migration_limit,
161
            migration_max_seq_len,
162
            notify_on_model: Notify::new(),
163
            model_update_tx: None,
164
            chat_engine_factory,
165
            prefill_load_estimator,
166
            metrics,
167
            registering_worker_sets: DashSet::new(),
168
            registration_notify: Notify::new(),
169
            pending_puts: DashMap::new(),
170
        }
171
    }
Ryan Olson's avatar
Ryan Olson committed
172

173
174
175
176
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

177
178
179
180
181
182
183
184
185
186
187
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

188
189
190
191
192
    /// Common watch logic with optional namespace filtering.
    ///
    /// Takes `Arc<Self>` so that each `handle_put` call can be spawned into its own
    /// tokio task, preventing a slow HuggingFace config download for one model from
    /// blocking discovery events for all subsequent models.
193
    pub async fn watch(
194
        self: Arc<Self>,
195
        mut discovery_stream: DiscoveryStream,
196
        namespace_filter: NamespaceFilter,
197
198
199
200
201
202
203
204
205
206
    ) {
        while let Some(result) = discovery_stream.next().await {
            let event = match result {
                Ok(event) => event,
                Err(err) => {
                    tracing::error!(%err, "Error in discovery stream");
                    continue;
                }
            };

207
            match event {
208
                DiscoveryEvent::Added(instance) => {
209
210
                    // Extract ModelCardInstanceId and card from the discovery instance
                    let (mcid, mut card) = match &instance {
211
212
213
214
215
                        DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            instance_id,
216
                            model_suffix,
217
218
                            ..
                        } => {
219
                            let mcid = ModelCardInstanceId {
220
221
                                namespace: namespace.clone(),
                                component: component.clone(),
222
223
224
                                endpoint: endpoint.clone(),
                                instance_id: *instance_id,
                                model_suffix: model_suffix.clone(),
225
226
227
                            };

                            match instance.deserialize_model::<ModelDeploymentCard>() {
228
                                Ok(card) => (mcid, card),
229
230
231
232
233
234
235
236
237
238
                                Err(err) => {
                                    tracing::error!(%err, instance_id, "Failed to deserialize model card");
                                    continue;
                                }
                            }
                        }
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
239
240
241
                            continue;
                        }
                    };
242

243
244
                    // Filter by namespace using the configured filter
                    if !namespace_filter.matches(&mcid.namespace) {
245
                        tracing::debug!(
246
                            model_namespace = mcid.namespace,
247
248
                            namespace_filter = ?namespace_filter,
                            "Skipping model due to namespace filter"
249
250
251
252
                        );
                        continue;
                    }

253
254
255
256
257
258
259
                    // If a WorkerSet already exists for this (model, namespace, type),
                    // validate that the new worker's checksum matches. Different
                    // WorkerSets (different namespaces) are allowed to have different checksums to support rolling updates.
                    let ws_key = worker_set_key(&mcid.namespace, card.model_type);
                    if let Some(model) = self.manager.get_model(card.name())
                        && !model.is_checksum_compatible(&ws_key, card.mdcsum())
                    {
260
261
                        tracing::error!(
                            model_name = card.name(),
262
                            namespace = mcid.namespace,
263
264
265
                            new_checksum = card.mdcsum(),
                            "Checksum for new worker does not match existing WorkerSet's checksum. \
                             Drain all old workers in this namespace before deploying a new version."
266
267
268
269
270
271
272
273
274
275
                        );
                        // TODO: mark that instance down in clients
                        // Not obvious how to do that given the current design
                        // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
                        // Calling `report_instance_down` on the Client should do it (although
                        // needs more testing).
                        // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
                        // interface `AsyncEngine` which only has a `generate` method.
                        continue;
                    }
276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
                    // Spawn each handle_put into its own task so that a slow
                    // HuggingFace config download for one model cannot block
                    // discovery events for all subsequent models.
                    //
                    // The JoinHandle is stored in `pending_puts` so that a
                    // subsequent `handle_delete` for the same instance can
                    // await the in-flight put before attempting cleanup.
                    let instance_key = mcid.to_path();
                    let watcher = Arc::clone(&self);
                    let handle = tokio::spawn(async move {
                        match watcher.handle_put(&mcid, &mut card).await {
                            Ok(()) => {
                                tracing::info!(
                                    model_name = card.name(),
                                    namespace = mcid.namespace,
                                    "added model"
                                );
                                watcher.notify_on_model.notify_waiters();
                            }
                            Err(err) => {
                                tracing::error!(
                                    model_name = card.name(),
                                    namespace = mcid.namespace,
                                    error = format!("{err:#}"),
                                    "Error adding model from discovery",
                                );
                            }
304
                        }
305
306
307
308
                        // Note: we intentionally do NOT remove from pending_puts here.
                        // Only the watch loop (on duplicate events) and handle_delete
                        // manage pending_puts, avoiding a race where a completed task's
                        // cleanup could remove a newer task's entry.
309
310
                    });
                    // If a duplicate Added event arrives while the first task is still
311
312
313
314
315
316
317
318
                    // in-flight, abort the old task to cancel redundant work.
                    //
                    // `instance_key` is `mcid.to_path()` = "{ns}/{component}/{endpoint}/{instance_id:x}",
                    // so this is keyed per-worker-instance, NOT per-model. Two different workers
                    // registering the same model produce two different keys and run independently.
                    // The only case that hits this branch is the etcd watch replaying the same
                    // worker's Added event (reconnect or re-sync) — where cancelling the earlier
                    // redundant task is exactly what we want.
319
320
                    if let Some((_, old_handle)) = self.pending_puts.remove(&instance_key) {
                        old_handle.abort();
321
                    }
322
                    self.pending_puts.insert(instance_key, handle);
323
                }
324
325
326
327
                DiscoveryEvent::Removed(id) => {
                    // Extract ModelCardInstanceId from the removal event
                    let model_card_instance_id = match &id {
                        DiscoveryInstanceId::Model(mcid) => mcid,
328
                        DiscoveryInstanceId::Endpoint(_) | DiscoveryInstanceId::EventChannel(_) => {
329
330
331
332
333
334
                            tracing::error!(
                                "Unexpected discovery instance type in removal (expected Model)"
                            );
                            continue;
                        }
                    };
335

336
                    match self
337
                        .handle_delete(model_card_instance_id, &namespace_filter)
338
339
340
341
342
343
344
345
346
347
348
                        .await
                    {
                        Ok(Some(model_name)) => {
                            tracing::info!(model_name, "removed model");
                        }
                        Ok(None) => {
                            // There are other instances running this model, nothing to do
                        }
                        Err(e) => {
                            tracing::error!(error = %e, "error removing model");
                        }
349
                    }
350
                }
351
            }
Ryan Olson's avatar
Ryan Olson committed
352
353
354
        }
    }

355
356
    /// Handle a worker removal. Cleans up per-namespace WorkerSets and the Model itself
    /// when no instances remain. Returns the model name if the entire Model was removed.
357
358
    async fn handle_delete(
        &self,
359
        mcid: &ModelCardInstanceId,
360
        namespace_filter: &NamespaceFilter,
361
    ) -> anyhow::Result<Option<String>> {
362
        let key = mcid.to_path();
363
364
365
366
367

        // If there is an in-flight handle_put for this instance, wait for it
        // to complete before we attempt cleanup. Without this, a Removed event
        // arriving while handle_put is still downloading HF config would fail
        // to find the model card, leaving a stale registration.
368
        if let Some((_, mut handle)) = self.pending_puts.remove(&key) {
369
370
371
            tracing::debug!(key = %key, "awaiting in-flight handle_put before delete");
            // Ignore join errors (panic in the spawned task) — we still proceed
            // with cleanup since the put may have partially registered the model.
372
373
374
375
376
377
378
379
380
381
382
383
384
            match tokio::time::timeout(Duration::from_secs(60), &mut handle).await {
                Ok(_) => {}
                Err(_) => {
                    // Abort the timed-out task so it cannot register the model
                    // after we proceed with deletion.
                    handle.abort();
                    let _ = handle.await;
                    tracing::warn!(
                        key = %key,
                        "Timed out waiting for in-flight handle_put, aborted and proceeding with delete"
                    );
                }
            }
385
        }
386
        let card = match self.manager.remove_model_card(&key) {
387
            Some(card) => card,
388
            None => {
389
                anyhow::bail!("Missing ModelDeploymentCard for {}", key);
390
391
            }
        };
392
        let model_name = card.name().to_string();
393
394
395
396
397
        let worker_namespace = &mcid.namespace;
        let worker_component = &mcid.component;
        let ws_key = worker_set_key(&mcid.namespace, card.model_type);

        // Query discovery for all remaining instances of this model
398
        let active_instances = self
399
            .cards_for_model_with_endpoints(&model_name, namespace_filter)
400
401
            .await
            .with_context(|| model_name.clone())?;
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422

        // Check if instances of the SAME component remain in this namespace.
        // In disaggregated deployments, prefill and decode are different components
        // in the same namespace, so we must check at the component level to avoid
        // removing one type's WorkerSet while the other still has workers.
        let component_has_instances = active_instances.iter().any(|(eid, _)| {
            eid.namespace == *worker_namespace && eid.component == *worker_component
        });

        if !component_has_instances {
            // No more workers of this component in this namespace — remove its WorkerSet
            if let Some(_removed_ws) = self.manager.remove_worker_set(&model_name, &ws_key) {
                // remove_prefill_activator uses deployment namespace (not ws_key)
                self.manager
                    .remove_prefill_activator(&model_name, worker_namespace);
                tracing::info!(
                    model_name,
                    namespace = %worker_namespace,
                    "Removed WorkerSet (no remaining instances in namespace)"
                );
            }
423
424
425
426
427
428
429
430
431

            // If the removed component was a prefill worker, deactivate the decode-side
            // prefill router so requests fall back to aggregated mode (or fail cleanly
            // with enforce_disagg). The decode WorkerSet's namespace matches the
            // deployment namespace, not the ws_key.
            if card.model_type.supports_prefill() {
                self.manager
                    .deactivate_prefill_router_for_decode(&model_name, worker_namespace);
            }
432
433
434
        }

        // Check if the Model still has instances in any namespace
435
        if !active_instances.is_empty() {
436
437
438
            tracing::debug!(
                model_name,
                active_instance_count = active_instances.len(),
439
                "Model has other active instances in other namespaces"
440
            );
441
442
            return Ok(None);
        }
443

444
445
        // No instances remain anywhere — remove the entire Model
        let _ = self.manager.remove_model(&model_name);
446

447
        if let Some(tx) = &self.model_update_tx {
448
            for model_type in ALL_MODEL_TYPES {
449
450
                if card.model_type.intersects(*model_type)
                    && is_model_type_list_empty(&self.manager, *model_type)
451
                {
452
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
453
454
455
                }
            }
        }
456

457
        Ok(Some(model_name))
458
    }
Ryan Olson's avatar
Ryan Olson committed
459

460
    // Handles a PUT event from store, this usually means adding a new model to the list of served
461
    // models.
462
463
    async fn handle_put(
        &self,
464
        mcid: &ModelCardInstanceId,
465
466
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
467
468
469
470
471
        // Check if this specific (model, namespace, type) WorkerSet already exists.
        // If so, this is just another worker joining an existing set — no pipeline build needed.
        let model_name = card.name().to_string();
        let namespace = mcid.namespace.clone();
        let ws_key = worker_set_key(&namespace, card.model_type);
472

473
474
475
        if let Some(model) = self.manager.get_model(&model_name)
            && model.has_worker_set(&ws_key)
        {
476
477
478
479
480
481
482
483
484
485
486
487
            if !model.is_checksum_compatible(&ws_key, card.mdcsum()) {
                tracing::error!(
                    model_name = card.name(),
                    namespace = namespace,
                    new_checksum = card.mdcsum(),
                    "Checksum for new worker does not match existing WorkerSet's checksum. \
                     Drain all old workers in this namespace before deploying a new version."
                );
                return Err(anyhow::anyhow!(
                    "Checksum mismatch for worker in namespace {namespace}"
                ));
            }
488
489
            self.manager
                .save_model_card(&mcid.to_path(), card.clone())?;
490
491
            tracing::debug!(
                model_name = card.name(),
492
493
                namespace = namespace,
                "Worker joined existing WorkerSet, skipping pipeline build"
494
495
496
497
            );
            return Ok(());
        }

498
499
500
501
502
        // Guard against concurrent pipeline construction for the same (model, namespace, type)
        let registration_key = ModelManager::model_namespace_key(&model_name, &ws_key);
        if !self
            .registering_worker_sets
            .insert(registration_key.clone())
503
504
505
506
507
508
509
510
511
512
            && !self
                .recover_concurrent_registration(
                    mcid,
                    card,
                    &model_name,
                    &namespace,
                    &ws_key,
                    &registration_key,
                )
                .await?
513
        {
514
515
516
517
518
            return Ok(());
        }

        // RAII guard ensures the registration key is removed even if
        // do_worker_set_registration panics, preventing permanent poisoning.
519
        // It also wakes any waiters in recover_concurrent_registration.
520
521
522
        let _guard = RegistrationGuard {
            set: &self.registering_worker_sets,
            key: registration_key,
523
            notify: &self.registration_notify,
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        };

        self.do_worker_set_registration(mcid, card).await
    }

    /// Handle the case where another task is already building the pipeline for this
    /// (model, namespace, type). This is a recovery path — it waits for the in-flight
    /// registration to finish, then either joins the resulting WorkerSet or retries.
    ///
    /// Returns `true` if the caller should proceed with its own registration
    /// (i.e. the other task failed), `false` if the worker was handled (joined or rejected).
    async fn recover_concurrent_registration(
        &self,
        mcid: &ModelCardInstanceId,
        card: &mut ModelDeploymentCard,
        model_name: &str,
        namespace: &str,
        ws_key: &str,
        registration_key: &str,
    ) -> anyhow::Result<bool> {
        // Wait for the in-flight registration to complete so we can validate
        // the new worker's checksum. Without this, a concurrent worker with a
        // mismatched checksum could sneak past the early check in `watch`.
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        //
        // Uses a Notify + enable() loop instead of polling to wake up
        // immediately when the RegistrationGuard drops, avoiding up to 100ms
        // of unnecessary latency and wasted CPU cycles.
        // An absolute deadline ensures spurious wakeups (from unrelated
        // registrations sharing the same Notify) cannot extend the wait
        // beyond 30 seconds.
        let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
        loop {
            let notified = self.registration_notify.notified();
            tokio::pin!(notified);
            // Register interest in the notification BEFORE checking the
            // condition to avoid a race where the guard drops between
            // our check and the .await.
            notified.as_mut().enable();
            if !self.registering_worker_sets.contains(registration_key) {
                break;
            }
            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
            if remaining.is_zero() {
                break;
            }
            if tokio::time::timeout(remaining, notified).await.is_err() {
                break;
            }
572
573
        }

574
575
576
577
578
579
580
581
582
583
584
585
586
587
        // If we timed out and the other task is still running, bail out rather
        // than proceeding with concurrent pipeline construction.
        if self.registering_worker_sets.contains(registration_key) {
            // Save the model card so handle_delete can find it for cleanup.
            self.manager
                .save_model_card(&mcid.to_path(), card.clone())?;
            tracing::warn!(
                model_name = card.name(),
                namespace = namespace,
                "Timed out waiting for concurrent registration to complete, skipping"
            );
            return Ok(false);
        }

588
589
590
591
592
        // Validate checksum against the registered model
        if let Some(model) = self.manager.get_model(model_name)
            && !model.is_checksum_compatible(ws_key, card.mdcsum())
        {
            tracing::error!(
593
                model_name = card.name(),
594
                namespace = namespace,
595
596
597
                new_checksum = card.mdcsum(),
                "Checksum for new worker does not match existing WorkerSet's checksum. \
                 Drain all old workers in this namespace before deploying a new version."
598
            );
599
            return Ok(false);
600
601
        }

602
603
604
605
606
607
608
609
        // If the first registration failed or timed out, no WorkerSet exists.
        // Fall through to do_worker_set_registration instead of becoming a ghost
        // worker (registered in cards but with no serving pipeline).
        if self
            .manager
            .get_model(model_name)
            .is_none_or(|m| !m.has_worker_set(ws_key))
        {
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
            // Only the first waiter to re-insert the key should proceed with
            // registration. Other waiters return false to avoid concurrent builds.
            if !self
                .registering_worker_sets
                .insert(registration_key.to_string())
            {
                // Save the model card so handle_delete can find it for cleanup.
                self.manager
                    .save_model_card(&mcid.to_path(), card.clone())?;
                tracing::debug!(
                    model_name = card.name(),
                    namespace = namespace,
                    "Another waiter won the re-registration race, skipping"
                );
                return Ok(false);
            }
626
627
628
629
630
631
632
            tracing::warn!(
                model_name = card.name(),
                namespace = namespace,
                "Concurrent registration produced no WorkerSet, retrying"
            );
            return Ok(true);
        }
633

634
635
636
637
638
639
640
641
        self.manager
            .save_model_card(&mcid.to_path(), card.clone())?;
        tracing::debug!(
            model_name = card.name(),
            namespace = namespace,
            "Worker joined existing WorkerSet, skipping pipeline build"
        );
        Ok(false)
642
643
    }

644
645
646
    /// Build a complete WorkerSet with all engines for this (model, namespace)
    /// and add it to the Model.
    async fn do_worker_set_registration(
647
648
649
650
651
652
653
654
655
656
657
658
        &self,
        mcid: &ModelCardInstanceId,
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
        card.download_config().await?;

        let component = self
            .drt
            .namespace(&mcid.namespace)?
            .component(&mcid.component)?;
        let endpoint = component.endpoint(&mcid.endpoint);
        let client = endpoint.client().await?;
659
660
661
662
663
664
        let instance_watcher = client.instance_avail_watcher();
        tracing::debug!(
            model_name = card.name(),
            namespace = mcid.namespace,
            "building worker set pipeline"
        );
665
666
667
        self.manager
            .save_model_card(&mcid.to_path(), card.clone())?;

668
        let checksum = card.mdcsum();
669
670
671
672
673
674
        let namespace = mcid.namespace.clone();
        let ws_key = worker_set_key(&namespace, card.model_type);

        // Build the WorkerSet with all applicable engines
        let mut worker_set = WorkerSet::new(namespace.clone(), checksum.to_string(), card.clone());
        worker_set.set_instance_watcher(instance_watcher);
675
676
677

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
678
679
680
681
682
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

683
            let endpoint = component.endpoint(&mcid.endpoint);
684
685
686
687
688
689
690
691
692
            // Create the KV router whenever any local routed pipeline will be built.
            // The chat factory builds its own router, but completions currently always
            // uses the local routed pipeline and therefore still needs a chooser.
            let needs_local_chat_pipeline =
                card.model_type.supports_chat() && self.chat_engine_factory.is_none();
            let needs_local_completions_pipeline = card.model_type.supports_completions();
            let kv_chooser = if self.router_config.router_mode == RouterMode::KV
                && (needs_local_chat_pipeline || needs_local_completions_pipeline)
            {
693
694
                Some(
                    self.manager
695
696
697
                        .kv_chooser_for(
                            &endpoint,
                            card.kv_cache_block_size,
698
                            Some(self.router_config.kv_router_config.clone()),
699
                            self.prefill_load_estimator.clone(),
700
                            WORKER_TYPE_DECODE, // This is the decode router
701
                            Some(card.display_name.clone()),
702
                            card.runtime_config.enable_eagle,
703
                        )
704
705
706
707
708
                        .await?,
                )
            } else {
                None
            };
709

710
711
712
713
714
715
716
717
718
719
720
721
722
            // Loading the tokenizer is expensive (~10 MiB JSON), so only do it
            // once and only when a local pipeline actually needs it.  Models
            // without tokenizer.json (e.g. Qwen3-Omni) set tokenizer = None;
            // they rely on a Python chat_engine_factory for tokenization.
            // When a chat_engine_factory handles chat and no completions are
            // needed, skip tokenizer loading entirely — even if the file exists.
            let needs_rust_tokenizer =
                needs_local_chat_pipeline || needs_local_completions_pipeline;
            let tokenizer = if needs_rust_tokenizer && card.has_tokenizer() {
                Some(card.tokenizer().context("tokenizer")?)
            } else {
                None
            };
723

724
725
            // Create prefill chooser once if we're building pipelines
            // Both chat and completions will share the same prefill chooser instance
726
            let model_name = card.name().to_string();
727
728
            let prefill_chooser = self
                .manager
729
                .register_prefill_router(&model_name, &namespace)
730
731
                .map(|rx| {
                    // Create prefill-specific config with track_active_blocks disabled
732
                    let mut prefill_config = self.router_config.kv_router_config.clone();
733
734
735
736
737
                    prefill_config.router_track_active_blocks = false;

                    PrefillRouter::new(
                        rx,
                        self.manager.clone(),
738
                        self.router_config.router_mode,
739
740
                        card.kv_cache_block_size,
                        Some(prefill_config),
741
                        self.prefill_load_estimator.clone(),
742
                        self.router_config.enforce_disagg,
743
744
                        model_name.clone(),
                        namespace.clone(),
745
                        card.runtime_config.enable_eagle,
746
747
748
                    )
                });

749
750
751
            // Create a new worker monitor for this WorkerSet. Each WorkerSet gets its own
            // monitor (1-to-1) since each monitor is scoped to this WorkerSet's Client/namespace.
            // The monitor tracks Prometheus metrics (active_decode_blocks, active_prefill_tokens,
752
            // worker TTFT/ITL cleanup). The thresholds control busy detection behavior only.
753
754
755
756
757
758
759
760
761
762
            //
            // IMPORTANT: When KV routing is active, the monitor must use the KvRouter's Client
            // so that busy-state updates (via update_free_instances) are visible to the
            // PushRouter, which also uses the KvRouter's Client (see common.rs:258-263).
            // Using a different Client instance would cause the PushRouter to never see
            // busy workers, since each Client::new() creates independent ArcSwap state.
            let monitor_client = kv_chooser
                .as_ref()
                .map(|chooser| chooser.client().clone())
                .unwrap_or_else(|| client.clone());
763
            let worker_monitor = Some(KvWorkerMonitor::new(
764
                monitor_client,
765
766
                self.router_config.load_threshold_config.clone(),
            ));
767

768
769
770
            // Store KV router, worker monitor, and prefill router on the WorkerSet.
            // The prefill router is stored so the watcher can deactivate/reactivate it
            // when prefill workers die or rejoin.
771
772
            worker_set.kv_router = kv_chooser.clone();
            worker_set.worker_monitor = worker_monitor.clone();
773
            worker_set.prefill_router = prefill_chooser.clone();
774

775
            // Add chat engine only if the model supports chat
776
            if card.model_type.supports_chat() {
777
778
779
780
781
782
783
784
785
786
787
                let factory_engine = if let Some(ref factory) = self.chat_engine_factory {
                    match factory(mcid.clone(), card.clone()).await {
                        Ok(engine) => Some(engine),
                        Err(err) => return Err(err).context("python chat_engine_factory"),
                    }
                } else {
                    None
                };

                let chat_engine = if let Some(engine) = factory_engine {
                    engine
788
                } else {
789
790
791
792
793
794
795
                    let tk = tokenizer.clone().ok_or_else(|| {
                        anyhow::anyhow!(
                            "Model has no supported Rust tokenizer and no chat_engine_factory. \
                             Use --dyn-chat-processor vllm/sglang or provide a supported \
                             tokenizer file (tokenizer.json, tiktoken.model, or *.tiktoken)."
                        )
                    })?;
796
797
798
799
800
801
                    entrypoint::build_routed_pipeline::<
                        NvCreateChatCompletionRequest,
                        NvCreateChatCompletionStreamResponse,
                    >(
                        card,
                        &client,
802
                        self.manager.clone(),
803
804
805
                        self.router_config.router_mode,
                        worker_monitor.clone(),
                        kv_chooser.clone(),
806
                        tk,
807
                        prefill_chooser.clone(),
808
                        self.router_config.enforce_disagg,
809
                        self.migration_limit,
810
                        self.migration_max_seq_len,
811
                        self.metrics.clone(),
812
813
814
815
                    )
                    .await
                    .context("build_routed_pipeline")?
                };
816
                worker_set.chat_engine = Some(chat_engine);
817
                tracing::info!("Chat completions is ready");
818
            }
819

820
821
            // Add completions engine only if the model supports completions
            // and we have a tokenizer (completions always uses the Rust preprocessor).
822
            if card.model_type.supports_completions() {
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
                if let Some(tk) = tokenizer {
                    let formatter = PromptFormatter::no_op();
                    let PromptFormatter::OAI(formatter) = formatter;
                    let preprocessor =
                        OpenAIPreprocessor::new_with_parts(card.clone(), formatter, tk.clone())
                            .context("OpenAIPreprocessor::new_with_parts")?;
                    let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
                        NvCreateCompletionRequest,
                        NvCreateCompletionResponse,
                    >(
                        card,
                        &client,
                        self.manager.clone(),
                        self.router_config.router_mode,
                        worker_monitor,
                        kv_chooser,
                        preprocessor,
                        tk,
                        prefill_chooser,
                        self.router_config.enforce_disagg,
                        self.migration_limit,
844
                        self.migration_max_seq_len,
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
                        self.metrics.clone(),
                    )
                    .await
                    .context("build_routed_pipeline_with_preprocessor")?;
                    worker_set.completions_engine = Some(completions_engine);
                    tracing::info!("Completions is ready");
                } else {
                    tracing::warn!(
                        "Skipping completions engine: no Rust tokenizer available for this model"
                    );
                }
            }

            // Verify we built at least one serving engine. A Tokens model that
            // ends up with no chat AND no completions engine (e.g. completions-only
            // model with no tokenizer) should fail fast rather than register an
            // empty WorkerSet that can't serve any requests.
            if !worker_set.has_decode_engine() {
                anyhow::bail!(
                    "Model '{}' requires frontend tokenization/preprocessing (ModelInput::Tokens) \
                     but no serving engine could be built. Provide a working tokenizer config or \
                     perform tokenization in the backend (ModelInput::Text).",
                    card.name()
                );
869
            }
870
871
872
873
874
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
875
876
            >::from_client_with_monitor(
                client, self.router_config.router_mode, None
877
878
            )
            .await?;
879
            worker_set.embeddings_engine = Some(Arc::new(push_router));
880
881
882
883
884
885
886
887
888
889
        }
        // Case: Text + (Images, Audio, Videos)
        // Must come before the plain Text+Chat / Text+Completions branches because
        // diffusion models often set both Images and Chat flags. The branch below
        // handles the chat registration internally when supports_chat() is true.
        else if card.model_input == ModelInput::Text
            && (card.model_type.supports_images()
                || card.model_type.supports_audios()
                || card.model_type.supports_videos())
        {
890
            // Image/Audio/Video models can also support chat completions (vLLM omni way)
891
892
893
894
            if card.model_type.supports_chat() {
                let chat_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
895
896
                >::from_client_with_monitor(
                    client.clone(), self.router_config.router_mode, None
897
898
                )
                .await?;
899
                worker_set.chat_engine = Some(Arc::new(chat_router));
900
901
902
903
904
905
            }

            if card.model_type.supports_images() {
                let images_router = PushRouter::<
                    NvCreateImageRequest,
                    Annotated<NvImagesResponse>,
906
907
                >::from_client_with_monitor(
                    client.clone(), self.router_config.router_mode, None
908
909
                )
                .await?;
910
                worker_set.images_engine = Some(Arc::new(images_router));
911
912
913
914
915
916
            }

            if card.model_type.supports_videos() {
                let videos_router = PushRouter::<
                    NvCreateVideoRequest,
                    Annotated<NvVideosResponse>,
917
918
                >::from_client_with_monitor(
                    client.clone(), self.router_config.router_mode, None
919
920
                )
                .await?;
921
                worker_set.videos_engine = Some(Arc::new(videos_router));
922
923
            }

924
925
926
927
            if card.model_type.supports_audios() {
                let audios_router = PushRouter::<
                    NvCreateAudioSpeechRequest,
                    Annotated<NvAudioSpeechResponse>,
928
929
                >::from_client_with_monitor(
                    client.clone(), self.router_config.router_mode, None
930
931
932
933
                )
                .await?;
                worker_set.audios_engine = Some(Arc::new(audios_router));
            }
934
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
935
            // Case: Text + Chat (pure text-to-text, no diffusion)
936
937
938
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
939
940
            >::from_client_with_monitor(
                client, self.router_config.router_mode, None
941
942
            )
            .await?;
943
            worker_set.chat_engine = Some(Arc::new(push_router));
944
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
945
            // Case: Text + Completions
946
947
948
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
949
950
            >::from_client_with_monitor(
                client, self.router_config.router_mode, None
951
952
            )
            .await?;
953
            worker_set.completions_engine = Some(Arc::new(push_router));
954
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
955
956
957
958
959
960
961
            // Case 4: Tokens + Embeddings
            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

962
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
963
            let backend = Backend::from_mdc(card).into_operator();
964
965
966
967

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
968
969
            >::from_client_with_monitor(
                client, self.router_config.router_mode, None
970
971
972
            )
            .await?;

973
            // Note: Embeddings don't need KV routing complexity or load monitoring
974
975
976
977
978
979
980
981
982
983
984
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

985
            worker_set.embeddings_engine = Some(embedding_engine);
986
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
987
            // Case 6: Tensor + TensorBased (non-LLM)
988
            // No KV cache concepts - not an LLM model
989
990
991
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
992
993
            >::from_client_with_monitor(
                client, self.router_config.router_mode, None
994
995
            )
            .await?;
996
            worker_set.tensor_engine = Some(Arc::new(push_router));
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        } else if card.model_type.supports_prefill() {
            // Case 6: Prefill
            // Guardrail: Verify model_input is Tokens
            if card.model_input != ModelInput::Tokens {
                anyhow::bail!(
                    "Prefill models must use ModelInput::Tokens, got {}",
                    card.model_input.as_str()
                );
            }

            tracing::info!(
                model_name = card.name(),
1009
1010
1011
                "Prefill model detected, registering and activating prefill router"
            );

1012
1013
            // Prefill sets have no engines — we add the WorkerSet first for tracking,
            // then activate the prefill router.
1014
            self.manager
1015
                .add_worker_set(card.name(), &ws_key, worker_set);
1016

1017
1018
1019
1020
            if let Some(tx) = &self.model_update_tx {
                tx.send(ModelUpdate::Added(card.clone())).await.ok();
            }

1021
1022
1023
1024
1025
1026
1027
            // Note: activate_prefill_router is keyed by deployment namespace (not ws_key)
            // because it coordinates between decode and prefill WorkerSets that share
            // the same deployment namespace but have different ws_keys ("ns" vs "ns:prefill").
            let Ok(()) = self
                .manager
                .activate_prefill_router(card.name(), &namespace, endpoint)
            else {
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
                tracing::warn!(
                    model_name = card.name(),
                    "Failed to activate prefill router - prefill model may already be activated"
                );
                return Ok(());
            };

            tracing::info!(
                model_name = card.name(),
                "Prefill model registered and router activated successfully"
1038
            );
1039
1040

            return Ok(());
1041
1042
1043
1044
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
1045
                Tokens+(Chat|Completions|Prefill), Text+(Chat|Completions|Images), Tokens+Embeddings, Tensor+TensorBased",
1046
1047
                card.model_type,
                card.model_input.as_str()
1048
            );
1049
        }
Ryan Olson's avatar
Ryan Olson committed
1050

1051
1052
        // Add the completed WorkerSet to the Model
        self.manager
1053
            .add_worker_set(card.name(), &ws_key, worker_set);
1054

1055
1056
1057
1058
        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }

1059
1060
        Ok(())
    }
1061

1062
    /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
1063
    async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
1064
1065
        let discovery = self.drt.discovery();
        let instances = discovery.list(DiscoveryQuery::AllModels).await?;
1066

1067
1068
1069
        let mut results = Vec::with_capacity(instances.len());
        for instance in instances {
            match instance.deserialize_model::<ModelDeploymentCard>() {
1070
                Ok(card) => {
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
                    let endpoint_id = match &instance {
                        dynamo_runtime::discovery::DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            ..
                        } => EndpointId {
                            namespace: namespace.clone(),
                            component: component.clone(),
                            name: endpoint.clone(),
                        },
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
1086
1087
1088
                            continue;
                        }
                    };
1089
                    results.push((endpoint_id, card));
1090
                }
1091
                Err(err) => {
1092
                    tracing::error!(%err, "Failed to deserialize model card");
1093
1094
                    continue;
                }
1095
            }
1096
        }
1097
        Ok(results)
1098
1099
    }

1100
    pub async fn cards_for_model(
1101
1102
        &self,
        model_name: &str,
1103
        namespace_filter: &NamespaceFilter,
1104
    ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
        Ok(self
            .cards_for_model_with_endpoints(model_name, namespace_filter)
            .await?
            .into_iter()
            .map(|(_, card)| card)
            .collect())
    }

    /// Like `cards_for_model` but also returns the EndpointId for each card,
    /// allowing callers to filter by namespace.
    async fn cards_for_model_with_endpoints(
        &self,
        model_name: &str,
        namespace_filter: &NamespaceFilter,
    ) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
1120
1121
1122
        let mut all = self.all_cards().await?;
        all.retain(|(endpoint_id, card)| {
            let matches_name = card.name() == model_name;
1123
            let matches_namespace = namespace_filter.matches(&endpoint_id.namespace);
1124
1125
            matches_name && matches_namespace
        });
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
        Ok(all)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::discovery::WorkerSet;
    use crate::model_card::ModelDeploymentCard;

    fn make_worker_set(namespace: &str) -> WorkerSet {
        WorkerSet::new(
            namespace.to_string(),
            "test-checksum".to_string(),
            ModelDeploymentCard::default(),
        )
    }

    #[test]
    fn test_is_model_type_list_empty_on_empty_manager() {
        let mm = ModelManager::new();
        assert!(is_model_type_list_empty(&mm, ModelType::Chat));
        assert!(is_model_type_list_empty(&mm, ModelType::Completions));
        assert!(is_model_type_list_empty(&mm, ModelType::Embedding));
        assert!(is_model_type_list_empty(&mm, ModelType::Images));
1151
        assert!(is_model_type_list_empty(&mm, ModelType::Audios));
1152
1153
1154
1155
1156
1157
1158
1159
1160
        assert!(is_model_type_list_empty(&mm, ModelType::Videos));
        assert!(is_model_type_list_empty(&mm, ModelType::TensorBased));
        assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
    }

    #[test]
    fn test_is_model_type_list_empty_prefill_present() {
        let mm = ModelManager::new();
        // A WorkerSet with no engines is treated as a prefill set
1161
        mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"));
1162
1163
1164
1165
1166
1167
1168

        assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
        // Other types should still be empty since the WorkerSet has no engines
        assert!(is_model_type_list_empty(&mm, ModelType::Chat));
        assert!(is_model_type_list_empty(&mm, ModelType::Completions));
        assert!(is_model_type_list_empty(&mm, ModelType::Embedding));
        assert!(is_model_type_list_empty(&mm, ModelType::Images));
1169
        assert!(is_model_type_list_empty(&mm, ModelType::Audios));
1170
1171
1172
1173
1174
1175
1176
        assert!(is_model_type_list_empty(&mm, ModelType::Videos));
        assert!(is_model_type_list_empty(&mm, ModelType::TensorBased));
    }

    #[test]
    fn test_is_model_type_list_empty_after_removal() {
        let mm = ModelManager::new();
1177
        mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"));
1178
1179
1180
1181
1182
1183
1184
1185
1186
        assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));

        mm.remove_model("model-a");
        assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
    }

    #[test]
    fn test_is_model_type_list_not_empty_when_other_model_remains() {
        let mm = ModelManager::new();
1187
1188
        mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"));
        mm.add_worker_set("model-b", "ns1", make_worker_set("ns1"));
1189
1190
1191
1192
1193
1194
1195
1196

        // Remove one model — other still provides prefill
        mm.remove_model("model-a");
        assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));

        // Remove the last model — now empty
        mm.remove_model("model-b");
        assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
1197
1198
    }
}