model_manager.rs 14.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use parking_lot::{Mutex, RwLock};
10

11
use dynamo_runtime::component::Component;
12
use dynamo_runtime::prelude::DistributedRuntimeProvider;
13

14
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
15
16
use crate::{
    kv_router::KvRouter,
17
    types::generic::tensor::TensorStreamingEngine,
18
19
20
21
22
    types::openai::{
        chat_completions::OpenAIChatCompletionsStreamingEngine,
        completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
    },
};
23
24
25
26
use crate::{
    kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
    model_type::ModelType,
};
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

// Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
43
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
44
    prefill_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
45

46
    // These are Mutex because we read and write rarely and equally
47
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
63
            tensor_engines: RwLock::new(ModelEngines::default()),
64
            prefill_engines: RwLock::new(ModelEngines::default()),
65
            cards: Mutex::new(HashMap::new()),
66
67
68
69
            kv_choosers: Mutex::new(HashMap::new()),
        }
    }

70
71
72
73
74
75
76
77
78
79
80
81
82
    pub fn is_valid_checksum(
        &self,
        model_type: ModelType,
        model_name: &str,
        candidate_checksum: &str,
    ) -> Option<bool> {
        let mut results = vec![];
        for unit in model_type.units() {
            let maybe_valid_checksum = match unit {
                ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
                ModelType::Completions => self.completion_engines.read().checksum(model_name),
                ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
                ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
83
                ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                _ => {
                    continue;
                }
            };
            if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
                tracing::debug!(
                    model_name,
                    valid_checksum,
                    candidate_checksum,
                    "is_valid_checksum: check case"
                );
                valid_checksum == candidate_checksum
            }) {
                results.push(is_valid)
            }
        }
        if results.is_empty() {
            None
        } else {
            // The checksum is valid if it is correct for all the ModelType in the bitflag.
            Some(results.into_iter().all(|x| x))
        }
    }

108
109
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
110
111
    }

112
    pub fn has_model_any(&self, model: &str) -> bool {
113
114
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
115
116
    }

117
118
119
120
121
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
122
            .chain(self.list_tensor_models())
123
            .chain(self.list_prefill_models())
124
125
126
            .collect()
    }

127
    pub fn list_chat_completions_models(&self) -> Vec<String> {
128
        self.chat_completion_engines.read().list()
129
130
131
    }

    pub fn list_completions_models(&self) -> Vec<String> {
132
        self.completion_engines.read().list()
133
134
135
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
136
        self.embeddings_engines.read().list()
137
138
    }

139
140
141
142
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

143
144
145
146
    pub fn list_prefill_models(&self) -> Vec<String> {
        self.prefill_engines.read().list()
    }

147
148
149
    pub fn add_completions_model(
        &self,
        model: &str,
150
        card_checksum: &str,
151
152
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
153
        let mut clients = self.completion_engines.write();
154
        clients.add(model, card_checksum, engine)
155
156
157
158
159
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
160
        card_checksum: &str,
161
162
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
163
        let mut clients = self.chat_completion_engines.write();
164
        clients.add(model, card_checksum, engine)
165
166
167
168
169
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
170
        card_checksum: &str,
171
172
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
173
        let mut clients = self.embeddings_engines.write();
174
        clients.add(model, card_checksum, engine)
175
176
    }

177
178
179
    pub fn add_tensor_model(
        &self,
        model: &str,
180
        card_checksum: &str,
181
182
183
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
184
        clients.add(model, card_checksum, engine)
185
186
    }

187
188
189
190
191
192
193
194
195
196
    pub fn add_prefill_model(
        &self,
        model: &str,
        card_checksum: &str,
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
        clients.add(model, card_checksum, engine)
    }

197
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
198
        let mut clients = self.completion_engines.write();
199
200
201
202
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
203
        let mut clients = self.chat_completion_engines.write();
204
205
206
207
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
208
        let mut clients = self.embeddings_engines.write();
209
210
211
        clients.remove(model)
    }

212
213
214
215
216
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

217
218
219
220
221
    pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.prefill_engines.write();
        clients.remove(model)
    }

222
    pub fn get_embeddings_engine(
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

255
256
257
258
259
260
261
262
263
264
265
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

266
267
268
269
270
271
272
273
274
275
276
    pub fn get_prefill_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.prefill_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

277
278
279
280
281
    /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.lock().insert(key.to_string(), card);
        Ok(())
282
283
    }

284
285
286
    /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
287
288
289
290
291
292
    }

    pub async fn kv_chooser_for(
        &self,
        model_name: &str,
        component: &Component,
293
        kv_cache_block_size: u32,
294
        kv_router_config: Option<KvRouterConfig>,
295
296
    ) -> anyhow::Result<Arc<KvRouter>> {
        if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
297
298
299
300
301
302
303
304
305
306
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
                    model_name = %model_name,
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
                    "KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
307
308
309
            return Ok(kv_chooser);
        }

310
        // Create new KV router with etcd registration
311
312
313
314
        let etcd_client = component
            .drt()
            .etcd_client()
            .ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
315
        let router_uuid = uuid::Uuid::new_v4();
316
        let router_key = format!(
317
318
319
320
            "{}/{}/{}",
            KV_ROUTERS_ROOT_PATH,
            component.path(),
            router_uuid
321
322
323
324
325
326
327
328
329
        );
        etcd_client
            .kv_create(
                &router_key,
                serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?,
                None, // use primary lease
            )
            .await?;

330
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
331
332
333
334
        let chooser = KvRouter::new(
            component.clone(),
            kv_cache_block_size,
            Some(selector),
335
            kv_router_config,
336
            router_uuid.to_string(),
337
338
        )
        .await?;
339
340
341
342
343
344
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
            .insert(model_name.to_string(), new_kv_chooser.clone());
        Ok(new_kv_chooser)
    }
345

346
347
348
349
    fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(model_name).cloned()
    }

350
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
351
        self.cards
352
353
            .lock()
            .values()
354
355
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
356
            .map(|parser| parser.to_string())
357
    }
358
359
360
361
362
363
364
365
366

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
367
368
369
370
371
372
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
373
374
375
    /// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
    /// same card.
    checksums: HashMap<String, String>,
376
377
378
379
380
381
382
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
383
            checksums: HashMap::new(),
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

399
    fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
400
401
402
403
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
404
405
        self.checksums
            .insert(model.to_string(), checksum.to_string());
406
407
408
409
410
411
412
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
413
        let _ = self.checksums.remove(model);
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
428
429
430
431
432
433

    /// Returns a newly allocated String for called convenience. All the places I use
    /// this I need a String.
    pub fn checksum(&self, model: &str) -> Option<String> {
        self.checksums.get(model).map(|s| s.to_string())
    }
434
}