watcher.rs 33.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Ryan Olson's avatar
Ryan Olson committed
2
3
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::Notify;
6
use tokio::sync::mpsc::Sender;
7

8
use anyhow::Context as _;
9
use dashmap::DashSet;
10
use futures::StreamExt;
Ryan Olson's avatar
Ryan Olson committed
11

Neelay Shah's avatar
Neelay Shah committed
12
use dynamo_runtime::{
13
    DistributedRuntime,
14
15
16
17
    discovery::{
        DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery, DiscoveryStream,
        ModelCardInstanceId,
    },
18
    pipeline::{
19
20
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
21
    },
22
    protocols::{EndpointId, annotated::Annotated},
23
};
Ryan Olson's avatar
Ryan Olson committed
24

25
26
use crate::{
    backend::Backend,
27
    discovery::WORKER_TYPE_DECODE,
28
    entrypoint::{self, ChatEngineFactoryCallback, RouterConfig},
29
    http::service::metrics::Metrics,
30
    kv_router::PrefillRouter,
31
    model_card::ModelDeploymentCard,
32
33
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
34
35
36
37
38
39
40
41
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
42
            images::{NvCreateImageRequest, NvImagesResponse},
43
            videos::{NvCreateVideoRequest, NvVideosResponse},
44
        },
45
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
46
47
    },
};
48

49
use super::ModelManager;
50
use crate::namespace::is_global_namespace;
51

52
#[derive(Debug, Clone)]
53
pub enum ModelUpdate {
54
55
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
56
57
}

58
pub struct ModelWatcher {
59
    manager: Arc<ModelManager>,
60
    drt: DistributedRuntime,
61
    router_config: RouterConfig,
62
    migration_limit: u32,
63
    notify_on_model: Notify,
64
    model_update_tx: Option<Sender<ModelUpdate>>,
65
    chat_engine_factory: Option<ChatEngineFactoryCallback>,
66
    metrics: Arc<Metrics>,
67
    registering_models: DashSet<String>,
Ryan Olson's avatar
Ryan Olson committed
68
69
}

70
71
72
73
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
74
    ModelType::Images,
75
    ModelType::Audios,
76
77
    ModelType::Videos,
    ModelType::TensorBased,
78
    ModelType::Prefill,
79
];
80

