watcher.rs 23.9 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::mpsc::Sender;
6

7
use anyhow::Context as _;
8
use tokio::sync::{Notify, mpsc::Receiver};
Ryan Olson's avatar
Ryan Olson committed
9

Neelay Shah's avatar
Neelay Shah committed
10
use dynamo_runtime::{
11
    DistributedRuntime,
12
    pipeline::{
13
14
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
15
    },
16
17
    protocols::{EndpointId, annotated::Annotated},
    transports::etcd::WatchEvent,
18
};
Ryan Olson's avatar
Ryan Olson committed
19

20
21
use crate::{
    backend::Backend,
22
23
    entrypoint,
    kv_router::KvRouterConfig,
24
    model_card::{self, ModelDeploymentCard},
25
26
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
27
28
29
30
31
32
33
34
35
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
        },
36
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
37
38
    },
};
39

40
use super::ModelManager;
41
use crate::namespace::is_global_namespace;
42

43
#[derive(Debug, Clone)]
44
pub enum ModelUpdate {
45
46
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
47
48
}

49
pub struct ModelWatcher {
50
    manager: Arc<ModelManager>,
51
    drt: DistributedRuntime,
52
    router_mode: RouterMode,
53
    notify_on_model: Notify,
54
    model_update_tx: Option<Sender<ModelUpdate>>,
55
    kv_router_config: Option<KvRouterConfig>,
56
    busy_threshold: Option<f64>,
Ryan Olson's avatar
Ryan Olson committed
57
58
}

59
60
61
62
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
63
    ModelType::TensorBased,
64
];
65

66
impl ModelWatcher {
67
    pub fn new(
68
        runtime: DistributedRuntime,
69
        model_manager: Arc<ModelManager>,
70
        router_mode: RouterMode,
71
        kv_router_config: Option<KvRouterConfig>,
72
        busy_threshold: Option<f64>,
73
74
    ) -> ModelWatcher {
        Self {
75
            manager: model_manager,
76
            drt: runtime,
77
            router_mode,
78
            notify_on_model: Notify::new(),
79
            model_update_tx: None,
80
            kv_router_config,
81
            busy_threshold,
82
        }
83
    }
Ryan Olson's avatar
Ryan Olson committed
84

85
86
87
88
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

89
90
91
92
93
94
95
96
97
98
99
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

100
101
102
    /// Common watch logic with optional namespace filtering
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
103
104
105
106

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
107
108
                    let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
                        Ok(card) => card,
109
                        Err(err) => {
110
111
                            match kv.value_str() {
                                Ok(value) => {
112
                                    tracing::error!(%err, value, "Invalid JSON in model card")
113
114
                                }
                                Err(value_str_err) => {
115
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
116
117
                                }
                            }
118
119
120
                            continue;
                        }
                    };
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model card key, skipping");
                            continue;
                        }
                    };
                    let endpoint_id = match etcd_key_extract(key) {
                        Ok((eid, _)) => eid,
                        Err(err) => {
                            tracing::error!(%key, model_name = card.name(), %err, "Failed extracting EndpointId from key. Ignoring instance.");
                            continue;
                        }
                    };
135
136
137
138

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
139
                        && endpoint_id.namespace != target_ns
140
141
                    {
                        tracing::debug!(
142
                            model_namespace = endpoint_id.namespace,
143
                            target_namespace = target_ns,
144
                            model_name = card.name(),
145
146
147
148
149
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                    // If we already have a worker for this model, and the ModelDeploymentCard
                    // cards don't match, alert, and don't add the new instance
                    let can_add =
                        self.manager
                            .is_valid_checksum(card.model_type, card.name(), card.mdcsum());
                    if can_add.is_some_and(|is_valid| !is_valid) {
                        tracing::error!(
                            model_name = card.name(),
                            "Checksum for new model does not match existing model."
                        );

                        // TODO: mark that instance down in clients
                        // Not obvious how to do that given the current design
                        // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
                        // Calling `report_instance_down` on the Client should do it (although
                        // needs more testing).
                        // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
                        // interface `AsyncEngine` which only has a `generate` method.

                        continue;
                    }
171

172
                    match self.handle_put(key, &endpoint_id, &mut card).await {
173
                        Ok(()) => {
174
                            tracing::info!(
175
176
                                model_name = card.name(),
                                namespace = endpoint_id.namespace,
177
178
                                "added model"
                            );
179
                            self.notify_on_model.notify_waiters();
180
                        }
181
182
                        Err(err) => {
                            tracing::error!(
183
184
                                model_name = card.name(),
                                namespace = endpoint_id.namespace,
185
                                error = format!("{err:#}"),
186
                                "Error adding model from discovery",
187
                            );
188
189
190
                        }
                    }
                }
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                WatchEvent::Delete(kv) => {
                    let Ok(deleted_key) = kv.key_str() else {
                        tracing::warn!("Invalid UTF-8 in etcd delete notification key: {kv:?}");
                        continue;
                    };
                    match self
                        .handle_delete(deleted_key, target_namespace, global_namespace)
                        .await
                    {
                        Ok(Some(model_name)) => {
                            tracing::info!(model_name, "removed model");
                        }
                        Ok(None) => {
                            // There are other instances running this model, nothing to do
                        }
                        Err(e) => {
                            tracing::error!(error = %e, "error removing model");
                        }
209
                    }
210
                }
211
            }
