watcher.rs 34.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Ryan Olson's avatar
Ryan Olson committed
2
3
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::Notify;
6
use tokio::sync::mpsc::Sender;
7

8
use anyhow::Context as _;
9
use dashmap::DashSet;
10
use futures::StreamExt;
Ryan Olson's avatar
Ryan Olson committed
11

Neelay Shah's avatar
Neelay Shah committed
12
use dynamo_runtime::{
13
    DistributedRuntime,
14
15
16
17
    discovery::{
        DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery, DiscoveryStream,
        ModelCardInstanceId,
    },
18
    pipeline::{
19
20
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
21
    },
22
    protocols::{EndpointId, annotated::Annotated},
23
};
Ryan Olson's avatar
Ryan Olson committed
24

25
26
use crate::{
    backend::Backend,
27
    discovery::WORKER_TYPE_DECODE,
28
    entrypoint::{self, ChatEngineFactoryCallback, RouterConfig},
29
    http::service::metrics::Metrics,
30
    kv_router::PrefillRouter,
31
    model_card::ModelDeploymentCard,
32
33
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
34
35
36
37
38
39
40
41
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
42
            images::{NvCreateImageRequest, NvImagesResponse},
43
            videos::{NvCreateVideoRequest, NvVideosResponse},
44
        },
45
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
46
47
    },
};
48

49
use super::ModelManager;
50
use crate::namespace::is_global_namespace;
51

52
#[derive(Debug, Clone)]
53
pub enum ModelUpdate {
54
55
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
56
57
}

58
pub struct ModelWatcher {
59
    manager: Arc<ModelManager>,
60
    drt: DistributedRuntime,
61
    router_config: RouterConfig,
62
    migration_limit: u32,
63
    notify_on_model: Notify,
64
    model_update_tx: Option<Sender<ModelUpdate>>,
65
    chat_engine_factory: Option<ChatEngineFactoryCallback>,
66
    metrics: Arc<Metrics>,
67
    registering_models: DashSet<String>,
Ryan Olson's avatar
Ryan Olson committed
68
69
}

70
71
72
73
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
74
    ModelType::Images,
75
    ModelType::Audios,
76
77
    ModelType::Videos,
    ModelType::TensorBased,
78
    ModelType::Prefill,
79
];
80

