watcher.rs 19.7 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5
use tokio::sync::mpsc::Sender;
6

7
use anyhow::Context as _;
8
use tokio::sync::{Notify, mpsc::Receiver};
Ryan Olson's avatar
Ryan Olson committed
9

Neelay Shah's avatar
Neelay Shah committed
10
use dynamo_runtime::{
11
    DistributedRuntime,
12
    pipeline::{
13
14
        ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
        network::egress::push_router::PushRouter,
15
    },
16
17
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
18
};
Ryan Olson's avatar
Ryan Olson committed
19

20
21
use crate::{
    backend::Backend,
22
23
    entrypoint,
    kv_router::KvRouterConfig,
24
    model_card::ModelDeploymentCard,
25
26
    model_type::{ModelInput, ModelType},
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
27
28
29
30
31
32
33
34
35
    protocols::{
        common::llm_backend::EmbeddingsEngineOutput,
        openai::{
            chat_completions::{
                NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            },
            completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
            embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
        },
36
        tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
37
38
    },
};
39

40
use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
41
use crate::namespace::is_global_namespace;
42

43
44
45
46
47
48
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelUpdate {
    Added(ModelType),
    Removed(ModelType),
}

49
pub struct ModelWatcher {
50
    manager: Arc<ModelManager>,
51
    drt: DistributedRuntime,
52
    router_mode: RouterMode,
53
    notify_on_model: Notify,
54
    model_update_tx: Option<Sender<ModelUpdate>>,
55
    kv_router_config: Option<KvRouterConfig>,
56
    busy_threshold: Option<f64>,
Ryan Olson's avatar
Ryan Olson committed
57
58
}

59
60
61
62
const ALL_MODEL_TYPES: &[ModelType] = &[
    ModelType::Chat,
    ModelType::Completions,
    ModelType::Embedding,
63
    ModelType::TensorBased,
64
];
65

66
impl ModelWatcher {
67
    pub fn new(
68
        runtime: DistributedRuntime,
69
        model_manager: Arc<ModelManager>,
70
        router_mode: RouterMode,
71
        kv_router_config: Option<KvRouterConfig>,
72
        busy_threshold: Option<f64>,
73
74
    ) -> ModelWatcher {
        Self {
75
            manager: model_manager,
76
            drt: runtime,
77
            router_mode,
78
            notify_on_model: Notify::new(),
79
            model_update_tx: None,
80
            kv_router_config,
81
            busy_threshold,
82
        }
83
    }
Ryan Olson's avatar
Ryan Olson committed
84

85
86
87
88
    pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
        self.model_update_tx = Some(tx);
    }

89
90
91
92
93
94
95
96
97
98
99
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

100
101
102
    /// Common watch logic with optional namespace filtering
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
        let global_namespace = target_namespace.is_none_or(is_global_namespace);
103
104
105
106
107
108
109

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
110
111
112
113
114
115
116
117
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
118
119
120
                            continue;
                        }
                    };
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

                    // Filter by namespace if target_namespace is specified
                    if !global_namespace
                        && let Some(target_ns) = target_namespace
                        && model_entry.endpoint_id.namespace != target_ns
                    {
                        tracing::debug!(
                            model_namespace = model_entry.endpoint_id.namespace,
                            target_namespace = target_ns,
                            model_name = model_entry.name,
                            "Skipping model from different namespace"
                        );
                        continue;
                    }

136
137
138
139
140
141
142
143
144
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

145
146
147
148
149
150
                    if let Some(tx) = &self.model_update_tx {
                        tx.send(ModelUpdate::Added(model_entry.model_type))
                            .await
                            .ok();
                    }

151
                    if self.manager.has_model_any(&model_entry.name) {
152
153
154
155
156
                        tracing::trace!(
                            name = model_entry.name,
                            namespace = model_entry.endpoint_id.namespace,
                            "New endpoint for existing model"
                        );
157
                        self.notify_on_model.notify_waiters();
158
159
160
                        continue;
                    }