Ryan Olson's avatar
Ryan Olson committed
212
213
214
        }
    }

215
216
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
217
218
    async fn handle_delete(
        &self,
219
        key: &str,
220
221
222
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Option<String>> {
223
224
        let card = match self.manager.remove_model_card(key) {
            Some(card) => card,
225
            None => {
226
                anyhow::bail!("Missing ModelDeploymentCard for {key}");
227
228
            }
        };
229
        let model_name = card.name().to_string();
230
        let active_instances = self
231
            .cards_for_model(&model_name, target_namespace, is_global_namespace)
232
233
234
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
235
236
237
238
239
240
            tracing::debug!(
                model_name,
                target_namespace = ?target_namespace,
                active_instance_count = active_instances.len(),
                "Model has other active instances, not removing"
            );
241
242
            return Ok(None);
        }
243

244
        // Ignore the errors because model could be either type
245
246
247
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
248
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
249
250
251
252

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
253
        let mut tensor_model_removed = false;
254
255
256
257
258
259
260
261
262
263
264

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
265
266
267
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
268

269
270
271
272
273
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
            && !tensor_model_removed
        {
274
            tracing::debug!(
275
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
276
277
278
                model_name,
                chat_model_removed,
                completions_model_removed,
279
280
                embeddings_model_removed,
                tensor_model_removed
281
282
283
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
284
                if ((chat_model_removed && *model_type == ModelType::Chat)
285
                    || (completions_model_removed && *model_type == ModelType::Completions)
286
287
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
                    || (tensor_model_removed && *model_type == ModelType::TensorBased))
288
                    && let Some(tx) = &self.model_update_tx
289
                {
290
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
291
292
293
                }
            }
        }
294

295
        Ok(Some(model_name))
296
    }
Ryan Olson's avatar
Ryan Olson committed
297

298
299
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
300
301
302
303
304
305
306
    async fn handle_put(
        &self,
        key: &str,
        endpoint_id: &EndpointId,
        card: &mut ModelDeploymentCard,
    ) -> anyhow::Result<()> {
        card.move_from_nats(self.drt.nats_client()).await?;
307
        let component = self
308
309
            .drt
            .namespace(&endpoint_id.namespace)?
310
            .component(&endpoint_id.component)?;
311
        let client = component.endpoint(&endpoint_id.name).client().await?;
312
313
        tracing::debug!(model_name = card.name(), "adding model");
        self.manager.save_model_card(key, card.clone())?;
314

315
316
317
318
        if self.manager.has_model_any(card.name()) {
            tracing::debug!(
                model_name = card.name(),
                namespace = endpoint_id.namespace,
319
320
321
322
323
324
325
326
                "New endpoint for existing model"
            );
            return Ok(());
        }

        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }
327
        let checksum = card.mdcsum();
328
329
330

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
331
332
333
334
335
336
337
338
339
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

            let kv_chooser = if self.router_mode == RouterMode::KV {
                Some(
                    self.manager
                        .kv_chooser_for(
340
                            card.name(),
341
342
343
344
345
346
347
348
349
                            &component,
                            card.kv_cache_block_size,
                            self.kv_router_config,
                        )
                        .await?,
                )
            } else {
                None
            };
350

351
            // This is expensive, we are loading ~10MiB JSON, so only do it once
352
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
353

354
            // Add chat engine only if the model supports chat
355
            if card.model_type.supports_chat() {
356
357
358
359
                let chat_engine = entrypoint::build_routed_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(
360
                    card,
361
362
363
364
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser.clone(),
365
                    tokenizer_hf.clone(),
366
                )
367
368
                .await
                .context("build_routed_pipeline")?;
369
                self.manager
370
                    .add_chat_completions_model(card.name(), checksum, chat_engine)
371
                    .context("add_chat_completions_model")?;
372
                tracing::info!("Chat completions is ready");
373
            }
374

375
            // Add completions engine only if the model supports completions
376
            if card.model_type.supports_completions() {
377
378
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
379
380
381
382
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
383
384
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
385
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
386
387
388
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
389
                    card,
390
391
392
393
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser,
394
                    preprocessor,
395
                    tokenizer_hf,
396
                )