81
impl ModelWatcher {
82
    pub fn new(
83
        runtime: DistributedRuntime,
84
        model_manager: Arc<ModelManager>,
85
        router_config: RouterConfig,
86
        migration_limit: u32,
87
        chat_engine_factory: Option<ChatEngineFactoryCallback>,
88
        metrics: Arc<Metrics>,
89
90
    ) -> ModelWatcher {
        Self {
91
            manager: model_manager,
92
            drt: runtime,
93
            router_config,
94
            migration_limit,
95
            notify_on_model: Notify::new(),
96
            model_update_tx: None,
97
            chat_engine_factory,
98
            metrics,
99
            registering_models: DashSet::new(),
100
        }
101
    }
Ryan Olson's avatar
Ryan Olson committed
102

103
104
105
106
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

107
108
109
110
111
112
113
114
115
116
117
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

118
    /// Common watch logic with optional namespace filtering
119
120
121
122
123
    pub async fn watch(
        &self,
        mut discovery_stream: DiscoveryStream,
        target_namespace: Option<&str>,
    ) {
124
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
125

126
127
128
129
130
131
132
133
134
        while let Some(result) = discovery_stream.next().await {
            let event = match result {
                Ok(event) => event,
                Err(err) => {
                    tracing::error!(%err, "Error in discovery stream");
                    continue;
                }
            };

135
            match event {
136
                DiscoveryEvent::Added(instance) => {
137
138
                    // Extract ModelCardInstanceId and card from the discovery instance
                    let (mcid, mut card) = match &instance {
139
140
141
142
143
                        DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            instance_id,
144
                            model_suffix,
145
146
                            ..
                        } => {
147
                            let mcid = ModelCardInstanceId {
148
149
                                namespace: namespace.clone(),
                                component: component.clone(),
150
151
152
                                endpoint: endpoint.clone(),
                                instance_id: *instance_id,
                                model_suffix: model_suffix.clone(),
153
154
155
                            };

                            match instance.deserialize_model::<ModelDeploymentCard>() {
156
                                Ok(card) => (mcid, card),
157
158
159
160
161
162
163
164
165
166
                                Err(err) => {
                                    tracing::error!(%err, instance_id, "Failed to deserialize model card");
                                    continue;
                                }
                            }
                        }
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
167
168
169
                            continue;
                        }
                    };
170
171
172
173

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
174
                        && mcid.namespace != target_ns
175
176
                    {
                        tracing::debug!(
177
                            model_namespace = mcid.namespace,
178
179
180
181
182
183
                            target_namespace = target_ns,
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                    // If we already have a worker for this model, and the ModelDeploymentCard
                    // cards don't match, alert, and don't add the new instance
                    let can_add =
                        self.manager
                            .is_valid_checksum(card.model_type, card.name(), card.mdcsum());
                    if can_add.is_some_and(|is_valid| !is_valid) {
                        tracing::error!(
                            model_name = card.name(),
                            "Checksum for new model does not match existing model."
                        );

                        // TODO: mark that instance down in clients
                        // Not obvious how to do that given the current design
                        // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
                        // Calling `report_instance_down` on the Client should do it (although
                        // needs more testing).
                        // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
                        // interface `AsyncEngine` which only has a `generate` method.

                        continue;
                    }
205

206
                    match self.handle_put(&mcid, &mut card).await {
207
                        Ok(()) => {
208
                            tracing::info!(
209
                                model_name = card.name(),
210
                                namespace = mcid.namespace,
211
212
                                "added model"
                            );
213
                            self.notify_on_model.notify_waiters();
214
                        }
215
216
                        Err(err) => {
                            tracing::error!(
217
                                model_name = card.name(),
218
                                namespace = mcid.namespace,
219
                                error = format!("{err:#}"),
220
                                "Error adding model from discovery",
221
                            );
222
223
224
                        }
                    }
                }
225
226
227
228
                DiscoveryEvent::Removed(id) => {
                    // Extract ModelCardInstanceId from the removal event
                    let model_card_instance_id = match &id {
                        DiscoveryInstanceId::Model(mcid) => mcid,
229
                        DiscoveryInstanceId::Endpoint(_) | DiscoveryInstanceId::EventChannel(_) => {
230
231
232
233
234
235
                            tracing::error!(
                                "Unexpected discovery instance type in removal (expected Model)"
                            );
                            continue;
                        }
                    };
236

237
                    match self
238
                        .handle_delete(model_card_instance_id, target_namespace, global_namespace)
239
240
241
242
243
244
245
246
247
248
249
                        .await
                    {
                        Ok(Some(model_name)) => {
                            tracing::info!(model_name, "removed model");
                        }
                        Ok(None) => {
                            // There are other instances running this model, nothing to do
                        }
                        Err(e) => {
                            tracing::error!(error = %e, "error removing model");
                        }
250
                    }
251
                }
252
            }
Ryan Olson's avatar
Ryan Olson committed
253
254
255
        }
    }

256
257
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
258
259
    async fn handle_delete(
        &self,
260
        mcid: &ModelCardInstanceId,
261
262
263
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Option<String>> {
264
265
        let key = mcid.to_path();
        let card = match self.manager.remove_model_card(&key) {
266
            Some(card) => card,
267
            None => {
268
                anyhow::bail!("Missing ModelDeploymentCard for {}", key);
269
270
            }
        };
271
        let model_name = card.name().to_string();
272
        let active_instances = self
273
            .cards_for_model(&model_name, target_namespace, is_global_namespace)
274
275
276
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
277
278
279
280
281
282
            tracing::debug!(
                model_name,
                target_namespace = ?target_namespace,
                active_instance_count = active_instances.len(),
                "Model has other active instances, not removing"
            );
283
284
            return Ok(None);
        }
285

286
        // Ignore the errors because model could be either type
287
288
289
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
290
        let images_model_remove_err = self.manager.remove_images_model(&model_name);
291
292
        let videos_model_remove_err = self.manager.remove_videos_model(&model_name);
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
293
        let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
294
295
296
297

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
298
        let mut images_model_removed = false;
299
300
        let mut videos_model_removed = false;
        let mut tensor_model_removed = false;
301
        let mut prefill_model_removed = false;
302
303
304
305
306
307
308
309
310
311
312

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
313
314
315
        if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
            images_model_removed = true;
        }
316
317
318
319
320
321
        if videos_model_remove_err.is_ok() && self.manager.list_videos_models().is_empty() {
            videos_model_removed = true;
        }
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
322
323
324
        if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
            prefill_model_removed = true;
        }