81
impl ModelWatcher {
82
    pub fn new(
83
        runtime: DistributedRuntime,
84
        model_manager: Arc<ModelManager>,
85
        router_config: RouterConfig,
86
        migration_limit: u32,
87
        chat_engine_factory: Option<ChatEngineFactoryCallback>,
88
        metrics: Arc<Metrics>,
89
90
    ) -> ModelWatcher {
        Self {
91
            manager: model_manager,
92
            drt: runtime,
93
            router_config,
94
            migration_limit,
95
            notify_on_model: Notify::new(),
96
            model_update_tx: None,
97
            chat_engine_factory,
98
            metrics,
99
            registering_models: DashSet::new(),
100
        }
101
    }
Ryan Olson's avatar
Ryan Olson committed
102

103
104
105
106
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

107
108
109
110
111
112
113
114
115
116
117
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

118
    /// Common watch logic with optional namespace filtering
119
120
121
122
123
    pub async fn watch(
        &self,
        mut discovery_stream: DiscoveryStream,
        target_namespace: Option<&str>,
    ) {
124
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
125

126
127
128
129
130
131
132
133
134
        while let Some(result) = discovery_stream.next().await {
            let event = match result {
                Ok(event) => event,
                Err(err) => {
                    tracing::error!(%err, "Error in discovery stream");
                    continue;
                }
            };

135
            match event {
136
                DiscoveryEvent::Added(instance) => {
137
138
                    // Extract ModelCardInstanceId and card from the discovery instance
                    let (mcid, mut card) = match &instance {
139
140
141
142
143
                        DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            instance_id,
144
                            model_suffix,
145
146
                            ..
                        } => {
147
                            let mcid = ModelCardInstanceId {
148
149
                                namespace: namespace.clone(),
                                component: component.clone(),
150
151
152
                                endpoint: endpoint.clone(),
                                instance_id: *instance_id,
                                model_suffix: model_suffix.clone(),
153
154
155
                            };

                            match instance.deserialize_model::<ModelDeploymentCard>() {
156
                                Ok(card) => (mcid, card),
157
158
159
160
161
162
163
164
165
166
                                Err(err) => {
                                    tracing::error!(%err, instance_id, "Failed to deserialize model card");
                                    continue;
                                }
                            }
                        }
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
167
168
169
                            continue;
                        }
                    };
170
171
172
173

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
174
                        && mcid.namespace != target_ns
175
176
                    {
                        tracing::debug!(
177
                            model_namespace = mcid.namespace,
178
179
180
181
182
183
                            target_namespace = target_ns,
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                    // If we already have a worker for this model, and the ModelDeploymentCard
                    // cards don't match, alert, and don't add the new instance
                    let can_add =
                        self.manager
                            .is_valid_checksum(card.model_type, card.name(), card.mdcsum());
                    if can_add.is_some_and(|is_valid| !is_valid) {
                        tracing::error!(
                            model_name = card.name(),
                            "Checksum for new model does not match existing model."
                        );

                        // TODO: mark that instance down in clients
                        // Not obvious how to do that given the current design
                        // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
                        // Calling `report_instance_down` on the Client should do it (although
                        // needs more testing).
                        // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
                        // interface `AsyncEngine` which only has a `generate` method.

                        continue;
                    }
205

206
                    match self.handle_put(&mcid, &mut card).await {
207
                        Ok(()) => {
208
                            tracing::info!(
209
                                model_name = card.name(),
210
                                namespace = mcid.namespace,
211
212
                                "added model"
                            );
213
                            self.notify_on_model.notify_waiters();
214
                        }
215
216
                        Err(err) => {
                            tracing::error!(
217
                                model_name = card.name(),
218
                                namespace = mcid.namespace,
219
                                error = format!("{err:#}"),
220
                                "Error adding model from discovery",
221
                            );
222
223
224
                        }
                    }
                }
225
226
227
228
                DiscoveryEvent::Removed(id) => {
                    // Extract ModelCardInstanceId from the removal event
                    let model_card_instance_id = match &id {
                        DiscoveryInstanceId::Model(mcid) => mcid,
229
                        DiscoveryInstanceId::Endpoint(_) | DiscoveryInstanceId::EventChannel(_) => {
230
231
232
233
234
235
                            tracing::error!(
                                "Unexpected discovery instance type in removal (expected Model)"
                            );
                            continue;
                        }
                    };
236

237
                    match self
238
                        .handle_delete(model_card_instance_id, target_namespace, global_namespace)
239
240
241
242
243
244
245
246
247
248
249
                        .await
                    {
                        Ok(Some(model_name)) => {
                            tracing::info!(model_name, "removed model");
                        }
                        Ok(None) => {
                            // There are other instances running this model, nothing to do
                        }
                        Err(e) => {
                            tracing::error!(error = %e, "error removing model");
                        }
250
                    }
251
                }
252
            }
Ryan Olson's avatar
Ryan Olson committed
253
254
255
        }
    }

256
257
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
258
259
    async fn handle_delete(
        &self,
260
        mcid: &ModelCardInstanceId,
261
262
263
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Option<String>> {
264
265
        let key = mcid.to_path();
        let card = match self.manager.remove_model_card(&key) {
266
            Some(card) => card,
267
            None => {
268
                anyhow::bail!("Missing ModelDeploymentCard for {}", key);
269
270
            }
        };
271
        let model_name = card.name().to_string();
272
        let active_instances = self
273
            .cards_for_model(&model_name, target_namespace, is_global_namespace)
274
275
276
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
277
278
279
280
281
282
            tracing::debug!(
                model_name,
                target_namespace = ?target_namespace,
                active_instance_count = active_instances.len(),
                "Model has other active instances, not removing"
            );
283
284
            return Ok(None);
        }
285

286
        // Ignore the errors because model could be either type
287
288
289
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
290
        let images_model_remove_err = self.manager.remove_images_model(&model_name);
291
292
        let videos_model_remove_err = self.manager.remove_videos_model(&model_name);
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
293
        let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
294
295
296
297

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
298
        let mut images_model_removed = false;
299
300
        let mut videos_model_removed = false;
        let mut tensor_model_removed = false;
301
        let mut prefill_model_removed = false;
302
303
304
305
306
307
308
309
310
311
312

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
313
314
315
        if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
            images_model_removed = true;
        }
316
317
318
319
320
321
        if videos_model_remove_err.is_ok() && self.manager.list_videos_models().is_empty() {
            videos_model_removed = true;
        }
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
322
323
324
        if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
            prefill_model_removed = true;
        }
325

326
327
328
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
329
            && !images_model_removed
330
331
            && !videos_model_removed
            && !tensor_model_removed
332
            && !prefill_model_removed