397
398
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
399
                self.manager
400
                    .add_completions_model(card.name(), checksum, completions_engine)
401
                    .context("add_completions_model")?;
402
                tracing::info!("Completions is ready");
403
            }
404
405
406
407
408
409
410
411
412
413
414
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
415
                .add_embeddings_model(card.name(), checksum, engine)?;
416
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
417
418
419
420
421
422
423
424
425
426
            // Case 3: Text + Chat
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
427
                .add_chat_completions_model(card.name(), checksum, engine)?;
428
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
429
430
431
432
433
434
435
436
437
438
            // Case 2: Text + Completions
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
439
                .add_completions_model(card.name(), checksum, engine)?;
440
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
441
442
443
444
445
446
447
448
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

449
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
450
            let backend = Backend::from_mdc(card).into_operator();
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;

            // Note: Embeddings don't need KV routing complexity
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
473
                .add_embeddings_model(card.name(), checksum, embedding_engine)?;
474
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
475
476
477
478
479
480
481
482
483
            // Case 5: Tensor + Tensor (non-LLM)
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
484
485
            self.manager
                .add_tensor_model(card.name(), checksum, engine)?;
486
487
488
489
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
490
                Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
491
492
                card.model_type,
                card.model_input.as_str()
493
            );
494
        }
Ryan Olson's avatar
Ryan Olson committed
495

496
497
        Ok(())
    }
498

499
500
    /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
    pub async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
501
        let Some(etcd_client) = self.drt.etcd_client() else {
502
            anyhow::bail!("all_cards: Missing etcd client");
503
        };
504
505
        let kvs = etcd_client.kv_get_prefix(model_card::ROOT_PATH).await?;
        let mut results = Vec::with_capacity(kvs.len());
506
        for kv in kvs {
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
            let maybe_convert = serde_json::from_slice::<ModelDeploymentCard>(kv.value());
            let r = match maybe_convert {
                Ok(card) => {
                    let maybe_endpoint_id = kv.key_str().map_err(|err| err.into()).and_then(|k| {
                        etcd_key_extract(k).map(|(endpoint_id, _instance_id)| endpoint_id)
                    });
                    let endpoint_id = match maybe_endpoint_id {
                        Ok(eid) => eid,
                        Err(err) => {
                            tracing::error!(%err, "Skipping invalid etcd key, not string or not EndpointId");
                            continue;
                        }
                    };
                    (endpoint_id, card)
                }
522
523
524
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
525
                            tracing::error!(%err, value, "Invalid JSON in model card");
526
527
                        }
                        Err(value_str_err) => {
528
                            tracing::error!(original_error=%err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON");
529
530
531
532
533
                        }
                    }
                    continue;
                }
            };
534
            results.push(r);
535
        }
536
        Ok(results)
537
538
    }

539
    pub async fn cards_for_model(
540
541
542
543
        &self,
        model_name: &str,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
544
545
546
547
    ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
        let mut all = self.all_cards().await?;
        all.retain(|(endpoint_id, card)| {
            let matches_name = card.name() == model_name;
548
549
550
            let matches_namespace = match (is_global_namespace, target_namespace) {
                (true, _) => true,
                (false, None) => true,
551
                (false, Some(target_ns)) => endpoint_id.namespace == target_ns,
552
553
554
            };
            matches_name && matches_namespace
        });
555
556
557
558
559
560
561
        Ok(all.into_iter().map(|(_eid, card)| card).collect())
    }
}

/// The ModelDeploymentCard is published in etcd with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn etcd_key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
562
563
564
    if !s.starts_with(model_card::ROOT_PATH) {
        anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
    }
565
566
    let parts: Vec<&str> = s.split('/').collect();

567
568
    // Need at least prefix model_card::ROOT_PATH (2 parts) + namespace, component, name (3 parts)
    if parts.len() <= 5 {
569
570
571
572
        anyhow::bail!("Invalid format: not enough path segments in {s}");
    }

    let endpoint_id = EndpointId {
573
574
575
        namespace: parts[2].to_string(),
        component: parts[3].to_string(),
        name: parts[4].to_string(),
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    };
    Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_etcd_key_extract() {
        let input = format!(
            "{}/dynamo/backend/generate/694d9981145a61ad",
            model_card::ROOT_PATH
        );
        let (endpoint_id, _) = etcd_key_extract(&input).unwrap();
        assert_eq!(endpoint_id.namespace, "dynamo");
        assert_eq!(endpoint_id.component, "backend");
        assert_eq!(endpoint_id.name, "generate");
594
    }
Ryan Olson's avatar
Ryan Olson committed
595
}