325

326
327
328
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
329
            && !images_model_removed
330
331
            && !videos_model_removed
            && !tensor_model_removed
332
            && !prefill_model_removed
333
        {
334
            tracing::debug!(
335
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, images_model_removed: {}, videos_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
336
337
338
                model_name,
                chat_model_removed,
                completions_model_removed,
339
                embeddings_model_removed,
340
                images_model_removed,
341
342
                videos_model_removed,
                tensor_model_removed,
343
                prefill_model_removed
344
345
346
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
347
                if ((chat_model_removed && *model_type == ModelType::Chat)
348
                    || (completions_model_removed && *model_type == ModelType::Completions)
349
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
350
                    || (images_model_removed && *model_type == ModelType::Images)
351
352
                    || (videos_model_removed && *model_type == ModelType::Videos)
                    || (tensor_model_removed && *model_type == ModelType::TensorBased)
353
                    || (prefill_model_removed && *model_type == ModelType::Prefill))
354
                    && let Some(tx) = &self.model_update_tx
355
                {
356
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
357
358
359
                }
            }
        }
360

361
        Ok(Some(model_name))
362
    }
Ryan Olson's avatar
Ryan Olson committed
363

364
    // Handles a PUT event from store, this usually means adding a new model to the list of served
365
    // models.
366
367
    async fn handle_put(
        &self,
368
        mcid: &ModelCardInstanceId,
369
370
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
371
372
373
        // Check if model is already registered before downloading config.
        // This prevents duplicate HuggingFace API calls when multiple workers register
        // the same model.
374
375
376
377
378
379
380
381
382
        // Prefill and decode models are tracked separately, so registering one
        // doesn't block the other (they can arrive in any order).
        let already_registered = if card.model_type.supports_prefill() {
            self.manager.has_prefill_model(card.name())
        } else {
            self.manager.has_decode_model(card.name())
        };

        if already_registered {
383
384
            self.manager
                .save_model_card(&mcid.to_path(), card.clone())?;
385
386
            tracing::debug!(
                model_name = card.name(),
387
                namespace = mcid.namespace,
388
                model_type = %card.model_type,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                "Model already registered, skipping config download"
            );
            return Ok(());
        }

        // Use registering_models set to prevent concurrent registrations.
        let model_key = card.name().to_string();
        if !self.registering_models.insert(model_key.clone()) {
            self.manager
                .save_model_card(&mcid.to_path(), card.clone())?;
            tracing::debug!(
                model_name = card.name(),
                namespace = mcid.namespace,
                "Model registration in progress by another worker, skipping"
403
404
405
406
            );
            return Ok(());
        }

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        // We acquired the registration lock. Use a helper to ensure cleanup on all exit paths.
        let result = self.do_model_registration(mcid, card).await;

        // Always remove from registering set, whether success or failure
        self.registering_models.remove(&model_key);

        result
    }

    /// Inner function that performs the actual model registration.
    /// Called by handle_put after acquiring the registration lock.
    async fn do_model_registration(
        &self,
        mcid: &ModelCardInstanceId,
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
        card.download_config().await?;

        let component = self
            .drt
            .namespace(&mcid.namespace)?
            .component(&mcid.component)?;
        let endpoint = component.endpoint(&mcid.endpoint);
        let client = endpoint.client().await?;
        tracing::debug!(model_name = card.name(), "adding model");
        self.manager
            .save_model_card(&mcid.to_path(), card.clone())?;

435
436
437
        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }
