watcher.rs 25.9 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::Notify;
6
use tokio::sync::mpsc::Sender;
7

8
use anyhow::Context as _;
9
use futures::StreamExt;
Ryan Olson's avatar
Ryan Olson committed
10

Neelay Shah's avatar
Neelay Shah committed
11
use dynamo_runtime::{
12
    DistributedRuntime,
13
    discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoveryStream},
14
    pipeline::{
15
16
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
17
    },
18
    protocols::{EndpointId, annotated::Annotated},
19
};
Ryan Olson's avatar
Ryan Olson committed
20

21
22
use crate::{
    backend::Backend,
23
    entrypoint,
24
    kv_router::{KvRouterConfig, PrefillRouter},
25
    model_card::ModelDeploymentCard,
26
27
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
28
29
30
31
32
33
34
35
36
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
        },
37
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
38
39
    },
};
40

41
use super::ModelManager;
42
use crate::namespace::is_global_namespace;
43

44
#[derive(Debug, Clone)]
45
pub enum ModelUpdate {
46
47
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
48
49
}

50
pub struct ModelWatcher {
51
    manager: Arc<ModelManager>,
52
    drt: DistributedRuntime,
53
    router_mode: RouterMode,
54
    notify_on_model: Notify,
55
    model_update_tx: Option<Sender<ModelUpdate>>,
56
    kv_router_config: Option<KvRouterConfig>,
57
    busy_threshold: Option<f64>,
Ryan Olson's avatar
Ryan Olson committed
58
59
}

60
61
62
63
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
64
    ModelType::TensorBased,
65
    ModelType::Prefill,
66
];
67

