watcher.rs 14.2 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5

6
use anyhow::Context as _;
7
use tokio::sync::{mpsc::Receiver, Notify};
Ryan Olson's avatar
Ryan Olson committed
8

Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
    pipeline::{
11
12
        network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
        ServiceBackend, SingleIn, Source,
13
    },
14
15
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
16
    DistributedRuntime,
17
};
Ryan Olson's avatar
Ryan Olson committed
18

19
20
use crate::{
    backend::Backend,
21
    kv_router::{KvPushRouter, KvRouterConfig},
22
    model_type::ModelType,
23
    preprocessor::{OpenAIPreprocessor, PreprocessedRequest},
24
    protocols::common::llm_backend::LLMEngineOutput,
25
26
27
    protocols::openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
28
    protocols::openai::completions::{CompletionRequest, CompletionResponse},
29
    protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31

32
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
33

34
pub struct ModelWatcher {
35
    manager: Arc<ModelManager>,
36
    drt: DistributedRuntime,
37
    router_mode: RouterMode,
38
    notify_on_model: Notify,
39
    kv_router_config: Option<KvRouterConfig>,
Ryan Olson's avatar
Ryan Olson committed
40
41
}

42
impl ModelWatcher {
43
    pub fn new(
44
        runtime: DistributedRuntime,
45
        model_manager: Arc<ModelManager>,
46
        router_mode: RouterMode,
47
        kv_router_config: Option<KvRouterConfig>,
48
49
    ) -> ModelWatcher {
        Self {
50
            manager: model_manager,
51
            drt: runtime,
52
            router_mode,
53
            notify_on_model: Notify::new(),
54
            kv_router_config,
55
        }
56
    }
Ryan Olson's avatar
Ryan Olson committed
57

58
59
60
61
62
63
64
65
66
67
68
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

69
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
70
71
72
73
74
75
76
77
        tracing::debug!("model watcher started");

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
78
79
80
81
82
83
84
85
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
86
87
88
                            continue;
                        }
                    };
89
90
91
92
93
94
95
96
97
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

98
                    if self.manager.has_model_any(&model_entry.name) {
99
                        tracing::trace!(name = model_entry.name, "New endpoint for existing model");
100
                        self.notify_on_model.notify_waiters();
101
102
103
                        continue;
                    }

104
                    match self.handle_put(&model_entry).await {
105
106
                        Ok(()) => {
                            tracing::info!(model_name = model_entry.name, "added model");
107
                            self.notify_on_model.notify_waiters();
108
109
110
111
112
113
                        }
                        Err(e) => {
                            tracing::error!(%e, "error adding model {}", model_entry.name);
                        }
                    }
                }
114
115
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
116
                        tracing::info!("removed model {}", model_name);
117
                    }
118
119
120
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
121
                    Err(e) => {
122
                        tracing::error!("error removing model: {}", e);
123
                    }
124
                },
125
            }
Ryan Olson's avatar
Ryan Olson committed
126
127
128
        }
    }

129
130
131
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
132
        let key = kv.key_str()?;
133
        let model_entry = match self.manager.remove_model_entry(key) {
134
135
136
137
138
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
139
140
141
142
143
144
145
146
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
147

148
        // Ignore the errors because model could be either type
149
150
151
        let _ = self.manager.remove_chat_completions_model(&model_name);
        let _ = self.manager.remove_completions_model(&model_name);
        let _ = self.manager.remove_embeddings_model(&model_name);
152

153
        Ok(Some(model_name))
154
    }
Ryan Olson's avatar
Ryan Olson committed
155

156
157
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
158
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
159
        let endpoint_id = model_entry.endpoint.clone();
160
        let component = self
161
162
            .drt
            .namespace(&endpoint_id.namespace)?
163
164
            .component(&endpoint_id.component)?;
        let client = component.endpoint(&endpoint_id.name).client().await?;
165

166
167
168
169
        let Some(etcd_client) = self.drt.etcd_client() else {
            // Should be impossible because we only get here on an etcd event
            anyhow::bail!("Missing etcd_client");
        };
170
        let card = match model_entry.load_mdc(&etcd_client).await {
171
172
173
174
175
176
177
178
179
180
            Ok(card) => {
                tracing::debug!(card.display_name, "adding model");
                Some(card)
            }
            Err(err) => {
                // `dynamo serve` isn't using MDC yet so can't be an error
                tracing::info!(%err, "load_mdc did not complete");
                None
            }
        };
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        match model_entry.model_type {
            ModelType::Backend => {
                // A Backend model expects pre-processed requests meaning it's up to us whether we
                // handle Chat or Completions requests, so handle both.

                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card");
                };
                // Download tokenizer.json etc to local disk
                // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
                // OpenAIPreprocessor::new loads the files, so we can delete them after this
                // function. Needs checking carefully, possibly we need to store it in state.
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

                let frontend = SegmentSource::<
                    SingleIn<NvCreateChatCompletionRequest>,
                    ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
202
203
204
205
206
207
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client.clone(),
                        self.router_mode,
                    )
                    .await?;
208
                let service_backend = match self.router_mode {
209
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
210
211
                        ServiceBackend::from_engine(Arc::new(router))
                    }
212
                    RouterMode::KV => {
213
214
                        let chooser = self
                            .manager
215
216
217
218
219
220
                            .kv_chooser_for(
                                &model_entry.name,
                                &component,
                                card.kv_cache_block_size,
                                self.kv_router_config.clone(),
                            )
221
                            .await?;
222
                        let kv_push_router = KvPushRouter::new(router, chooser);
223
224
225
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
                let chat_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_chat_completions_model(&model_entry.name, chat_engine)?;

                let frontend = SegmentSource::<
                    SingleIn<CompletionRequest>,
                    ManyOut<Annotated<CompletionResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
243
244
245
246
247
248
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client,
                        self.router_mode,
                    )
                    .await?;
249
                let service_backend = match self.router_mode {
250
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
251
252
                        ServiceBackend::from_engine(Arc::new(router))
                    }
253
                    RouterMode::KV => {
254
255
                        let chooser = self
                            .manager
256
257
258
259
260
261
                            .kv_chooser_for(
                                &model_entry.name,
                                &component,
                                card.kv_cache_block_size,
                                self.kv_router_config.clone(),
                            )
262
                            .await?;
263
                        let kv_push_router = KvPushRouter::new(router, chooser);
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };

                let completions_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_completions_model(&model_entry.name, completions_engine)?;
            }
            ModelType::Chat => {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_chat_completions_model(&model_entry.name, engine)?;
            }
            ModelType::Completion => {
                let push_router =
                    PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
                        client,
                        Default::default(),
                    )
                    .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_completions_model(&model_entry.name, engine)?;
            }
299
300
301
302
303
304
305
306
307
308
            ModelType::Embedding => {
                let push_router = PushRouter::<
                    NvCreateEmbeddingRequest,
                    Annotated<NvCreateEmbeddingResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_embeddings_model(&model_entry.name, engine)?;
            }
309
        }
Ryan Olson's avatar
Ryan Olson committed
310

311
312
        Ok(())
    }
313

314
    /// All the registered ModelEntry, one per instance
315
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
337
        }
338
339
340
        Ok(entries)
    }

341
    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
342
343
344
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
345
    }
Ryan Olson's avatar
Ryan Olson committed
346
}