438

439
        let checksum = card.mdcsum();
440
441
442

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
443
444
445
446
447
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

448
            let endpoint = component.endpoint(&mcid.endpoint);
449
450
451
452
453
454
455
456
457
            // Create the KV router whenever any local routed pipeline will be built.
            // The chat factory builds its own router, but completions currently always
            // uses the local routed pipeline and therefore still needs a chooser.
            let needs_local_chat_pipeline =
                card.model_type.supports_chat() && self.chat_engine_factory.is_none();
            let needs_local_completions_pipeline = card.model_type.supports_completions();
            let kv_chooser = if self.router_config.router_mode == RouterMode::KV
                && (needs_local_chat_pipeline || needs_local_completions_pipeline)
            {
458
459
                Some(
                    self.manager
460
461
462
463
                        .kv_chooser_for(
                            &endpoint,
                            card.kv_cache_block_size,
                            Some(self.router_config.kv_router_config),
464
                            WORKER_TYPE_DECODE, // This is the decode router
465
                        )
466
467
468
469
470
                        .await?,
                )
            } else {
                None
            };
471

472
            // This is expensive, we are loading ~10MiB JSON, so only do it once
473
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
474

475
476
            // Create prefill chooser once if we're building pipelines
            // Both chat and completions will share the same prefill chooser instance
477
            let model_name = card.name().to_string();
478
479
            let prefill_chooser = self
                .manager
480
                .register_prefill_router(model_name.clone())
481
482
                .map(|rx| {
                    // Create prefill-specific config with track_active_blocks disabled
483
                    let mut prefill_config = self.router_config.kv_router_config;
484
485
486
487
488
                    prefill_config.router_track_active_blocks = false;

                    PrefillRouter::new(
                        rx,
                        self.manager.clone(),
489
                        self.router_config.router_mode,
490
491
                        card.kv_cache_block_size,
                        Some(prefill_config),
492
                        self.router_config.enforce_disagg,
493
                        model_name.clone(), // Pass model name for worker monitor lookup
494
495
496
                    )
                });

497
498
499
500
501
502
503
504
505
            // Get or create the worker monitor for this model.
            // Always create the monitor for Prometheus metrics (active_decode_blocks, active_prefill_tokens,
            // worker TTFT/ITL cleanup). The thresholds control busy detection behavior only.
            // LoadThresholdConfig allows dynamic threshold updates via the ModelManager.
            let worker_monitor = Some(self.manager.get_or_create_worker_monitor(
                card.name(),
                client.clone(),
                self.router_config.load_threshold_config.clone(),
            ));
506

507
            // Add chat engine only if the model supports chat
