"docs/vscode:/vscode.git/clone" did not exist on "5206ab20ba7477e5457c2e64469590d548fa15e6"
watcher.rs 20.9 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::mpsc::Sender;
6

7
use anyhow::Context as _;
8
use tokio::sync::{Notify, mpsc::Receiver};
Ryan Olson's avatar
Ryan Olson committed
9

Neelay Shah's avatar
Neelay Shah committed
10
use dynamo_runtime::{
11
    DistributedRuntime,
12
    pipeline::{
13
14
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
15
    },
16
    protocols::annotated::Annotated,
17
    storage::key_value_store::Key,
18
    transports::etcd::{KeyValue, WatchEvent},
19
};
Ryan Olson's avatar
Ryan Olson committed
20

21
22
use crate::{
    backend::Backend,
23
24
    entrypoint,
    kv_router::KvRouterConfig,
25
    model_card::ModelDeploymentCard,
26
27
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
28
29
30
31
32
33
34
35
36
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
        },
37
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
38
39
    },
};
40

41
use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
42
use crate::namespace::is_global_namespace;
43

44
#[derive(Debug, Clone)]
45
pub enum ModelUpdate {
46
47
    Added(ModelDeploymentCard),
    Removed(ModelDeploymentCard),
48
49
}

50
pub struct ModelWatcher {
51
    manager: Arc<ModelManager>,
52
    drt: DistributedRuntime,
53
    router_mode: RouterMode,
54
    notify_on_model: Notify,
55
    model_update_tx: Option<Sender<ModelUpdate>>,
56
    kv_router_config: Option<KvRouterConfig>,
57
    busy_threshold: Option<f64>,
Ryan Olson's avatar
Ryan Olson committed
58
59
}

60
61
62
63
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
64
    ModelType::TensorBased,
65
];
66

67
impl ModelWatcher {
68
    pub fn new(
69
        runtime: DistributedRuntime,
70
        model_manager: Arc<ModelManager>,
71
        router_mode: RouterMode,
72
        kv_router_config: Option<KvRouterConfig>,
73
        busy_threshold: Option<f64>,
74
75
    ) -> ModelWatcher {
        Self {
76
            manager: model_manager,
77
            drt: runtime,
78
            router_mode,
79
            notify_on_model: Notify::new(),
80
            model_update_tx: None,
81
            kv_router_config,
82
            busy_threshold,
83
        }
84
    }
Ryan Olson's avatar
Ryan Olson committed
85

86
87
88
89
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

90
91
92
93
94
95
96
97
98
99
100
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

101
102
103
    /// Common watch logic with optional namespace filtering
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
104
105
106
107
108
109
110

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
111
112
113
114
115
116
117
118
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
119
120
121
                            continue;
                        }
                    };
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
                        && model_entry.endpoint_id.namespace != target_ns
                    {
                        tracing::debug!(
                            model_namespace = model_entry.endpoint_id.namespace,
                            target_namespace = target_ns,
                            model_name = model_entry.name,
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

137
138
139
140
141
142
143
144
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };

145
                    match self.handle_put(key, &model_entry).await {
146
                        Ok(()) => {
147
148
149
150
151
                            tracing::info!(
                                model_name = model_entry.name,
                                namespace = model_entry.endpoint_id.namespace,
                                "added model"
                            );
152
                            self.notify_on_model.notify_waiters();
153
                        }
154
155
156
                        Err(err) => {
                            tracing::error!(
                                error = format!("{err:#}"),
157
158
159
                                "error adding model {} from namespace {}",
                                model_entry.name,
                                model_entry.endpoint_id.namespace,
160
                            );
161
162
163
                        }
                    }
                }
164
165
166
167
                WatchEvent::Delete(kv) => match self
                    .handle_delete(&kv, target_namespace, global_namespace)
                    .await
                {
168
                    Ok(Some(model_name)) => {
169
                        tracing::info!(model_name, "removed model");
170
                    }
171
172
173
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
174
                    Err(e) => {
175
                        tracing::error!(error = %e, "error removing model");
176
                    }
177
                },
178
            }
Ryan Olson's avatar
Ryan Olson committed
179
180
181
        }
    }

182
183
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
184
185
186
187
188
189
    async fn handle_delete(
        &self,
        kv: &KeyValue,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Option<String>> {
190
        let key = kv.key_str()?;
191
192
        let card = match self.manager.remove_model_card(key) {
            Some(card) => card,
193
            None => {
194
                anyhow::bail!("Missing ModelDeploymentCard for {key}");
195
196
            }
        };
197
        let model_name = card.display_name.clone();
198
        let active_instances = self
199
            .entries_for_model(&model_name, target_namespace, is_global_namespace)
200
201
202
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
203
204
205
206
207
208
            tracing::debug!(
                model_name,
                target_namespace = ?target_namespace,
                active_instance_count = active_instances.len(),
                "Model has other active instances, not removing"
            );
209
210
            return Ok(None);
        }
211

212
        // Ignore the errors because model could be either type
213
214
215
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
216
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
217
218
219
220

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
221
        let mut tensor_model_removed = false;
222
223
224
225
226
227
228
229
230
231
232

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
233
234
235
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
236

237
238
239
240
241
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
            && !tensor_model_removed
        {
242
            tracing::debug!(
243
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
244
245
246
                model_name,
                chat_model_removed,
                completions_model_removed,
247
248
                embeddings_model_removed,
                tensor_model_removed
249
250
251
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
252
                if ((chat_model_removed && *model_type == ModelType::Chat)
253
                    || (completions_model_removed && *model_type == ModelType::Completions)
254
255
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
                    || (tensor_model_removed && *model_type == ModelType::TensorBased))
256
                    && let Some(tx) = &self.model_update_tx
257
                {
258
                    tx.send(ModelUpdate::Removed(card.clone())).await.ok();
259
260
261
                }
            }
        }
262

263
        Ok(Some(model_name))
264
    }