333
        {
334
            tracing::debug!(
335
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, images_model_removed: {}, videos_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
336
337
338
                model_name,
                chat_model_removed,
                completions_model_removed,
339
                embeddings_model_removed,
340
                images_model_removed,
341
342
                videos_model_removed,
                tensor_model_removed,
343
                prefill_model_removed
344
345
346
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
347
                if ((chat_model_removed && *model_type == ModelType::Chat)
348
                    || (completions_model_removed && *model_type == ModelType::Completions)
349
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
350
                    || (images_model_removed && *model_type == ModelType::Images)
351
352
                    || (videos_model_removed && *model_type == ModelType::Videos)
                    || (tensor_model_removed && *model_type == ModelType::TensorBased)
353
                    || (prefill_model_removed && *model_type == ModelType::Prefill))
354
                    && let Some(tx) = &self.model_update_tx
355
                {
356
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
357
358
359
                }
            }
        }
360

361
        Ok(Some(model_name))
362
    }
Ryan Olson's avatar
Ryan Olson committed
363

364
    // Handles a PUT event from store, this usually means adding a new model to the list of served
365
    // models.
366
367
    async fn handle_put(
        &self,
368
        mcid: &ModelCardInstanceId,
369
370
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
371
372
373
        // Check if model is already registered before downloading config.
        // This prevents duplicate HuggingFace API calls when multiple workers register
        // the same model.
374
375
376
377
378
379
380
381
382
        // Prefill and decode models are tracked separately, so registering one
        // doesn't block the other (they can arrive in any order).
        let already_registered = if card.model_type.supports_prefill() {
            self.manager.has_prefill_model(card.name())
        } else {
            self.manager.has_decode_model(card.name())
        };

        if already_registered {
383
384
            self.manager
                .save_model_card(&mcid.to_path(), card.clone())?;
385
386
            tracing::debug!(
                model_name = card.name(),
387
                namespace = mcid.namespace,
388
                model_type = %card.model_type,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                "Model already registered, skipping config download"
            );
            return Ok(());
        }

        // Use registering_models set to prevent concurrent registrations.
        let model_key = card.name().to_string();
        if !self.registering_models.insert(model_key.clone()) {
            self.manager
                .save_model_card(&mcid.to_path(), card.clone())?;
            tracing::debug!(
                model_name = card.name(),
                namespace = mcid.namespace,
                "Model registration in progress by another worker, skipping"
403
404
405
406
            );
            return Ok(());
        }

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        // We acquired the registration lock. Use a helper to ensure cleanup on all exit paths.
        let result = self.do_model_registration(mcid, card).await;

        // Always remove from registering set, whether success or failure
        self.registering_models.remove(&model_key);

        result
    }

    /// Inner function that performs the actual model registration.
    /// Called by handle_put after acquiring the registration lock.
    async fn do_model_registration(
        &self,
        mcid: &ModelCardInstanceId,
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
        card.download_config().await?;

        let component = self
            .drt
            .namespace(&mcid.namespace)?
            .component(&mcid.component)?;
        let endpoint = component.endpoint(&mcid.endpoint);
        let client = endpoint.client().await?;
        tracing::debug!(model_name = card.name(), "adding model");
        self.manager
            .save_model_card(&mcid.to_path(), card.clone())?;

435
436
437
        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }
438

439
        let checksum = card.mdcsum();
440
441
442

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
443
444
445
446
447
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

448
            let endpoint = component.endpoint(&mcid.endpoint);
449
450
451
452
453
454
455
456
457
            // Create the KV router whenever any local routed pipeline will be built.
            // The chat factory builds its own router, but completions currently always
            // uses the local routed pipeline and therefore still needs a chooser.
            let needs_local_chat_pipeline =
                card.model_type.supports_chat() && self.chat_engine_factory.is_none();
            let needs_local_completions_pipeline = card.model_type.supports_completions();
            let kv_chooser = if self.router_config.router_mode == RouterMode::KV
                && (needs_local_chat_pipeline || needs_local_completions_pipeline)
            {
458
459
                Some(
                    self.manager
460
461
462
463
                        .kv_chooser_for(
                            &endpoint,
                            card.kv_cache_block_size,
                            Some(self.router_config.kv_router_config),
464
                            WORKER_TYPE_DECODE, // This is the decode router
465
                        )
466
467
468
469
470
                        .await?,
                )
            } else {
                None
            };
471

472
            // This is expensive, we are loading ~10MiB JSON, so only do it once
473
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
474

475
476
            // Create prefill chooser once if we're building pipelines
            // Both chat and completions will share the same prefill chooser instance
477
            let model_name = card.name().to_string();
478
479
            let prefill_chooser = self
                .manager
480
                .register_prefill_router(model_name.clone())