508
            if card.model_type.supports_chat() {
509
510
511
512
513
514
515
516
517
518
519
                let factory_engine = if let Some(ref factory) = self.chat_engine_factory {
                    match factory(mcid.clone(), card.clone()).await {
                        Ok(engine) => Some(engine),
                        Err(err) => return Err(err).context("python chat_engine_factory"),
                    }
                } else {
                    None
                };

                let chat_engine = if let Some(engine) = factory_engine {
                    engine
520
521
522
523
524
525
526
                } else {
                    entrypoint::build_routed_pipeline::<
                        NvCreateChatCompletionRequest,
                        NvCreateChatCompletionStreamResponse,
                    >(
                        card,
                        &client,
527
                        self.manager.clone(),
528
529
530
531
532
533
                        self.router_config.router_mode,
                        worker_monitor.clone(),
                        kv_chooser.clone(),
                        tokenizer_hf.clone(),
                        prefill_chooser.clone(),
                        self.router_config.enforce_disagg,
534
                        self.migration_limit,
535
                        self.metrics.clone(),
536
537
538
539
                    )
                    .await
                    .context("build_routed_pipeline")?
                };
540
                self.manager
541
                    .add_chat_completions_model(card.name(), checksum, chat_engine)
542
                    .context("add_chat_completions_model")?;
543
                tracing::info!("Chat completions is ready");
544
            }
545

546
            // Add completions engine only if the model supports completions.
547
            if card.model_type.supports_completions() {
548
549
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
550
551
552
553
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
554
555
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
556
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
557
558
559
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
560
                    card,
561
                    &client,
562
                    self.manager.clone(),
563
                    self.router_config.router_mode,
564
                    worker_monitor,
565
                    kv_chooser,
566
                    preprocessor,
567
                    tokenizer_hf,
568
                    prefill_chooser,
569
                    self.router_config.enforce_disagg,
570
                    self.migration_limit,
571
                    self.metrics.clone(),
572
                )
573
574
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
575
                self.manager
576
                    .add_completions_model(card.name(), checksum, completions_engine)
577
                    .context("add_completions_model")?;
578
                tracing::info!("Completions is ready");
579
            }
580
581
582
583
584
585
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
            >::from_client_with_threshold(
586
                client, self.router_config.router_mode, None, None
587
588
589
590
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
591
                .add_embeddings_model(card.name(), checksum, engine)?;
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        }
        // Case: Text + (Images, Audio, Videos)
        // Must come before the plain Text+Chat / Text+Completions branches because
        // diffusion models often set both Images and Chat flags. The branch below
        // handles the chat registration internally when supports_chat() is true.
        else if card.model_input == ModelInput::Text
            && (card.model_type.supports_images()
                || card.model_type.supports_audios()
                || card.model_type.supports_videos())
        {
            // Image Models can support chat completions (vllm omni way)
            // So register chat_completions model as well
            if card.model_type.supports_chat() {
                let chat_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client_with_threshold(
                    client.clone(),
                    self.router_config.router_mode,
                    None,
                    None,
                )
                .await?;
                self.manager.add_chat_completions_model(
                    card.name(),
                    checksum,
                    Arc::new(chat_router),
                )?;
            }

            // This is ModelType::Images : registers /v1/images/* endpoints
            if card.model_type.supports_images() {
                let images_router = PushRouter::<
                    NvCreateImageRequest,
                    Annotated<NvImagesResponse>,
                >::from_client_with_threshold(
                    client.clone(), self.router_config.router_mode, None, None
                )
                .await?;
                self.manager
                    .add_images_model(card.name(), checksum, Arc::new(images_router))?;
            }

            // This is ModelType::Videos : registers /v1/videos/* endpoints
            if card.model_type.supports_videos() {
                let videos_router = PushRouter::<
                    NvCreateVideoRequest,
                    Annotated<NvVideosResponse>,
                >::from_client_with_threshold(
                    client.clone(), self.router_config.router_mode, None, None
                )
                .await?;
                self.manager
                    .add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
            }

            // TODO: add audio models support
649
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
650
            // Case: Text + Chat (pure text-to-text, no diffusion)
651
652
653
654
655
656
657
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
            >::from_client_with_threshold(
                client, self.router_config.router_mode, None, None
            )
            .await?;
658
659
            let engine = Arc::new(push_router);
            self.manager
660
                .add_chat_completions_model(card.name(), checksum, engine)?;
661
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
662
            // Case: Text + Completions
663
664
665
666
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
667
                client, self.router_config.router_mode, None, None
668
669
670
671
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
672
                .add_completions_model(card.name(), checksum, engine)?;
673
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
674
675
676
677
678
679
680
681
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

682
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
683
            let backend = Backend::from_mdc(card).into_operator();
684
685
686
687
688

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
689
                client, self.router_config.router_mode, None, None
690
691
692
            )
            .await?;

