watcher.rs 24.6 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::mpsc::Sender;
6

7
use anyhow::Context as _;
8
use tokio::sync::{Notify, mpsc::Receiver};
Ryan Olson's avatar
Ryan Olson committed
9

Neelay Shah's avatar
Neelay Shah committed
10
use dynamo_runtime::{
11
    DistributedRuntime,
12
    pipeline::{
13
14
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
15
    },
16
17
    protocols::{EndpointId, annotated::Annotated},
    transports::etcd::WatchEvent,
18
};
Ryan Olson's avatar
Ryan Olson committed
19

20
21
use crate::{
    backend::Backend,
22
23
    entrypoint,
    kv_router::KvRouterConfig,
24
    model_card::{self, ModelDeploymentCard},
25
26
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
27
28
29
30
31
32
33
34
35
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
        },
36
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
37
38
    },
};
39

40
use super::ModelManager;
41
use crate::namespace::is_global_namespace;
42

43
#[derive(Debug, Clone)]
44
pub enum ModelUpdate {
45
46
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
47
48
}

49
pub struct ModelWatcher {
50
    manager: Arc<ModelManager>,
51
    drt: DistributedRuntime,
52
    router_mode: RouterMode,
53
    notify_on_model: Notify,
54
    model_update_tx: Option<Sender<ModelUpdate>>,
55
    kv_router_config: Option<KvRouterConfig>,
56
    busy_threshold: Option<f64>,
Ryan Olson's avatar
Ryan Olson committed
57
58
}

59
60
61
62
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
63
    ModelType::TensorBased,
64
];
65

66
impl ModelWatcher {
67
    pub fn new(
68
        runtime: DistributedRuntime,
69
        model_manager: Arc<ModelManager>,
70
        router_mode: RouterMode,
71
        kv_router_config: Option<KvRouterConfig>,
72
        busy_threshold: Option<f64>,
73
74
    ) -> ModelWatcher {
        Self {
75
            manager: model_manager,
76
            drt: runtime,
77
            router_mode,
78
            notify_on_model: Notify::new(),
79
            model_update_tx: None,
80
            kv_router_config,
81
            busy_threshold,
82
        }
83
    }
Ryan Olson's avatar
Ryan Olson committed
84

85
86
87
88
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

89
90
91
92
93
94
95
96
97
98
99
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

100
101
102
    /// Common watch logic with optional namespace filtering
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
103
104
105
106

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
107
108
                    let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
                        Ok(card) => card,
109
                        Err(err) => {
110
111
                            match kv.value_str() {
                                Ok(value) => {
112
                                    tracing::error!(%err, value, "Invalid JSON in model card")
113
114
                                }
                                Err(value_str_err) => {
115
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
116
117
                                }
                            }
118
119
120
                            continue;
                        }
                    };
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model card key, skipping");
                            continue;
                        }
                    };
                    let endpoint_id = match etcd_key_extract(key) {
                        Ok((eid, _)) => eid,
                        Err(err) => {
                            tracing::error!(%key, model_name = card.name(), %err, "Failed extracting EndpointId from key. Ignoring instance.");
                            continue;
                        }
                    };
135
136
137
138

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
139
                        && endpoint_id.namespace != target_ns
140
141
                    {
                        tracing::debug!(
142
                            model_namespace = endpoint_id.namespace,
143
                            target_namespace = target_ns,
144
                            model_name = card.name(),
145
146
147
148
149
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                    // If we already have a worker for this model, and the ModelDeploymentCard
                    // cards don't match, alert, and don't add the new instance
                    let can_add =
                        self.manager
                            .is_valid_checksum(card.model_type, card.name(), card.mdcsum());
                    if can_add.is_some_and(|is_valid| !is_valid) {
                        tracing::error!(
                            model_name = card.name(),
                            "Checksum for new model does not match existing model."
                        );

                        // TODO: mark that instance down in clients
                        // Not obvious how to do that given the current design
                        // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
                        // Calling `report_instance_down` on the Client should do it (although
                        // needs more testing).
                        // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
                        // interface `AsyncEngine` which only has a `generate` method.

                        continue;
                    }
171

172
                    match self.handle_put(key, &endpoint_id, &mut card).await {
173
                        Ok(()) => {
174
                            tracing::info!(
175
176
                                model_name = card.name(),
                                namespace = endpoint_id.namespace,
177
178
                                "added model"
                            );
179
                            self.notify_on_model.notify_waiters();
180
                        }
181
182
                        Err(err) => {
                            tracing::error!(
183
184
                                model_name = card.name(),
                                namespace = endpoint_id.namespace,
185
                                error = format!("{err:#}"),
186
                                "Error adding model from discovery",
187
                            );
188
189
190
                        }
                    }
                }
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                WatchEvent::Delete(kv) => {
                    let Ok(deleted_key) = kv.key_str() else {
                        tracing::warn!("Invalid UTF-8 in etcd delete notification key: {kv:?}");
                        continue;
                    };
                    match self
                        .handle_delete(deleted_key, target_namespace, global_namespace)
                        .await
                    {
                        Ok(Some(model_name)) => {
                            tracing::info!(model_name, "removed model");
                        }
                        Ok(None) => {
                            // There are other instances running this model, nothing to do
                        }
                        Err(e) => {
                            tracing::error!(error = %e, "error removing model");
                        }
209
                    }
210
                }
211
            }