68
impl ModelWatcher {
69
    pub fn new(
70
        runtime: DistributedRuntime,
71
        model_manager: Arc<ModelManager>,
72
        router_mode: RouterMode,
73
        kv_router_config: Option<KvRouterConfig>,
74
        busy_threshold: Option<f64>,
75
76
    ) -> ModelWatcher {
        Self {
77
            manager: model_manager,
78
            drt: runtime,
79
            router_mode,
80
            notify_on_model: Notify::new(),
81
            model_update_tx: None,
82
            kv_router_config,
83
            busy_threshold,
84
        }
85
    }
Ryan Olson's avatar
Ryan Olson committed
86

87
88
89
90
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

91
92
93
94
95
96
97
98
99
100
101
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

102
    /// Common watch logic with optional namespace filtering
103
104
105
106
107
    pub async fn watch(
        &self,
        mut discovery_stream: DiscoveryStream,
        target_namespace: Option<&str>,
    ) {
108
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
109

110
111
112
113
114
115
116
117
118
        while let Some(result) = discovery_stream.next().await {
            let event = match result {
                Ok(event) => event,
                Err(err) => {
                    tracing::error!(%err, "Error in discovery stream");
                    continue;
                }
            };

119
            match event {
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                DiscoveryEvent::Added(instance) => {
                    // Extract EndpointId, instance_id, and card from the discovery instance
                    let (endpoint_id, instance_id, mut card) = match &instance {
                        DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            instance_id,
                            ..
                        } => {
                            let eid = EndpointId {
                                namespace: namespace.clone(),
                                component: component.clone(),
                                name: endpoint.clone(),
                            };

                            match instance.deserialize_model::<ModelDeploymentCard>() {
                                Ok(card) => (eid, *instance_id, card),
                                Err(err) => {
                                    tracing::error!(%err, instance_id, "Failed to deserialize model card");
                                    continue;
                                }
                            }
                        }
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
148
149
150
                            continue;
                        }
                    };
151
152
153
154

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
155
                        && endpoint_id.namespace != target_ns
156
157
                    {
                        tracing::debug!(
158
                            model_namespace = endpoint_id.namespace,
159
160
161
162
163
164
                            target_namespace = target_ns,
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                    // If we already have a worker for this model, and the ModelDeploymentCard
                    // cards don't match, alert, and don't add the new instance
                    let can_add =
                        self.manager
                            .is_valid_checksum(card.model_type, card.name(), card.mdcsum());
                    if can_add.is_some_and(|is_valid| !is_valid) {
                        tracing::error!(
                            model_name = card.name(),
                            "Checksum for new model does not match existing model."
                        );

                        // TODO: mark that instance down in clients
                        // Not obvious how to do that given the current design
                        // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
                        // Calling `report_instance_down` on the Client should do it (although
                        // needs more testing).
                        // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
                        // interface `AsyncEngine` which only has a `generate` method.

                        continue;
                    }
186

187
188
189
190
                    // Use instance_id as the HashMap key (simpler and sufficient since keys are opaque)
                    let key = format!("{:x}", instance_id);

                    match self.handle_put(&key, &endpoint_id, &mut card).await {
191
                        Ok(()) => {
192
                            tracing::info!(
193
194
                                model_name = card.name(),
                                namespace = endpoint_id.namespace,
195
196
                                "added model"
                            );
197
                            self.notify_on_model.notify_waiters();
198
                        }
199
200
                        Err(err) => {
                            tracing::error!(
201
202
                                model_name = card.name(),
                                namespace = endpoint_id.namespace,
203
                                error = format!("{err:#}"),
204
                                "Error adding model from discovery",
205
                            );
206
207
208
                        }
                    }
                }
209
210
211
212
                DiscoveryEvent::Removed(instance_id) => {
                    // Use instance_id hex as the HashMap key (matches what we saved with)
                    let key = format!("{:x}", instance_id);

213
                    match self
214
                        .handle_delete(&key, target_namespace, global_namespace)
215
216
217
218
219
220
221
222
223
224
225
                        .await
                    {
                        Ok(Some(model_name)) => {
                            tracing::info!(model_name, "removed model");
                        }
                        Ok(None) => {
                            // There are other instances running this model, nothing to do
                        }
                        Err(e) => {
                            tracing::error!(error = %e, "error removing model");
                        }
226
                    }
227
                }
228
            }
Ryan Olson's avatar
Ryan Olson committed
229
230
231
        }
    }

232
233
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
234
235
    async fn handle_delete(
        &self,
236
        key: &str,
237
238
239
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Option<String>> {
240
241
        let card = match self.manager.remove_model_card(key) {
            Some(card) => card,
242
            None => {
243
                anyhow::bail!("Missing ModelDeploymentCard for {key}");
244
245
            }
        };
246
        let model_name = card.name().to_string();
247
        let active_instances = self
248
            .cards_for_model(&model_name, target_namespace, is_global_namespace)
249
250
251
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
252
253
254
255
256
257
            tracing::debug!(
                model_name,
                target_namespace = ?target_namespace,
                active_instance_count = active_instances.len(),
                "Model has other active instances, not removing"
            );
258
259
            return Ok(None);
        }
260

261
        // Ignore the errors because model could be either type
262
263
264
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
265
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
266
        let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
267
268
269
270

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
271
        let mut tensor_model_removed = false;
272
        let mut prefill_model_removed = false;
273
274
275
276
277
278
279
280
281
282
283

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
284
285
286
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
287
288
289
        if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
            prefill_model_removed = true;
        }
290

291
292
293
294
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
            && !tensor_model_removed
295
            && !prefill_model_removed
296
        {
297
            tracing::debug!(
298
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
299
300
301
                model_name,
                chat_model_removed,
                completions_model_removed,
302
                embeddings_model_removed,
303
304
                tensor_model_removed,
                prefill_model_removed
305
306
307
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
308
                if ((chat_model_removed && *model_type == ModelType::Chat)
309
                    || (completions_model_removed && *model_type == ModelType::Completions)
310
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
311
312
                    || (tensor_model_removed && *model_type == ModelType::TensorBased)
                    || (prefill_model_removed && *model_type == ModelType::Prefill))
313
                    && let Some(tx) = &self.model_update_tx
314
                {
315
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
316
317
318
                }
            }
        }
319

320
        Ok(Some(model_name))
321
    }
Ryan Olson's avatar
Ryan Olson committed
322

323
    // Handles a PUT event from store, this usually means adding a new model to the list of served
