"csrc/quantization/marlin/generate_kernels.py" did not exist on "ce96857fdd2bf2390aaa2183561fd1a0f5c464c7"
watcher.rs 16.2 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Ryan Olson's avatar
Ryan Olson committed
4
use std::sync::Arc;
5

6
use anyhow::Context as _;
7
use tokio::sync::{mpsc::Receiver, Notify};
Ryan Olson's avatar
Ryan Olson committed
8

Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
    pipeline::{
11
12
        network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
        ServiceBackend, SingleIn, Source,
13
    },
14
15
    protocols::annotated::Annotated,
    transports::etcd::{KeyValue, WatchEvent},
16
    DistributedRuntime,
17
};
Ryan Olson's avatar
Ryan Olson committed
18

19
20
use crate::{
    backend::Backend,
21
    kv_router::{KvPushRouter, KvRouterConfig},
22
    migration::Migration,
23
    model_type::ModelType,
24
25
    preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
    protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
26
27
28
    protocols::openai::chat_completions::{
        NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
    },
29
    protocols::openai::completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
30
    protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
31
};
32

33
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
34

35
pub struct ModelWatcher {
36
    manager: Arc<ModelManager>,
37
    drt: DistributedRuntime,
38
    router_mode: RouterMode,
39
    notify_on_model: Notify,
40
    kv_router_config: Option<KvRouterConfig>,
Ryan Olson's avatar
Ryan Olson committed
41
42
}

43
impl ModelWatcher {
44
    pub fn new(
45
        runtime: DistributedRuntime,
46
        model_manager: Arc<ModelManager>,
47
        router_mode: RouterMode,
48
        kv_router_config: Option<KvRouterConfig>,
49
50
    ) -> ModelWatcher {
        Self {
51
            manager: model_manager,
52
            drt: runtime,
53
            router_mode,
54
            notify_on_model: Notify::new(),
55
            kv_router_config,
56
        }
57
    }
Ryan Olson's avatar
Ryan Olson committed
58

59
60
61
62
63
64
65
66
67
68
69
    /// Wait until we have at least one chat completions model and return it's name.
    pub async fn wait_for_chat_model(&self) -> String {
        // Loop in case it gets added and immediately deleted
        loop {
            if let Some(model_name) = self.manager.list_chat_completions_models().first() {
                return model_name.to_owned();
            }
            self.notify_on_model.notified().await
        }
    }

70
    pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
71
72
73
74
75
76
77
78
        tracing::debug!("model watcher started");

        while let Some(event) = events_rx.recv().await {
            match event {
                WatchEvent::Put(kv) => {
                    let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                        Ok(model_entry) => model_entry,
                        Err(err) => {
79
80
81
82
83
84
85
86
                            match kv.value_str() {
                                Ok(value) => {
                                    tracing::error!(%err, value, "Invalid JSON in model entry")
                                }
                                Err(value_str_err) => {
                                    tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                                }
                            }
87
88
89
                            continue;
                        }
                    };
90
91
92
93
94
95
96
97
98
                    let key = match kv.key_str() {
                        Ok(k) => k,
                        Err(err) => {
                            tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
                            continue;
                        }
                    };
                    self.manager.save_model_entry(key, model_entry.clone());

99
                    if self.manager.has_model_any(&model_entry.name) {
100
                        tracing::trace!(name = model_entry.name, "New endpoint for existing model");
101
                        self.notify_on_model.notify_waiters();
102
103
104
                        continue;
                    }

105
                    match self.handle_put(&model_entry).await {
106
107
                        Ok(()) => {
                            tracing::info!(model_name = model_entry.name, "added model");
108
                            self.notify_on_model.notify_waiters();
109
                        }
110
111
112
113
114
115
                        Err(err) => {
                            tracing::error!(
                                error = format!("{err:#}"),
                                "error adding model {}",
                                model_entry.name
                            );
116
117
118
                        }
                    }
                }
119
120
                WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
                    Ok(Some(model_name)) => {
121
                        tracing::info!("removed model {}", model_name);
122
                    }
123
124
125
                    Ok(None) => {
                        // There are other instances running this model, nothing to do
                    }
126
                    Err(e) => {
127
                        tracing::error!("error removing model: {}", e);
128
                    }
129
                },
130
            }