Ryan Olson's avatar
Ryan Olson committed
265

266
267
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
268
    async fn handle_put(&self, key: &str, model_entry: &ModelEntry) -> anyhow::Result<()> {
269
        let endpoint_id = &model_entry.endpoint_id;
270
        let component = self
271
272
            .drt
            .namespace(&endpoint_id.namespace)?
273
            .component(&endpoint_id.component)?;
274
        let client = component.endpoint(&endpoint_id.name).client().await?;
275
        let model_slug = model_entry.slug();
276
277
278
279
280
281
        let card = match ModelDeploymentCard::load_from_store(
            &Key::from_raw(model_slug.to_string()),
            &self.drt,
        )
        .await
        {
282
283
284
285
286
287
288
289
            Ok(Some(mut card)) => {
                tracing::debug!(card.display_name, "adding model");
                // Ensure runtime_config is populated
                if let Some(rc) = model_entry.runtime_config.clone() {
                    card.runtime_config = rc;
                }
                card
            }
290
291
            Ok(None) => {
                anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
292
293
            }
            Err(err) => {
294
295
296
                anyhow::bail!(
                    "Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
                );
297
298
            }
        };
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        self.manager.save_model_card(key, card.clone());

        if self.manager.has_model_any(&model_entry.name) {
            tracing::trace!(
                name = model_entry.name,
                namespace = model_entry.endpoint_id.namespace,
                "New endpoint for existing model"
            );
            self.notify_on_model.notify_waiters();
            return Ok(());
        }

        if let Some(tx) = &self.model_update_tx {
            tx.send(ModelUpdate::Added(card.clone())).await.ok();
        }

        if card.model_input == ModelInput::Tokens
            && (card.model_type.supports_chat() || card.model_type.supports_completions())
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

            let kv_chooser = if self.router_mode == RouterMode::KV {
                Some(
                    self.manager
                        .kv_chooser_for(
                            &model_entry.name,
                            &component,
                            card.kv_cache_block_size,
                            self.kv_router_config,
                        )
                        .await?,
                )
            } else {
                None
            };
337

338
            // This is expensive, we are loading ~10MiB JSON, so only do it once
339
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
340

341
            // Add chat engine only if the model supports chat
342
            if card.model_type.supports_chat() {
343
344
345
346
347
348
349
350
351
                let chat_engine = entrypoint::build_routed_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(
                    &card,
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser.clone(),
352
                    tokenizer_hf.clone(),
353
                )
354
355
                .await
                .context("build_routed_pipeline")?;
356
                self.manager
357
358
                    .add_chat_completions_model(&model_entry.name, chat_engine)
                    .context("add_chat_completions_model")?;
359
                tracing::info!("Chat completions is ready");
360
            }
361

362
            // Add completions engine only if the model supports completions
363
            if card.model_type.supports_completions() {
364
365
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
366
367
368
369
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
370
371
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
372
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
373
374
375
376
377
378
379
380
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
                    &card,
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser,
381
                    preprocessor,
382
                    tokenizer_hf,
383
                )
384
385
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
386
                self.manager
387
388
                    .add_completions_model(&model_entry.name, completions_engine)
                    .context("add_completions_model")?;
389
                tracing::info!("Completions is ready");
390
            }
391
392
393
394
395
396
397
398
399
400
401
402
        } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
            // Case: Text + Embeddings
            let push_router = PushRouter::<
                NvCreateEmbeddingRequest,
                Annotated<NvCreateEmbeddingResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
                .add_embeddings_model(&model_entry.name, engine)?;
403
        } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
404
405
406
407
408
409
410
411
412
413
414
            // Case 3: Text + Chat
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
                .add_chat_completions_model(&model_entry.name, engine)?;
415
        } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
416
417
418
419
420
421
422
423
424
425
426
            // Case 2: Text + Completions
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
                .add_completions_model(&model_entry.name, engine)?;
427
        } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
428
429
430
431
432
433
434
435
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

436
437
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
            let backend = Backend::from_mdc(&card).into_operator();
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;

            // Note: Embeddings don't need KV routing complexity
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
                .add_embeddings_model(&model_entry.name, embedding_engine)?;
461
        } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
462
463
464
465
466
467
468
469
470
471
            // Case 5: Tensor + Tensor (non-LLM)
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager.add_tensor_model(&model_entry.name, engine)?;
472
473
474
475
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
476
                Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
477
478
                card.model_type,
                card.model_input.as_str()
479
            );
480
        }
Ryan Olson's avatar
Ryan Olson committed
481

482
483
        Ok(())
    }
484

485
    /// All the registered ModelEntry, one per instance
486
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
508
        }
509
510
511
        Ok(entries)
    }

512
513
514
515
516
517
    pub async fn entries_for_model(
        &self,
        model_name: &str,
        target_namespace: Option<&str>,
        is_global_namespace: bool,
    ) -> anyhow::Result<Vec<ModelEntry>> {
518
        let mut all = self.all_entries().await?;
519
520
521
522
523
524
525
526
527
        all.retain(|entry| {
            let matches_name = entry.name == model_name;
            let matches_namespace = match (is_global_namespace, target_namespace) {
                (true, _) => true,
                (false, None) => true,
                (false, Some(target_ns)) => entry.endpoint_id.namespace == target_ns,
            };
            matches_name && matches_namespace
        });
528
        Ok(all)
529
    }
Ryan Olson's avatar
Ryan Olson committed
530
}