"tests/vscode:/vscode.git/clone" did not exist on "4c04eef706fcd26c8e046cd797fa9b14a5fb361d"
watcher.rs 14.4 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5

6
use anyhow::Context as _;
7
use tokio::sync::{mpsc::Receiver, Notify};
Ryan Olson's avatar
Ryan Olson committed
8

Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
    pipeline::{
11
12
        network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
        ServiceBackend, SingleIn, Source,
13
    },
14
15
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
16
    DistributedRuntime,
17
};
Ryan Olson's avatar
Ryan Olson committed
18

19
20
use crate::{
    backend::Backend,
21
    kv_router::{KvPushRouter, KvRouterConfig},
22
    model_type::ModelType,
23
    preprocessor::{OpenAIPreprocessor, PreprocessedRequest},
24
    protocols::common::llm_backend::LLMEngineOutput,
25
26
27
    protocols::openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
28
    protocols::openai::completions::{CompletionRequest, CompletionResponse},
29
    protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31

32
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
33

34
pub struct ModelWatcher {
35
    manager: Arc<ModelManager>,
36
    drt: DistributedRuntime,
37
    router_mode: RouterMode,
38
    notify_on_model: Notify,
39
    kv_router_config: Option<KvRouterConfig>,
Ryan Olson's avatar
Ryan Olson committed
40
41
}

42
impl ModelWatcher {
43
    pub fn new(
44
        runtime: DistributedRuntime,
45
        model_manager: Arc<ModelManager>,
46
        router_mode: RouterMode,
47
        kv_router_config: Option<KvRouterConfig>,
48
49
    ) -> ModelWatcher {
        Self {
50
            manager: model_manager,
51
            drt: runtime,
52
            router_mode,
53
            notify_on_model: Notify::new(),
54
            kv_router_config,
55
        }
56
    }
Ryan Olson's avatar
Ryan Olson committed
57

58
59
60
61
62
63
64
65
66
67
68
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

69
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
70
71
72
73
74
75
76
77
        tracing::debug!("model watcher started");

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
78
79
80
81
82
83
84
85
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
86
87
88
                            continue;
                        }
                    };
89
90
91
92
93
94
95
96
97
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

98
                    if self.manager.has_model_any(&model_entry.name) {
99
                        tracing::trace!(name = model_entry.name, "New endpoint for existing model");
100
                        self.notify_on_model.notify_waiters();
101
102
103
                        continue;
                    }

104
                    match self.handle_put(&model_entry).await {
105
106
                        Ok(()) => {
                            tracing::info!(model_name = model_entry.name, "added model");
107
                            self.notify_on_model.notify_waiters();
108
                        }
109
110
111
112
113
114
                        Err(err) => {
                            tracing::error!(
                                error = format!("{err:#}"),
                                "error adding model {}",
                                model_entry.name
                            );
115
116
117
                        }
                    }
                }
118
119
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
120
                        tracing::info!("removed model {}", model_name);
121
                    }
122
123
124
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
125
                    Err(e) => {
126
                        tracing::error!("error removing model: {}", e);
127
                    }
128
                },
129
            }
Ryan Olson's avatar
Ryan Olson committed
130
131
132
        }
    }

133
134
135
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
136
        let key = kv.key_str()?;
137
        let model_entry = match self.manager.remove_model_entry(key) {
138
139
140
141
142
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
143
144
145
146
147
148
149
150
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
151

152
        // Ignore the errors because model could be either type
153
154
155
        let _ = self.manager.remove_chat_completions_model(&model_name);
        let _ = self.manager.remove_completions_model(&model_name);
        let _ = self.manager.remove_embeddings_model(&model_name);
156

157
        Ok(Some(model_name))
158
    }
Ryan Olson's avatar
Ryan Olson committed
159

160
161
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
162
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
163
        let endpoint_id = model_entry.endpoint.clone();
164
        let component = self
165
166
            .drt
            .namespace(&endpoint_id.namespace)?
167
168
            .component(&endpoint_id.component)?;
        let client = component.endpoint(&endpoint_id.name).client().await?;
169

170
171
172
173
        let Some(etcd_client) = self.drt.etcd_client() else {
            // Should be impossible because we only get here on an etcd event
            anyhow::bail!("Missing etcd_client");
        };
174
        let card = match model_entry.load_mdc(&etcd_client).await {
175
176
177
178
179
180
181
182
183
184
            Ok(card) => {
                tracing::debug!(card.display_name, "adding model");
                Some(card)
            }
            Err(err) => {
                // `dynamo serve` isn't using MDC yet so can't be an error
                tracing::info!(%err, "load_mdc did not complete");
                None
            }
        };
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        match model_entry.model_type {
            ModelType::Backend => {
                // A Backend model expects pre-processed requests meaning it's up to us whether we
                // handle Chat or Completions requests, so handle both.

                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card");
                };
                // Download tokenizer.json etc to local disk
                // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
                // OpenAIPreprocessor::new loads the files, so we can delete them after this
                // function. Needs checking carefully, possibly we need to store it in state.
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

                let frontend = SegmentSource::<
                    SingleIn<NvCreateChatCompletionRequest>,
                    ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
206
207
208
209
210
211
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client.clone(),
                        self.router_mode,
                    )
                    .await?;
212
                let service_backend = match self.router_mode {
213
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
214
215
                        ServiceBackend::from_engine(Arc::new(router))
                    }
216
                    RouterMode::KV => {
217
218
                        let chooser = self
                            .manager
219
220
221
222
223
224
                            .kv_chooser_for(
                                &model_entry.name,
                                &component,
                                card.kv_cache_block_size,
                                self.kv_router_config.clone(),
                            )
225
                            .await?;
226
                        let kv_push_router = KvPushRouter::new(router, chooser);
227
228
229
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };
230

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
                let chat_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_chat_completions_model(&model_entry.name, chat_engine)?;

                let frontend = SegmentSource::<
                    SingleIn<CompletionRequest>,
                    ManyOut<Annotated<CompletionResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
247
248
249
250
251
252
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client,
                        self.router_mode,
                    )
                    .await?;
253
                let service_backend = match self.router_mode {
254
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
255
256
                        ServiceBackend::from_engine(Arc::new(router))
                    }
257
                    RouterMode::KV => {
258
259
                        let chooser = self
                            .manager
260
261
262
263
264
265
                            .kv_chooser_for(
                                &model_entry.name,
                                &component,
                                card.kv_cache_block_size,
                                self.kv_router_config.clone(),
                            )
266
                            .await?;
267
                        let kv_push_router = KvPushRouter::new(router, chooser);
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };

                let completions_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_completions_model(&model_entry.name, completions_engine)?;
            }
            ModelType::Chat => {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_chat_completions_model(&model_entry.name, engine)?;
            }
            ModelType::Completion => {
                let push_router =
                    PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
                        client,
                        Default::default(),
                    )
                    .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_completions_model(&model_entry.name, engine)?;
            }
303
304
305
306
307
308
309
310
311
312
            ModelType::Embedding => {
                let push_router = PushRouter::<
                    NvCreateEmbeddingRequest,
                    Annotated<NvCreateEmbeddingResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_embeddings_model(&model_entry.name, engine)?;
            }
313
        }
Ryan Olson's avatar
Ryan Olson committed
314

315
316
        Ok(())
    }
317

318
    /// All the registered ModelEntry, one per instance
319
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
341
        }
342
343
344
        Ok(entries)
    }

345
    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
346
347
348
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
349
    }
Ryan Olson's avatar
Ryan Olson committed
350
}