161
                    match self.handle_put(&model_entry).await {
162
                        Ok(()) => {
163
164
165
166
167
                            tracing::info!(
                                model_name = model_entry.name,
                                namespace = model_entry.endpoint_id.namespace,
                                "added model"
                            );
168
                            self.notify_on_model.notify_waiters();
169
                        }
170
171
172
                        Err(err) => {
                            tracing::error!(
                                error = format!("{err:#}"),
173
174
175
                                "error adding model {} from namespace {}",
                                model_entry.name,
                                model_entry.endpoint_id.namespace,
176
                            );
177
178
179
                        }
                    }
                }
180
181
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
182
                        tracing::info!(model_name, "removed model");
183
                    }
184
185
186
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
187
                    Err(e) => {
188
                        tracing::error!(error = %e, "error removing model");
189
                    }
190
                },
191
            }
Ryan Olson's avatar
Ryan Olson committed
192
193
194
        }
    }

195
196
197
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
198
        let key = kv.key_str()?;
199
        let model_entry = match self.manager.remove_model_entry(key) {
200
201
202
203
204
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
205
206
207
208
209
210
211
212
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
213

214
        // Ignore the errors because model could be either type
215
216
217
        let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
        let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
        let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
218
        let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
219
220
221
222

        let mut chat_model_removed = false;
        let mut completions_model_removed = false;
        let mut embeddings_model_removed = false;
223
        let mut tensor_model_removed = false;
224
225
226
227
228
229
230
231
232
233
234

        if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
            chat_model_removed = true;
        }
        if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
        {
            completions_model_removed = true;
        }
        if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
            embeddings_model_removed = true;
        }
235
236
237
        if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
            tensor_model_removed = true;
        }
238

239
240
241
242
243
        if !chat_model_removed
            && !completions_model_removed
            && !embeddings_model_removed
            && !tensor_model_removed
        {
244
            tracing::debug!(
245
                "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
246
247
248
                model_name,
                chat_model_removed,
                completions_model_removed,
249
250
                embeddings_model_removed,
                tensor_model_removed
251
252
253
            );
        } else {
            for model_type in ALL_MODEL_TYPES {
254
                if ((chat_model_removed && *model_type == ModelType::Chat)
255
                    || (completions_model_removed && *model_type == ModelType::Completions)
256
257
                    || (embeddings_model_removed && *model_type == ModelType::Embedding)
                    || (tensor_model_removed && *model_type == ModelType::TensorBased))
258
                    && let Some(tx) = &self.model_update_tx
259
                {
260
                    tx.send(ModelUpdate::Removed(*model_type)).await.ok();
261
262
263
                }
            }
        }
264

265
        Ok(Some(model_name))
266
    }
Ryan Olson's avatar
Ryan Olson committed
267

268
269
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
270
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
271
        let endpoint_id = &model_entry.endpoint_id;
272
        let component = self
273
274
            .drt
            .namespace(&endpoint_id.namespace)?
275
            .component(&endpoint_id.component)?;
276
        let client = component.endpoint(&endpoint_id.name).client().await?;
277
278
        let model_slug = model_entry.slug();
        let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await {
279
280
281
282
283
284
285
286
            Ok(Some(mut card)) => {
                tracing::debug!(card.display_name, "adding model");
                // Ensure runtime_config is populated
                if let Some(rc) = model_entry.runtime_config.clone() {
                    card.runtime_config = rc;
                }
                card
            }
287
288
            Ok(None) => {
                anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
289
290
            }
            Err(err) => {
291
292
293
                anyhow::bail!(
                    "Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
                );
294
295
            }
        };
