watcher.rs 13.6 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5

6
use anyhow::Context as _;
7
use tokio::sync::{mpsc::Receiver, Notify};
Ryan Olson's avatar
Ryan Olson committed
8

Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
    pipeline::{
11
12
        network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
        ServiceBackend, SingleIn, Source,
13
    },
14
15
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
16
    DistributedRuntime,
17
};
Ryan Olson's avatar
Ryan Olson committed
18

19
20
use crate::{
    backend::Backend,
21
    kv_router::KvPushRouter,
22
23
24
    model_type::ModelType,
    preprocessor::{BackendInput, OpenAIPreprocessor},
    protocols::common::llm_backend::LLMEngineOutput,
25
26
27
    protocols::openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
28
    protocols::openai::completions::{CompletionRequest, CompletionResponse},
29
    protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31

32
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
33

34
pub struct ModelWatcher {
35
    manager: Arc<ModelManager>,
36
    drt: DistributedRuntime,
37
    router_mode: RouterMode,
38
    notify_on_model: Notify,
Ryan Olson's avatar
Ryan Olson committed
39
40
}

41
impl ModelWatcher {
42
    pub fn new(
43
        runtime: DistributedRuntime,
44
        model_manager: Arc<ModelManager>,
45
        router_mode: RouterMode,
46
47
    ) -> ModelWatcher {
        Self {
48
            manager: model_manager,
49
            drt: runtime,
50
            router_mode,
51
            notify_on_model: Notify::new(),
52
        }
53
    }
Ryan Olson's avatar
Ryan Olson committed
54

55
56
57
58
59
60
61
62
63
64
65
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

66
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
67
68
69
70
71
72
73
74
        tracing::debug!("model watcher started");

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
75
76
77
78
79
80
81
82
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
83
84
85
                            continue;
                        }
                    };
86
87
88
89
90
91
92
93
94
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

95
                    if self.manager.has_model_any(&model_entry.name) {
96
                        tracing::trace!(name = model_entry.name, "New endpoint for existing model");
97
                        self.notify_on_model.notify_waiters();
98
99
100
                        continue;
                    }

101
                    match self.handle_put(&model_entry).await {
102
103
                        Ok(()) => {
                            tracing::info!(model_name = model_entry.name, "added model");
104
                            self.notify_on_model.notify_waiters();
105
106
107
108
109
110
                        }
                        Err(e) => {
                            tracing::error!(%e, "error adding model {}", model_entry.name);
                        }
                    }
                }
111
112
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
113
                        tracing::info!("removed model {}", model_name);
114
                    }
115
116
117
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
118
                    Err(e) => {
119
                        tracing::error!("error removing model: {}", e);
120
                    }
121
                },
122
            }
Ryan Olson's avatar
Ryan Olson committed
123
124
125
        }
    }

126
127
128
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
129
        let key = kv.key_str()?;
130
        let model_entry = match self.manager.remove_model_entry(key) {
131
132
133
134
135
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
136
137
138
139
140
141
142
143
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
144

145
        // Ignore the errors because model could be either type
146
147
148
        let _ = self.manager.remove_chat_completions_model(&model_name);
        let _ = self.manager.remove_completions_model(&model_name);
        let _ = self.manager.remove_embeddings_model(&model_name);
149

150
        Ok(Some(model_name))
151
    }
Ryan Olson's avatar
Ryan Olson committed
152

153
154
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
155
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
156
        let endpoint_id = model_entry.endpoint.clone();
157
        let component = self
158
159
            .drt
            .namespace(&endpoint_id.namespace)?
160
161
            .component(&endpoint_id.component)?;
        let client = component.endpoint(&endpoint_id.name).client().await?;
162

163
164
165
166
        let Some(etcd_client) = self.drt.etcd_client() else {
            // Should be impossible because we only get here on an etcd event
            anyhow::bail!("Missing etcd_client");
        };
167
        let card = match model_entry.load_mdc(&etcd_client).await {
168
169
170
171
172
173
174
175
176
177
            Ok(card) => {
                tracing::debug!(card.display_name, "adding model");
                Some(card)
            }
            Err(err) => {
                // `dynamo serve` isn't using MDC yet so can't be an error
                tracing::info!(%err, "load_mdc did not complete");
                None
            }
        };
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        match model_entry.model_type {
            ModelType::Backend => {
                // A Backend model expects pre-processed requests meaning it's up to us whether we
                // handle Chat or Completions requests, so handle both.

                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card");
                };
                // Download tokenizer.json etc to local disk
                // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
                // OpenAIPreprocessor::new loads the files, so we can delete them after this
                // function. Needs checking carefully, possibly we need to store it in state.
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

                let frontend = SegmentSource::<
                    SingleIn<NvCreateChatCompletionRequest>,
                    ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
                let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
                    client.clone(),
201
                    self.router_mode,
202
203
204
                )
                .await?;
                let service_backend = match self.router_mode {
205
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
206
207
                        ServiceBackend::from_engine(Arc::new(router))
                    }
208
                    RouterMode::KV => {
209
210
                        let chooser = self
                            .manager
211
                            .kv_chooser_for(&model_entry.name, &component, card.kv_cache_block_size)
212
                            .await?;
213
                        let kv_push_router = KvPushRouter::new(router, chooser);
214
215
216
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                let chat_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_chat_completions_model(&model_entry.name, chat_engine)?;

                let frontend = SegmentSource::<
                    SingleIn<CompletionRequest>,
                    ManyOut<Annotated<CompletionResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
                let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
235
                    client,
236
                    self.router_mode,
237
                )
238
                .await?;
239
                let service_backend = match self.router_mode {
240
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
241
242
                        ServiceBackend::from_engine(Arc::new(router))
                    }
243
                    RouterMode::KV => {
244
245
                        let chooser = self
                            .manager
246
                            .kv_chooser_for(&model_entry.name, &component, card.kv_cache_block_size)
247
                            .await?;
248
                        let kv_push_router = KvPushRouter::new(router, chooser);
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };

                let completions_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_completions_model(&model_entry.name, completions_engine)?;
            }
            ModelType::Chat => {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_chat_completions_model(&model_entry.name, engine)?;
            }
            ModelType::Completion => {
                let push_router =
                    PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
                        client,
                        Default::default(),
                    )
                    .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_completions_model(&model_entry.name, engine)?;
            }
284
285
286
287
288
289
290
291
292
293
            ModelType::Embedding => {
                let push_router = PushRouter::<
                    NvCreateEmbeddingRequest,
                    Annotated<NvCreateEmbeddingResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_embeddings_model(&model_entry.name, engine)?;
            }
294
        }
Ryan Olson's avatar
Ryan Olson committed
295

296
297
        Ok(())
    }
298

299
    /// All the registered ModelEntry, one per instance
300
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
322
        }
323
324
325
        Ok(entries)
    }

326
    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
327
328
329
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
330
    }
Ryan Olson's avatar
Ryan Olson committed
331
}