Ryan Olson's avatar
Ryan Olson committed
212
213
214
        }
    }

215
216
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
217
218
    async fn handle_delete(
        &self,
219
        key: &str,
220
221
222
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Option<String>> {
223
224
        let card = match self.manager.remove_model_card(key) {
            Some(card) => card,
225
            None => {
226
                anyhow::bail!("Missing ModelDeploymentCard for {key}");
227
228
            }
        };
229
        let model_name = card.name().to_string();
230
        let active_instances = self
231
            .cards_for_model(&model_name, target_namespace, is_global_namespace)
232
233
234
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
235
236
237
238
239
240
            tracing::debug!(
                model_name,
                target_namespace = ?target_namespace,
                active_instance_count = active_instances.len(),
                "Model has other active instances, not removing"
            );
241
242
            return Ok(None);
        }
243

244
        // Ignore the errors because model could be either type
245
246
247
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
248
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
249
250
251
252

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
253
        let mut tensor_model_removed = false;
254
255
256
257
258
259
260
261
262
263
264

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
265
266
267
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
268

269
270
271
272
273
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
            && !tensor_model_removed
        {
274
            tracing::debug!(
275
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
276
277
278
                model_name,
                chat_model_removed,
                completions_model_removed,
279
280
                embeddings_model_removed,
                tensor_model_removed
281
282
283
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
284
                if ((chat_model_removed && *model_type == ModelType::Chat)
285
                    || (completions_model_removed && *model_type == ModelType::Completions)
286
287
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
                    || (tensor_model_removed && *model_type == ModelType::TensorBased))
288
                    && let Some(tx) = &self.model_update_tx
289
                {
290
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
291
292
293
                }
            }
        }
294

295
        Ok(Some(model_name))
296
    }
Ryan Olson's avatar
Ryan Olson committed
297

298
299
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
300
301
302
303
304
305
306
    async fn handle_put(
        &self,
        key: &str,
        endpoint_id: &EndpointId,
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
        card.move_from_nats(self.drt.nats_client()).await?;
307
        let component = self
308
309
            .drt
            .namespace(&endpoint_id.namespace)?
310
            .component(&endpoint_id.component)?;
311
        let client = component.endpoint(&endpoint_id.name).client().await?;
312
313
        tracing::debug!(model_name = card.name(), "adding model");
        self.manager.save_model_card(key, card.clone())?;
314

315
316
317
318
        if self.manager.has_model_any(card.name()) {
            tracing::debug!(
                model_name = card.name(),
                namespace = endpoint_id.namespace,
319
320
                "New endpoint for existing model"
            );
321
            //self.notify_on_model.notify_waiters();
322
323
324
325
326
327
            return Ok(());
        }

        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }
328
        let checksum = card.mdcsum();
329
330
331

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
332
333
334
335
336
337
338
339
340
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

            let kv_chooser = if self.router_mode == RouterMode::KV {
                Some(
                    self.manager
                        .kv_chooser_for(
341
                            card.name(),
342
343
344
345
346
347
348
349
350
                            &component,
                            card.kv_cache_block_size,
                            self.kv_router_config,
                        )
                        .await?,
                )
            } else {
                None
            };
351

352
            // This is expensive, we are loading ~10MiB JSON, so only do it once
353
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
354

355
            // Add chat engine only if the model supports chat
356
            if card.model_type.supports_chat() {
357
358
359
360
                let chat_engine = entrypoint::build_routed_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(
361
                    card,
362
363
364
365
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser.clone(),
366
                    tokenizer_hf.clone(),
367
                )
368
369
                .await
                .context("build_routed_pipeline")?;
370
                self.manager
371
                    .add_chat_completions_model(card.name(), checksum, chat_engine)
372
                    .context("add_chat_completions_model")?;
373
                tracing::info!("Chat completions is ready");
374
            }
375

376
            // Add completions engine only if the model supports completions
377
            if card.model_type.supports_completions() {
378
379
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
380
381
382
383
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
384
385
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
386
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
387
388
389
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
390
                    card,
391
392
393
394
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser,
395
                    preprocessor,
396
                    tokenizer_hf,
397
                )
398
399
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
400
                self.manager
401
                    .add_completions_model(card.name(), checksum, completions_engine)
402
                    .context("add_completions_model")?;
403
                tracing::info!("Completions is ready");
404
            }