Ryan Olson's avatar
Ryan Olson committed
131
132
133
        }
    }

134
135
136
    /// If the last instance running this model has gone delete it.
    /// Returns the name of the model we just deleted, if any.
    async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
137
        let key = kv.key_str()?;
138
        let model_entry = match self.manager.remove_model_entry(key) {
139
140
141
142
143
            Some(entry) => entry,
            None => {
                anyhow::bail!("Missing ModelEntry for {key}");
            }
        };
144
145
146
147
148
149
150
151
        let model_name = model_entry.name;
        let active_instances = self
            .entries_for_model(&model_name)
            .await
            .with_context(|| model_name.clone())?;
        if !active_instances.is_empty() {
            return Ok(None);
        }
152

153
        // Ignore the errors because model could be either type
154
155
156
        let _ = self.manager.remove_chat_completions_model(&model_name);
        let _ = self.manager.remove_completions_model(&model_name);
        let _ = self.manager.remove_embeddings_model(&model_name);
157

158
        Ok(Some(model_name))
159
    }
Ryan Olson's avatar
Ryan Olson committed
160

161
162
    // Handles a PUT event from etcd, this usually means adding a new model to the list of served
    // models.
163
    async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
164
        let endpoint_id = model_entry.endpoint.clone();
165
        let component = self
166
167
            .drt
            .namespace(&endpoint_id.namespace)?
168
169
            .component(&endpoint_id.component)?;
        let client = component.endpoint(&endpoint_id.name).client().await?;
170

171
172
173
174
        let Some(etcd_client) = self.drt.etcd_client() else {
            // Should be impossible because we only get here on an etcd event
            anyhow::bail!("Missing etcd_client");
        };
175
        let card = match model_entry.load_mdc(&etcd_client).await {
176
177
178
179
180
181
182
183
184
            Ok(card) => {
                tracing::debug!(card.display_name, "adding model");
                Some(card)
            }
            Err(err) => {
                tracing::info!(%err, "load_mdc did not complete");
                None
            }
        };
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
        match model_entry.model_type {
            ModelType::Backend => {
                // A Backend model expects pre-processed requests meaning it's up to us whether we
                // handle Chat or Completions requests, so handle both.

                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card");
                };
                // Download tokenizer.json etc to local disk
                // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
                // OpenAIPreprocessor::new loads the files, so we can delete them after this
                // function. Needs checking carefully, possibly we need to store it in state.
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

200
                // Chat Completions
201
202
203
204
205
206
                let frontend = SegmentSource::<
                    SingleIn<NvCreateChatCompletionRequest>,
                    ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
207
                let migration = Migration::from_mdc(card.clone()).await?.into_operator();
208
209
210
211
212
213
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client.clone(),
                        self.router_mode,
                    )
                    .await?;
214
                let service_backend = match self.router_mode {
215
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
216
217
                        ServiceBackend::from_engine(Arc::new(router))
                    }
218
                    RouterMode::KV => {
219
220
                        let chooser = self
                            .manager
221
222
223
224
                            .kv_chooser_for(
                                &model_entry.name,
                                &component,
                                card.kv_cache_block_size,
225
                                self.kv_router_config,
226
                            )
227
                            .await?;
228
                        let kv_push_router = KvPushRouter::new(router, chooser);
229
230
231
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };
232

233
234
235
                let chat_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
236
                    .link(migration.forward_edge())?
237
                    .link(service_backend)?
238
                    .link(migration.backward_edge())?
239
240
241
242
243
244
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_chat_completions_model(&model_entry.name, chat_engine)?;

245
                // Completions
246
                let frontend = SegmentSource::<
247
                    SingleIn<NvCreateCompletionRequest>,
248
                    ManyOut<Annotated<NvCreateCompletionResponse>>,