324
    // models.
325
326
327
328
329
330
    async fn handle_put(
        &self,
        key: &str,
        endpoint_id: &EndpointId,
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
331
332
        card.download_config().await?;

333
        let component = self
334
335
            .drt
            .namespace(&endpoint_id.namespace)?
336
            .component(&endpoint_id.component)?;
337
338
        let endpoint = component.endpoint(&endpoint_id.name);
        let client = endpoint.client().await?;
339
340
        tracing::debug!(model_name = card.name(), "adding model");
        self.manager.save_model_card(key, card.clone())?;
341

342
343
344
345
346
347
348
349
350
351
        // Check if we should skip registration:
        // - Skip if a model with this name already exists
        // - UNLESS this is a prefill model and no prefill model exists yet for this name
        let is_new_prefill = card.model_type.supports_prefill()
            && !self
                .manager
                .list_prefill_models()
                .contains(&card.name().to_string());

        if self.manager.has_model_any(card.name()) && !is_new_prefill {
352
353
354
            tracing::debug!(
                model_name = card.name(),
                namespace = endpoint_id.namespace,
355
356
                model_type = %card.model_type,
                "New endpoint for existing model, skipping"
357
358
359
360
361
362
363
            );
            return Ok(());
        }

        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }
364
        let checksum = card.mdcsum();
365
366
367

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
368
369
370
371
372
373
374
375
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

            let kv_chooser = if self.router_mode == RouterMode::KV {
                Some(
                    self.manager
376
                        .kv_chooser_for(&component, card.kv_cache_block_size, self.kv_router_config)
377
378
379
380
381
                        .await?,
                )
            } else {
                None
            };
382

383
            // This is expensive, we are loading ~10MiB JSON, so only do it once
384
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            // Create prefill chooser once if we're building pipelines
            // Both chat and completions will share the same prefill chooser instance
            let prefill_chooser = self
                .manager
                .register_prefill_router(card.name().to_string())
                .map(|rx| {
                    // Create prefill-specific config with track_active_blocks disabled
                    let mut prefill_config = self.kv_router_config.unwrap_or_default();
                    prefill_config.router_track_active_blocks = false;

                    PrefillRouter::new(
                        rx,
                        self.manager.clone(),
                        self.router_mode,
                        card.kv_cache_block_size,
                        Some(prefill_config),
                    )
                });

405
            // Add chat engine only if the model supports chat
406
            if card.model_type.supports_chat() {
407
408
409
410
                let chat_engine = entrypoint::build_routed_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(
411
                    card,
412
413
414
415
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser.clone(),
416
                    tokenizer_hf.clone(),
417
                    prefill_chooser.clone(),
418
                )
419
420
                .await
                .context("build_routed_pipeline")?;
421
                self.manager
422
                    .add_chat_completions_model(card.name(), checksum, chat_engine)
423
                    .context("add_chat_completions_model")?;
424
                tracing::info!("Chat completions is ready");
425
            }
426

427
            // Add completions engine only if the model supports completions
428
            if card.model_type.supports_completions() {
429
430
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
431
432
433
434
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
435
436
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
437
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
438
439
440
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
441
                    card,
442
443
444
445
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser,
446
                    preprocessor,
447
                    tokenizer_hf,
448
                    prefill_chooser,
449
                )
450
451
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
452
                self.manager
453
                    .add_completions_model(card.name(), checksum, completions_engine)
454
                    .context("add_completions_model")?;
455
                tracing::info!("Completions is ready");
456
            }
457
458
459
460
461
462
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
            >::from_client_with_threshold(
463
                client, self.router_mode, None, None
464
465
466
467
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
468
                .add_embeddings_model(card.name(), checksum, engine)?;
469
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
470
            // Case 3: Text + Chat
471
472
473
474
475
476
            let push_router =
                PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client_with_threshold(client, self.router_mode, None, None)
                .await?;
477
478
            let engine = Arc::new(push_router);
            self.manager
479
                .add_chat_completions_model(card.name(), checksum, engine)?;
480
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
481
482
483
484
485
            // Case 2: Text + Completions
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
486
                client, self.router_mode, None, None
487
488
489
490
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
491
                .add_completions_model(card.name(), checksum, engine)?;
492
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
493
494
495
496
497
498
499
500
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

501
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
502
            let backend = Backend::from_mdc(card).into_operator();
503
504
505
506
507

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
508
                client, self.router_mode, None, None
509
510
511
            )
            .await?;