481
482
                .map(|rx| {
                    // Create prefill-specific config with track_active_blocks disabled
483
                    let mut prefill_config = self.router_config.kv_router_config;
484
485
486
487
488
                    prefill_config.router_track_active_blocks = false;

                    PrefillRouter::new(
                        rx,
                        self.manager.clone(),
489
                        self.router_config.router_mode,
490
491
                        card.kv_cache_block_size,
                        Some(prefill_config),
492
                        self.router_config.enforce_disagg,
493
                        model_name.clone(), // Pass model name for worker monitor lookup
494
495
496
                    )
                });

497
498
499
500
501
502
503
504
505
            // Get or create the worker monitor for this model.
            // Always create the monitor for Prometheus metrics (active_decode_blocks, active_prefill_tokens,
            // worker TTFT/ITL cleanup). The thresholds control busy detection behavior only.
            // LoadThresholdConfig allows dynamic threshold updates via the ModelManager.
            let worker_monitor = Some(self.manager.get_or_create_worker_monitor(
                card.name(),
                client.clone(),
                self.router_config.load_threshold_config.clone(),
            ));
506

507
            // Add chat engine only if the model supports chat
508
            if card.model_type.supports_chat() {
509
510
511
512
513
514
515
516
517
518
519
                let factory_engine = if let Some(ref factory) = self.chat_engine_factory {
                    match factory(mcid.clone(), card.clone()).await {
                        Ok(engine) => Some(engine),
                        Err(err) => return Err(err).context("python chat_engine_factory"),
                    }
                } else {
                    None
                };

                let chat_engine = if let Some(engine) = factory_engine {
                    engine
520
521
522
523
524
525
526
                } else {
                    entrypoint::build_routed_pipeline::<
                        NvCreateChatCompletionRequest,
                        NvCreateChatCompletionStreamResponse,
                    >(
                        card,
                        &client,
527
                        self.manager.clone(),
528
529
530
531
532
533
                        self.router_config.router_mode,
                        worker_monitor.clone(),
                        kv_chooser.clone(),
                        tokenizer_hf.clone(),
                        prefill_chooser.clone(),
                        self.router_config.enforce_disagg,
534
                        self.migration_limit,
535
                        self.metrics.clone(),
536
537
538
539
                    )
                    .await
                    .context("build_routed_pipeline")?
                };
540
                self.manager
541
                    .add_chat_completions_model(card.name(), checksum, chat_engine)
542
                    .context("add_chat_completions_model")?;
543
                tracing::info!("Chat completions is ready");
544
            }
545

546
            // Add completions engine only if the model supports completions.
547
            if card.model_type.supports_completions() {
548
549
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
550
551
552
553
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
554
555
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
556
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
557
558
559
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
560
                    card,
561
                    &client,
562
                    self.manager.clone(),
563
                    self.router_config.router_mode,
564
                    worker_monitor,
565
                    kv_chooser,
566
                    preprocessor,
567
                    tokenizer_hf,
568
                    prefill_chooser,
569
                    self.router_config.enforce_disagg,
570
                    self.migration_limit,
571
                    self.metrics.clone(),
572
                )
573
574
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
575
                self.manager
576
                    .add_completions_model(card.name(), checksum, completions_engine)
577
                    .context("add_completions_model")?;
578
                tracing::info!("Completions is ready");
579
            }
580
581
582
583
584
585
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
            >::from_client_with_threshold(
586
                client, self.router_config.router_mode, None, None
587
588
589
590
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
591
                .add_embeddings_model(card.name(), checksum, engine)?;
592
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
593
            // Case 3: Text + Chat
594
595
596
597
598
599
600
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
            >::from_client_with_threshold(
                client, self.router_config.router_mode, None, None
            )
            .await?;
601
602
            let engine = Arc::new(push_router);
            self.manager
603
                .add_chat_completions_model(card.name(), checksum, engine)?;
604
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
605
606
607
608
609
            // Case 2: Text + Completions
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
610
                client, self.router_config.router_mode, None, None
611
612
613
614
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
615
                .add_completions_model(card.name(), checksum, engine)?;
616
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
617
618
619
620
621
622
623
624
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

625
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
626
            let backend = Backend::from_mdc(card).into_operator();
627
628
629
630
631

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
632
                client, self.router_config.router_mode, None, None
633
634
635
            )
            .await?;

636
            // Note: Embeddings don't need KV routing complexity or load monitoring
637
638
639
640
641
642
643
644
645
646
647
648
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
649
                .add_embeddings_model(card.name(), checksum, embedding_engine)?;
650
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
651
            // Case 6: Tensor + TensorBased (non-LLM)
652
            // No KV cache concepts - not an LLM model
653
654
655
656
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
657
                client, self.router_config.router_mode, None, None