296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        if model_entry.model_input == ModelInput::Tokens
            && (model_entry.model_type.supports_chat()
                || model_entry.model_type.supports_completions())
        {
            // Case 1: Tokens + (Chat OR Completions OR Both)
            // A model that expects pre-processed requests meaning it's up to us whether we
            // handle Chat or Completions requests, so handle whatever the model supports.

            let kv_chooser = if self.router_mode == RouterMode::KV {
                Some(
                    self.manager
                        .kv_chooser_for(
                            &model_entry.name,
                            &component,
                            card.kv_cache_block_size,
                            self.kv_router_config,
                        )
                        .await?,
                )
            } else {
                None
            };
319

320
            // This is expensive, we are loading ~10MiB JSON, so only do it once
321
            let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
322

323
324
            // Add chat engine only if the model supports chat
            if model_entry.model_type.supports_chat() {
325
326
327
328
329
330
331
332
333
                let chat_engine = entrypoint::build_routed_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(
                    &card,
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser.clone(),
334
                    tokenizer_hf.clone(),
335
                )
336
337
                .await
                .context("build_routed_pipeline")?;
338
                self.manager
339
340
                    .add_chat_completions_model(&model_entry.name, chat_engine)
                    .context("add_chat_completions_model")?;
341
                tracing::info!("Chat completions is ready");
342
            }
343

344
345
346
347
            // Add completions engine only if the model supports completions
            if model_entry.model_type.supports_completions() {
                let formatter = PromptFormatter::no_op();
                let PromptFormatter::OAI(formatter) = formatter;
348
349
350
351
                let preprocessor = OpenAIPreprocessor::new_with_parts(
                    card.clone(),
                    formatter,
                    tokenizer_hf.clone(),
352
353
                )
                .context("OpenAIPreprocessor::new_with_parts")?;
354
                let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
355
356
357
358
359
360
361
362
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(
                    &card,
                    &client,
                    self.router_mode,
                    self.busy_threshold,
                    kv_chooser,
363
                    preprocessor,
364
                    tokenizer_hf,
365
                )
366
367
                .await
                .context("build_routed_pipeline_with_preprocessor")?;
368
                self.manager
369
370
                    .add_completions_model(&model_entry.name, completions_engine)
                    .context("add_completions_model")?;
371
                tracing::info!("Completions is ready");
372
            }
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        } else if model_entry.model_input == ModelInput::Text
            && model_entry.model_type.supports_chat()
        {
            // Case 3: Text + Chat
            let push_router = PushRouter::<
                NvCreateChatCompletionRequest,
                Annotated<NvCreateChatCompletionStreamResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
                .add_chat_completions_model(&model_entry.name, engine)?;
        } else if model_entry.model_input == ModelInput::Text
            && model_entry.model_type.supports_completions()
        {
            // Case 2: Text + Completions
            let push_router = PushRouter::<
                NvCreateCompletionRequest,
                Annotated<NvCreateCompletionResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager
                .add_completions_model(&model_entry.name, engine)?;
        } else if model_entry.model_input == ModelInput::Tokens
            && model_entry.model_type.supports_embedding()
        {
            // Case 4: Tokens + Embeddings

            // Create preprocessing pipeline similar to Backend
            let frontend = SegmentSource::<
                SingleIn<NvCreateEmbeddingRequest>,
                ManyOut<Annotated<NvCreateEmbeddingResponse>>,
            >::new();

412
413
            let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
            let backend = Backend::from_mdc(&card).into_operator();
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436

            let router = PushRouter::<
                PreprocessedEmbeddingRequest,
                Annotated<EmbeddingsEngineOutput>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;

            // Note: Embeddings don't need KV routing complexity
            let service_backend = ServiceBackend::from_engine(Arc::new(router));

            // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
            let embedding_engine = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(service_backend)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            self.manager
                .add_embeddings_model(&model_entry.name, embedding_engine)?;
437
438
439
440
441
442
443
444
445
446
447
448
449
        } else if model_entry.model_input == ModelInput::Tensor
            && model_entry.model_type.supports_tensor()
        {
            // Case 5: Tensor + Tensor (non-LLM)
            let push_router = PushRouter::<
                NvCreateTensorRequest,
                Annotated<NvCreateTensorResponse>,
            >::from_client_with_threshold(
                client, self.router_mode, self.busy_threshold
            )
            .await?;
            let engine = Arc::new(push_router);
            self.manager.add_tensor_model(&model_entry.name, engine)?;
450
451
452
453
        } else {
            // Reject unsupported combinations
            anyhow::bail!(
                "Unsupported model configuration: {} with {} input. Supported combinations: \
454
                Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
455
456
457
                model_entry.model_type,
                model_entry.model_input.as_str()
            );
458
        }
Ryan Olson's avatar
Ryan Olson committed
459

460
461
        Ok(())
    }
462

463
    /// All the registered ModelEntry, one per instance
464
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
486
        }
487
488
489
        Ok(entries)
    }

490
    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
491
492
493
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
494
    }
Ryan Olson's avatar
Ryan Olson committed
495
}