512
            // Note: Embeddings don't need KV routing complexity or load monitoring
513
514
515
516
517
518
519
520
521
522
523
524
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
525
                .add_embeddings_model(card.name(), checksum, embedding_engine)?;
526
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
527
            // Case 5: Tensor + Tensor (non-LLM)
528
            // No KV cache concepts - not an LLM model
529
530
531
532
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
533
                client, self.router_mode, None, None
534
535
536
            )
            .await?;
            let engine = Arc::new(push_router);
537
538
            self.manager
                .add_tensor_model(card.name(), checksum, engine)?;
539
540
541
542
543
544
545
546
547
548
549
550
        } else if card.model_type.supports_prefill() {
            // Case 6: Prefill
            // Guardrail: Verify model_input is Tokens
            if card.model_input != ModelInput::Tokens {
                anyhow::bail!(
                    "Prefill models must use ModelInput::Tokens, got {}",
                    card.model_input.as_str()
                );
            }

            tracing::info!(
                model_name = card.name(),
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
                "Prefill model detected, registering and activating prefill router"
            );

            // Register prefill model for tracking (no engine needed, just lifecycle)
            self.manager
                .add_prefill_model(card.name(), checksum)
                .context("add_prefill_model")?;

            // Activate the prefill router with the endpoint for this prefill model
            let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else {
                tracing::warn!(
                    model_name = card.name(),
                    "Failed to activate prefill router - prefill model may already be activated"
                );
                return Ok(());
            };

            tracing::info!(
                model_name = card.name(),
                "Prefill model registered and router activated successfully"
571
            );
572
573
574
575
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
576
                Tokens+(Chat|Completions|Prefill), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
577
578
                card.model_type,
                card.model_input.as_str()
579
            );
580
        }
Ryan Olson's avatar
Ryan Olson committed
581

582
583
        Ok(())
    }
584

585
    /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
586
    async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
587
588
        let discovery = self.drt.discovery();
        let instances = discovery.list(DiscoveryQuery::AllModels).await?;
589

590
591
592
        let mut results = Vec::with_capacity(instances.len());
        for instance in instances {
            match instance.deserialize_model::<ModelDeploymentCard>() {
593
                Ok(card) => {
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                    // Extract EndpointId from the instance
                    let endpoint_id = match &instance {
                        dynamo_runtime::discovery::DiscoveryInstance::Model {
                            namespace,
                            component,
                            endpoint,
                            ..
                        } => EndpointId {
                            namespace: namespace.clone(),
                            component: component.clone(),
                            name: endpoint.clone(),
                        },
                        _ => {
                            tracing::error!(
                                "Unexpected discovery instance type (expected ModelCard)"
                            );
610
611
612
                            continue;
                        }
                    };
613
                    results.push((endpoint_id, card));
614
                }
615
                Err(err) => {
616
                    tracing::error!(%err, "Failed to deserialize model card");
617
618
                    continue;
                }
619
            }
620
        }
621
        Ok(results)
622
623
    }

624
    pub async fn cards_for_model(
625
626
627
628
        &self,
        model_name: &str,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
629
630
631
632
    ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
        let mut all = self.all_cards().await?;
        all.retain(|(endpoint_id, card)| {
            let matches_name = card.name() == model_name;
633
634
635
            let matches_namespace = match (is_global_namespace, target_namespace) {
                (true, _) => true,
                (false, None) => true,
636
                (false, Some(target_ns)) => endpoint_id.namespace == target_ns,
637
638
639
            };
            matches_name && matches_namespace
        });
640
641
642
        Ok(all.into_iter().map(|(_eid, card)| card).collect())
    }
}