405
406
407
408
409
410
411
412
413
414
415
416
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
                .add_embeddings_model(&model_entry.name, engine)?;
417
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
418
419
420
421
422
423
424
425
426
427
            // Case 3: Text + Chat
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
428
                .add_chat_completions_model(card.name(), checksum, engine)?;
429
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
430
431
432
433
434
435
436
437
438
439
            // Case 2: Text + Completions
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
440
                .add_completions_model(card.name(), checksum, engine)?;
441
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
442
443
444
445
446
447
448
449
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

450
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
451
            let backend = Backend::from_mdc(card).into_operator();
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;

            // Note: Embeddings don't need KV routing complexity
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
474
                .add_embeddings_model(card.name(), checksum, embedding_engine)?;
475
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
476
477
478
479
480
481
482
483
484
            // Case 5: Tensor + Tensor (non-LLM)
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
485
486
            self.manager
                .add_tensor_model(card.name(), checksum, engine)?;
487
488
489
490
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
491
                Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
492
493
                card.model_type,
                card.model_input.as_str()
494
            );
495
        }
Ryan Olson's avatar
Ryan Olson committed
496

497
498
        Ok(())
    }
499

500
501
    /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
    pub async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
502
        let Some(etcd_client) = self.drt.etcd_client() else {
503
            anyhow::bail!("all_cards: Missing etcd client");
504
        };
505
506
        let kvs = etcd_client.kv_get_prefix(model_card::ROOT_PATH).await?;
        let mut results = Vec::with_capacity(kvs.len());
507
        for kv in kvs {
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
            let maybe_convert = serde_json::from_slice::<ModelDeploymentCard>(kv.value());
            let r = match maybe_convert {
                Ok(card) => {
                    let maybe_endpoint_id = kv.key_str().map_err(|err| err.into()).and_then(|k| {
                        etcd_key_extract(k).map(|(endpoint_id, _instance_id)| endpoint_id)
                    });
                    let endpoint_id = match maybe_endpoint_id {
                        Ok(eid) => eid,
                        Err(err) => {
                            tracing::error!(%err, "Skipping invalid etcd key, not string or not EndpointId");
                            continue;
                        }
                    };
                    (endpoint_id, card)
                }
523
524
525
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
526
                            tracing::error!(%err, value, "Invalid JSON in model card");
527
528
                        }
                        Err(value_str_err) => {
529
                            tracing::error!(original_error=%err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON");
530
531
532
533
534
                        }
                    }
                    continue;
                }
            };
535
            results.push(r);
536
        }
537
        Ok(results)
538
539
    }

540
    pub async fn cards_for_model(
541
542
543
544
        &self,
        model_name: &str,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
545
546
547
548
    ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
        let mut all = self.all_cards().await?;
        all.retain(|(endpoint_id, card)| {
            let matches_name = card.name() == model_name;
549
550
551
            let matches_namespace = match (is_global_namespace, target_namespace) {
                (true, _) => true,
                (false, None) => true,
552
                (false, Some(target_ns)) => endpoint_id.namespace == target_ns,
553
554
555
            };
            matches_name && matches_namespace
        });
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        Ok(all.into_iter().map(|(_eid, card)| card).collect())
    }
}

/// The ModelDeploymentCard is published in etcd with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn etcd_key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
    let parts: Vec<&str> = s.split('/').collect();
    let start_idx = if !parts.is_empty() && parts[0] == "v1" {
        1
    } else {
        0
    };

    // Need at least prefix model_card::ROOT_PATH + 3 parts: namespace, component, name
    if parts.len() <= start_idx + 3 {
        anyhow::bail!("Invalid format: not enough path segments in {s}");
    }

    if parts.get(start_idx) != Some(&model_card::ROOT_PATH) {
        anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
    }

    let endpoint_id = EndpointId {
        namespace: parts[start_idx + 1].to_string(),
        component: parts[start_idx + 2].to_string(),
        name: parts[start_idx + 3].to_string(),
    };
    Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_etcd_key_extract() {
        let input = format!(
            "v1/{}/dynamo/backend/generate/694d9981145a61ad",
            model_card::ROOT_PATH
        );
        let (endpoint_id, instance_id) = etcd_key_extract(&input).unwrap();
        assert_eq!(endpoint_id.namespace, "dynamo");
        assert_eq!(endpoint_id.component, "backend");
        assert_eq!(endpoint_id.name, "generate");
        assert_eq!(instance_id, "694d9981145a61ad");

        let input = format!(
            "{}/dynamo/backend/generate/694d9981145a61ad",
            model_card::ROOT_PATH
        );
        let (endpoint_id, _) = etcd_key_extract(&input).unwrap();
        assert_eq!(endpoint_id.namespace, "dynamo");
        assert_eq!(endpoint_id.component, "backend");
        assert_eq!(endpoint_id.name, "generate");
611
    }
Ryan Olson's avatar
Ryan Olson committed
612
}