watcher.rs 13.7 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5

6
use anyhow::Context as _;
7
use tokio::sync::{mpsc::Receiver, Notify};
Ryan Olson's avatar
Ryan Olson committed
8

Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
    pipeline::{
11
12
        network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
        ServiceBackend, SingleIn, Source,
13
    },
14
15
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
16
    DistributedRuntime,
17
};
Ryan Olson's avatar
Ryan Olson committed
18

19
20
use crate::{
    backend::Backend,
21
    kv_router::KvPushRouter,
22
23
24
    model_type::ModelType,
    preprocessor::{BackendInput, OpenAIPreprocessor},
    protocols::common::llm_backend::LLMEngineOutput,
25
26
27
    protocols::openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
28
    protocols::openai::completions::{CompletionRequest, CompletionResponse},
29
    protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31

32
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
33

34
pub struct ModelWatcher {
35
    manager: Arc<ModelManager>,
36
    drt: DistributedRuntime,
37
    router_mode: RouterMode,
38
    notify_on_model: Notify,
Ryan Olson's avatar
Ryan Olson committed
39
40
}

41
42
impl ModelWatcher {
    pub async fn new(
43
        runtime: DistributedRuntime,
44
        model_manager: Arc<ModelManager>,
45
        router_mode: RouterMode,
46
47
48
    ) -> anyhow::Result<ModelWatcher> {
        Ok(Self {
            manager: model_manager,
49
            drt: runtime,
50
            router_mode,
51
            notify_on_model: Notify::new(),
52
53
        })
    }
Ryan Olson's avatar
Ryan Olson committed
54

55
56
57
58
59
60
61
62
63
64
65
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

66
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
67
68
69
70
71
72
73
74
        tracing::debug!("model watcher started");

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
75
76
77
78
79
80
81
82
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
83
84
85
                            continue;
                        }
                    };
86
87
88
89
90
91
92
93
94
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

95
96
97
98
99
                    if self.manager.has_model_any(&model_entry.name) {
                        tracing::trace!(
                            service_name = model_entry.name,
                            "New endpoint for existing model"
                        );
100
                        self.notify_on_model.notify_waiters();
101
102
103
                        continue;
                    }

104
                    match self.handle_put(&model_entry).await {
105
106
                        Ok(()) => {
                            tracing::info!(model_name = model_entry.name, "added model");
107
                            self.notify_on_model.notify_waiters();
108
109
110
111
112
113
                        }
                        Err(e) => {
                            tracing::error!(%e, "error adding model {}", model_entry.name);
                        }
                    }
                }
114
115
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
116
                        tracing::info!("removed model {}", model_name);
117
                    }
118
119
120
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
121
                    Err(e) => {
122
                        tracing::error!("error removing model: {}", e);
123
                    }
124
                },
125
            }
Ryan Olson's avatar
Ryan Olson committed
126
127
128
        }
    }

129
130
131
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
132
        let key = kv.key_str()?;
133
        let model_entry = match self.manager.remove_model_entry(key) {
134
135
136
137
138
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
139
140
141
142
143
144
145
146
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
147

148
        // Ignore the errors because model could be either type
149
150
151
        let _ = self.manager.remove_chat_completions_model(&model_name);
        let _ = self.manager.remove_completions_model(&model_name);
        let _ = self.manager.remove_embeddings_model(&model_name);
152

153
        Ok(Some(model_name))
154
    }
Ryan Olson's avatar
Ryan Olson committed
155

156
157
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
158
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
159
        let endpoint_id = model_entry.endpoint.clone();
160
        let component = self
161
162
            .drt
            .namespace(&endpoint_id.namespace)?
163
164
            .component(&endpoint_id.component)?;
        let client = component.endpoint(&endpoint_id.name).client().await?;
165

166
167
168
169
        let Some(etcd_client) = self.drt.etcd_client() else {
            // Should be impossible because we only get here on an etcd event
            anyhow::bail!("Missing etcd_client");
        };
170
        let card = match model_entry.load_mdc(&etcd_client).await {
171
172
173
174
175
176
177
178
179
180
            Ok(card) => {
                tracing::debug!(card.display_name, "adding model");
                Some(card)
            }
            Err(err) => {
                // `dynamo serve` isn't using MDC yet so can't be an error
                tracing::info!(%err, "load_mdc did not complete");
                None
            }
        };
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        match model_entry.model_type {
            ModelType::Backend => {
                // A Backend model expects pre-processed requests meaning it's up to us whether we
                // handle Chat or Completions requests, so handle both.

                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card");
                };
                // Download tokenizer.json etc to local disk
                // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
                // OpenAIPreprocessor::new loads the files, so we can delete them after this
                // function. Needs checking carefully, possibly we need to store it in state.
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

                let frontend = SegmentSource::<
                    SingleIn<NvCreateChatCompletionRequest>,
                    ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
                let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
                    client.clone(),
204
                    self.router_mode,
205
206
207
                )
                .await?;
                let service_backend = match self.router_mode {
208
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
209
210
                        ServiceBackend::from_engine(Arc::new(router))
                    }
211
                    RouterMode::KV => {
212
213
214
215
                        let chooser = self
                            .manager
                            .kv_chooser_for(&model_entry.name, &component)
                            .await?;
216
                        let kv_push_router = KvPushRouter::new(router, chooser);
217
218
219
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                let chat_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_chat_completions_model(&model_entry.name, chat_engine)?;

                let frontend = SegmentSource::<
                    SingleIn<CompletionRequest>,
                    ManyOut<Annotated<CompletionResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
                let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
238
                    client,
239
                    self.router_mode,
240
                )
241
                .await?;
242
                let service_backend = match self.router_mode {
243
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
244
245
                        ServiceBackend::from_engine(Arc::new(router))
                    }
246
                    RouterMode::KV => {
247
248
249
250
                        let chooser = self
                            .manager
                            .kv_chooser_for(&model_entry.name, &component)
                            .await?;
251
                        let kv_push_router = KvPushRouter::new(router, chooser);
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };

                let completions_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_completions_model(&model_entry.name, completions_engine)?;
            }
            ModelType::Chat => {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_chat_completions_model(&model_entry.name, engine)?;
            }
            ModelType::Completion => {
                let push_router =
                    PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
                        client,
                        Default::default(),
                    )
                    .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_completions_model(&model_entry.name, engine)?;
            }
287
288
289
290
291
292
293
294
295
296
            ModelType::Embedding => {
                let push_router = PushRouter::<
                    NvCreateEmbeddingRequest,
                    Annotated<NvCreateEmbeddingResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_embeddings_model(&model_entry.name, engine)?;
            }
297
        }
Ryan Olson's avatar
Ryan Olson committed
298

299
300
        Ok(())
    }
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    /// All the registered ModelEntry, one per instance
    async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
325
        }
326
327
328
329
330
331
332
        Ok(entries)
    }

    async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
333
    }
Ryan Olson's avatar
Ryan Olson committed
334
}