693
            // Note: Embeddings don't need KV routing complexity or load monitoring
694
695
696
697
698
699
700
701
702
703
704
705
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
706
                .add_embeddings_model(card.name(), checksum, embedding_engine)?;
707
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
708
            // Case 6: Tensor + TensorBased (non-LLM)
709
            // No KV cache concepts - not an LLM model
710
711
712
713
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
714
                client, self.router_config.router_mode, None, None
715
716
717
            )
            .await?;
            let engine = Arc::new(push_router);
718
719
            self.manager
                .add_tensor_model(card.name(), checksum, engine)?;
720
721
722
723
724
725
726
727
728
729
730
731
        } else if card.model_type.supports_prefill() {
            // Case 6: Prefill
            // Guardrail: Verify model_input is Tokens
            if card.model_input != ModelInput::Tokens {
                anyhow::bail!(
                    "Prefill models must use ModelInput::Tokens, got {}",
                    card.model_input.as_str()
                );
            }

            tracing::info!(
                model_name = card.name(),
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
                "Prefill model detected, registering and activating prefill router"
            );

            // Register prefill model for tracking (no engine needed, just lifecycle)
            self.manager
                .add_prefill_model(card.name(), checksum)
                .context("add_prefill_model")?;

            // Activate the prefill router with the endpoint for this prefill model
            let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else {
                tracing::warn!(
                    model_name = card.name(),
                    "Failed to activate prefill router - prefill model may already be activated"
                );
                return Ok(());
            };

            tracing::info!(
                model_name = card.name(),
                "Prefill model registered and router activated successfully"
752
            );
753
754
755
756
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
757
                Tokens+(Chat|Completions|Prefill), Text+(Chat|Completions|Images), Tokens+Embeddings, Tensor+TensorBased",
758
759
                card.model_type,
                card.model_input.as_str()
760
            );
761
        }
Ryan Olson's avatar
Ryan Olson committed
762

763
764
        Ok(())
    }
765

766
    /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
767
    async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
768
769
        let discovery = self.drt.discovery();
        let instances = discovery.list(DiscoveryQuery::AllModels).await?;
770

771
772
773
        let mut results = Vec::with_capacity(instances.len());
        for instance in instances {
            match instance.deserialize_model::<ModelDeploymentCard>() {
774
                Ok(card) => {
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
                    // Extract EndpointId from the instance
                    let endpoint_id = match &instance {
                        dynamo_runtime::discovery::DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            ..
                        } => EndpointId {
                            namespace: namespace.clone(),
                            component: component.clone(),
                            name: endpoint.clone(),
                        },
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
791
792
793
                            continue;
                        }
                    };
794
                    results.push((endpoint_id, card));
795
                }
796
                Err(err) => {
797
                    tracing::error!(%err, "Failed to deserialize model card");
798
799
                    continue;
                }
800
            }
801
        }
802
        Ok(results)
803
804
    }

805
    pub async fn cards_for_model(
806
807
808
809
        &self,
        model_name: &str,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
810
811
812
813
    ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
        let mut all = self.all_cards().await?;
        all.retain(|(endpoint_id, card)| {
            let matches_name = card.name() == model_name;
814
815
816
            let matches_namespace = match (is_global_namespace, target_namespace) {
                (true, _) => true,
                (false, None) => true,
817
                (false, Some(target_ns)) => endpoint_id.namespace == target_ns,
818
819
820
            };
            matches_name && matches_namespace
        });
821
822
823
        Ok(all.into_iter().map(|(_eid, card)| card).collect())
    }
}