watcher.rs 13.7 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5

6
use anyhow::Context as _;
7
use tokio::sync::{mpsc::Receiver, Notify};
Ryan Olson's avatar
Ryan Olson committed
8

Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
    pipeline::{
11
12
        network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
        ServiceBackend, SingleIn, Source,
13
    },
14
15
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
16
    DistributedRuntime,
17
};
Ryan Olson's avatar
Ryan Olson committed
18

19
20
use crate::{
    backend::Backend,
21
    kv_router::KvPushRouter,
22
    model_type::ModelType,
23
    preprocessor::{OpenAIPreprocessor, PreprocessedRequest},
24
    protocols::common::llm_backend::LLMEngineOutput,
25
26
27
    protocols::openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
28
    protocols::openai::completions::{CompletionRequest, CompletionResponse},
29
    protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31

32
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
33

34
pub struct ModelWatcher {
35
    manager: Arc<ModelManager>,
36
    drt: DistributedRuntime,
37
    router_mode: RouterMode,
38
    notify_on_model: Notify,
Ryan Olson's avatar
Ryan Olson committed
39
40
}

41
impl ModelWatcher {
42
    pub fn new(
43
        runtime: DistributedRuntime,
44
        model_manager: Arc<ModelManager>,
45
        router_mode: RouterMode,
46
47
    ) -> ModelWatcher {
        Self {
48
            manager: model_manager,
49
            drt: runtime,
50
            router_mode,
51
            notify_on_model: Notify::new(),
52
        }
53
    }
Ryan Olson's avatar
Ryan Olson committed
54

55
56
57
58
59
60
61
62
63
64
65
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

66
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
67
68
69
70
71
72
73
74
        tracing::debug!("model watcher started");

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
75
76
77
78
79
80
81
82
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
83
84
85
                            continue;
                        }
                    };
86
87
88
89
90
91
92
93
94
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

95
                    if self.manager.has_model_any(&model_entry.name) {
96
                        tracing::trace!(name = model_entry.name, "New endpoint for existing model");
97
                        self.notify_on_model.notify_waiters();
98
99
100
                        continue;
                    }

101
                    match self.handle_put(&model_entry).await {
102
103
                        Ok(()) => {
                            tracing::info!(model_name = model_entry.name, "added model");
104
                            self.notify_on_model.notify_waiters();
105
106
107
108
109
110
                        }
                        Err(e) => {
                            tracing::error!(%e, "error adding model {}", model_entry.name);
                        }
                    }
                }
111
112
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
113
                        tracing::info!("removed model {}", model_name);
114
                    }
115
116
117
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
118
                    Err(e) => {
119
                        tracing::error!("error removing model: {}", e);
120
                    }
121
                },
122
            }
Ryan Olson's avatar
Ryan Olson committed
123
124
125
        }
    }

126
127
128
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
129
        let key = kv.key_str()?;
130
        let model_entry = match self.manager.remove_model_entry(key) {
131
132
133
134
135
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
136
137
138
139
140
141
142
143
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
144

145
        // Ignore the errors because model could be either type
146
147
148
        let _ = self.manager.remove_chat_completions_model(&model_name);
        let _ = self.manager.remove_completions_model(&model_name);
        let _ = self.manager.remove_embeddings_model(&model_name);
149

150
        Ok(Some(model_name))
151
    }
Ryan Olson's avatar
Ryan Olson committed
152

153
154
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
155
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
156
        let endpoint_id = model_entry.endpoint.clone();
157
        let component = self
158
159
            .drt
            .namespace(&endpoint_id.namespace)?
160
161
            .component(&endpoint_id.component)?;
        let client = component.endpoint(&endpoint_id.name).client().await?;
162

163
164
165
166
        let Some(etcd_client) = self.drt.etcd_client() else {
            // Should be impossible because we only get here on an etcd event
            anyhow::bail!("Missing etcd_client");
        };
167
        let card = match model_entry.load_mdc(&etcd_client).await {
168
169
170
171
172
173
174
175
176
177
            Ok(card) => {
                tracing::debug!(card.display_name, "adding model");
                Some(card)
            }
            Err(err) => {
                // `dynamo serve` isn't using MDC yet so can't be an error
                tracing::info!(%err, "load_mdc did not complete");
                None
            }
        };
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        match model_entry.model_type {
            ModelType::Backend => {
                // A Backend model expects pre-processed requests meaning it's up to us whether we
                // handle Chat or Completions requests, so handle both.

                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card");
                };
                // Download tokenizer.json etc to local disk
                // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
                // OpenAIPreprocessor::new loads the files, so we can delete them after this
                // function. Needs checking carefully, possibly we need to store it in state.
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

                let frontend = SegmentSource::<
                    SingleIn<NvCreateChatCompletionRequest>,
                    ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
199
200
201
202
203
204
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client.clone(),
                        self.router_mode,
                    )
                    .await?;
205
                let service_backend = match self.router_mode {
206
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
207
208
                        ServiceBackend::from_engine(Arc::new(router))
                    }
209
                    RouterMode::KV => {
210
211
                        let chooser = self
                            .manager
212
                            .kv_chooser_for(&model_entry.name, &component, card.kv_cache_block_size)
213
                            .await?;
214
                        let kv_push_router = KvPushRouter::new(router, chooser);
215
216
217
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };
218

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                let chat_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_chat_completions_model(&model_entry.name, chat_engine)?;

                let frontend = SegmentSource::<
                    SingleIn<CompletionRequest>,
                    ManyOut<Annotated<CompletionResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
235
236
237
238
239
240
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client,
                        self.router_mode,
                    )
                    .await?;
241
                let service_backend = match self.router_mode {
242
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
243
244
                        ServiceBackend::from_engine(Arc::new(router))
                    }
245
                    RouterMode::KV => {
246
247
                        let chooser = self
                            .manager
248
                            .kv_chooser_for(&model_entry.name, &component, card.kv_cache_block_size)
249
                            .await?;
250
                        let kv_push_router = KvPushRouter::new(router, chooser);
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };

                let completions_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_completions_model(&model_entry.name, completions_engine)?;
            }
            ModelType::Chat => {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_chat_completions_model(&model_entry.name, engine)?;
            }
            ModelType::Completion => {
                let push_router =
                    PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
                        client,
                        Default::default(),
                    )
                    .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_completions_model(&model_entry.name, engine)?;
            }
286
287
288
289
290
291
292
293
294
295
            ModelType::Embedding => {
                let push_router = PushRouter::<
                    NvCreateEmbeddingRequest,
                    Annotated<NvCreateEmbeddingResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_embeddings_model(&model_entry.name, engine)?;
            }
296
        }
Ryan Olson's avatar
Ryan Olson committed
297

298
299
        Ok(())
    }
300

301
    /// All the registered ModelEntry, one per instance
302
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
324
        }
325
326
327
        Ok(entries)
    }

328
    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
329
330
331
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
332
    }
Ryan Olson's avatar
Ryan Olson committed
333
}