249
250
251
                >::new();
                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();
252
                let migration = Migration::from_mdc(card.clone()).await?.into_operator();
253
254
255
256
257
258
                let router =
                    PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
                        client,
                        self.router_mode,
                    )
                    .await?;
259
                let service_backend = match self.router_mode {
260
                    RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
261
262
                        ServiceBackend::from_engine(Arc::new(router))
                    }
263
                    RouterMode::KV => {
264
265
                        let chooser = self
                            .manager
266
267
268
269
                            .kv_chooser_for(
                                &model_entry.name,
                                &component,
                                card.kv_cache_block_size,
270
                                self.kv_router_config,
271
                            )
272
                            .await?;
273
                        let kv_push_router = KvPushRouter::new(router, chooser);
274
275
276
277
278
279
280
                        ServiceBackend::from_engine(Arc::new(kv_push_router))
                    }
                };

                let completions_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
281
                    .link(migration.forward_edge())?
282
                    .link(service_backend)?
283
                    .link(migration.backward_edge())?
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;
                self.manager
                    .add_completions_model(&model_entry.name, completions_engine)?;
            }
            ModelType::Chat => {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                let engine = Arc::new(push_router);
                self.manager
                    .add_chat_completions_model(&model_entry.name, engine)?;
            }
            ModelType::Completion => {
301
302
                let push_router = PushRouter::<
                    NvCreateCompletionRequest,
303
                    Annotated<NvCreateCompletionResponse>,
304
305
                >::from_client(client, Default::default())
                .await?;
306
307
308
309
                let engine = Arc::new(push_router);
                self.manager
                    .add_completions_model(&model_entry.name, engine)?;
            }
310
            ModelType::Embedding => {
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
                let Some(mut card) = card else {
                    anyhow::bail!("Missing model deployment card for embedding model");
                };

                // Download tokenizer files to local disk
                let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);

                // Create preprocessing pipeline similar to Backend
                let frontend = SegmentSource::<
                    SingleIn<NvCreateEmbeddingRequest>,
                    ManyOut<Annotated<NvCreateEmbeddingResponse>>,
                >::new();

                let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
                let backend = Backend::from_mdc(card.clone()).await?.into_operator();

                let router = PushRouter::<
                    PreprocessedEmbeddingRequest,
                    Annotated<EmbeddingsEngineOutput>,
                >::from_client(client, self.router_mode)
331
                .await?;
332
333
334
335
336
337
338
339
340
341
342
343
344

                // Note: Embeddings don't need KV routing complexity
                let service_backend = ServiceBackend::from_engine(Arc::new(router));

                // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
                let embedding_engine = frontend
                    .link(preprocessor.forward_edge())?
                    .link(backend.forward_edge())?
                    .link(service_backend)?
                    .link(backend.backward_edge())?
                    .link(preprocessor.backward_edge())?
                    .link(frontend)?;

345
                self.manager
346
                    .add_embeddings_model(&model_entry.name, embedding_engine)?;
347
            }
348
        }
Ryan Olson's avatar
Ryan Olson committed
349

350
351
        Ok(())
    }
352

353
    /// All the registered ModelEntry, one per instance
354
    pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        let Some(etcd_client) = self.drt.etcd_client() else {
            anyhow::bail!("all_entries: Missing etcd client");
        };
        let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
        let mut entries = Vec::with_capacity(kvs.len());
        for kv in kvs {
            let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                Ok(model_entry) => model_entry,
                Err(err) => {
                    match kv.value_str() {
                        Ok(value) => {
                            tracing::error!(%err, value, "Invalid JSON in model entry")
                        }
                        Err(value_str_err) => {
                            tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
                        }
                    }
                    continue;
                }
            };
            entries.push(model_entry);
376
        }
377
378
379
        Ok(entries)
    }

380
    pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
381
382
383
        let mut all = self.all_entries().await?;
        all.retain(|entry| entry.name == model_name);
        Ok(all)
384
    }
Ryan Olson's avatar
Ryan Olson committed
385
}