658
659
660
            )
            .await?;
            let engine = Arc::new(push_router);
661
662
            self.manager
                .add_tensor_model(card.name(), checksum, engine)?;
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        }
        // Case: Text + (Images, Audio, Videos)
        else if card.model_input == ModelInput::Text
            && (card.model_type.supports_images()
                || card.model_type.supports_audios()
                || card.model_type.supports_videos())
        {
            // Image Models can support chat completions (vllm omni way)
            // So register chat_completions model as well
            if card.model_type.supports_chat() {
                let chat_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client_with_threshold(
                    client.clone(),
                    self.router_config.router_mode,
                    None,
                    None,
                )
                .await?;
                self.manager.add_chat_completions_model(
                    card.name(),
                    checksum,
                    Arc::new(chat_router),
                )?;
            }
689

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
            // This is ModelType::Images : registers /v1/images/* endpoints
            if card.model_type.supports_images() {
                let images_router = PushRouter::<
                    NvCreateImageRequest,
                    Annotated<NvImagesResponse>,
                >::from_client_with_threshold(
                    client.clone(), self.router_config.router_mode, None, None
                )
                .await?;
                self.manager
                    .add_images_model(card.name(), checksum, Arc::new(images_router))?;
            }

            // This is ModelType::Videos : registers /v1/videos/* endpoints
            if card.model_type.supports_videos() {
                let videos_router = PushRouter::<
                    NvCreateVideoRequest,
                    Annotated<NvVideosResponse>,
                >::from_client_with_threshold(
                    client.clone(), self.router_config.router_mode, None, None
                )
                .await?;
                self.manager
                    .add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
            }

            // TODO: add audio models support
717
718
719
720
721
722
723
724
725
726
727
728
        } else if card.model_type.supports_prefill() {
            // Case 6: Prefill
            // Guardrail: Verify model_input is Tokens
            if card.model_input != ModelInput::Tokens {
                anyhow::bail!(
                    "Prefill models must use ModelInput::Tokens, got {}",
                    card.model_input.as_str()
                );
            }

            tracing::info!(
                model_name = card.name(),
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
                "Prefill model detected, registering and activating prefill router"
            );

            // Register prefill model for tracking (no engine needed, just lifecycle)
            self.manager
                .add_prefill_model(card.name(), checksum)
                .context("add_prefill_model")?;

            // Activate the prefill router with the endpoint for this prefill model
            let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else {
                tracing::warn!(
                    model_name = card.name(),
                    "Failed to activate prefill router - prefill model may already be activated"
                );
                return Ok(());
            };

            tracing::info!(
                model_name = card.name(),
                "Prefill model registered and router activated successfully"
749
            );
750
751
752
753
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
754
                Tokens+(Chat|Completions|Prefill), Text+(Chat|Completions|Images), Tokens+Embeddings, Tensor+TensorBased",
755
756
                card.model_type,
                card.model_input.as_str()
757
            );
758
        }
Ryan Olson's avatar
Ryan Olson committed
759

760
761
        Ok(())
    }
762

763
    /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
764
    async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
765
766
        let discovery = self.drt.discovery();
        let instances = discovery.list(DiscoveryQuery::AllModels).await?;
767

768
769
770
        let mut results = Vec::with_capacity(instances.len());
        for instance in instances {
            match instance.deserialize_model::<ModelDeploymentCard>() {
771
                Ok(card) => {
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
                    // Extract EndpointId from the instance
                    let endpoint_id = match &instance {
                        dynamo_runtime::discovery::DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            ..
                        } => EndpointId {
                            namespace: namespace.clone(),
                            component: component.clone(),
                            name: endpoint.clone(),
                        },
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
788
789
790
                            continue;
                        }
                    };
791
                    results.push((endpoint_id, card));
792
                }
793
                Err(err) => {
794
                    tracing::error!(%err, "Failed to deserialize model card");
795
796
                    continue;
                }
797
            }
798
        }
799
        Ok(results)
800
801
    }

802
    pub async fn cards_for_model(
803
804
805
806
        &self,
        model_name: &str,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
807
808
809
810
    ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
        let mut all = self.all_cards().await?;
        all.retain(|(endpoint_id, card)| {
            let matches_name = card.name() == model_name;
811
812
813
            let matches_namespace = match (is_global_namespace, target_namespace) {
                (true, _) => true,
                (false, None) => true,
814
                (false, Some(target_ns)) => endpoint_id.namespace == target_ns,
815
816
817
            };
            matches_name && matches_namespace
        });
818
819
820
        Ok(all.into_iter().map(|(_eid, card)| card).collect())
    }
}