model_manager.rs 13.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use parking_lot::{Mutex, RwLock};
10

11
use dynamo_runtime::component::Component;
12
use dynamo_runtime::prelude::DistributedRuntimeProvider;
13

14
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
15
16
use crate::{
    kv_router::KvRouter,
17
    types::generic::tensor::TensorStreamingEngine,
18
19
20
21
22
    types::openai::{
        chat_completions::OpenAIChatCompletionsStreamingEngine,
        completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
    },
};
23
24
25
26
use crate::{
    kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
    model_type::ModelType,
};
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

// Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
43
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
44

45
    // These are Mutex because we read and write rarely and equally
46
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
62
            tensor_engines: RwLock::new(ModelEngines::default()),
63
            cards: Mutex::new(HashMap::new()),
64
65
66
67
            kv_choosers: Mutex::new(HashMap::new()),
        }
    }

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    pub fn is_valid_checksum(
        &self,
        model_type: ModelType,
        model_name: &str,
        candidate_checksum: &str,
    ) -> Option<bool> {
        let mut results = vec![];
        for unit in model_type.units() {
            let maybe_valid_checksum = match unit {
                ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
                ModelType::Completions => self.completion_engines.read().checksum(model_name),
                ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
                ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
                _ => {
                    continue;
                }
            };
            if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
                tracing::debug!(
                    model_name,
                    valid_checksum,
                    candidate_checksum,
                    "is_valid_checksum: check case"
                );
                valid_checksum == candidate_checksum
            }) {
                results.push(is_valid)
            }
        }
        if results.is_empty() {
            None
        } else {
            // The checksum is valid if it is correct for all the ModelType in the bitflag.
            Some(results.into_iter().all(|x| x))
        }
    }

105
106
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
107
108
    }

109
    pub fn has_model_any(&self, model: &str) -> bool {
110
111
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
112
113
    }

114
115
116
117
118
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
119
            .chain(self.list_tensor_models())
120
121
122
            .collect()
    }

123
    pub fn list_chat_completions_models(&self) -> Vec<String> {
124
        self.chat_completion_engines.read().list()
125
126
127
    }

    pub fn list_completions_models(&self) -> Vec<String> {
128
        self.completion_engines.read().list()
129
130
131
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
132
        self.embeddings_engines.read().list()
133
134
    }

135
136
137
138
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

139
140
141
    pub fn add_completions_model(
        &self,
        model: &str,
142
        card_checksum: &str,
143
144
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
145
        let mut clients = self.completion_engines.write();
146
        clients.add(model, card_checksum, engine)
147
148
149
150
151
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
152
        card_checksum: &str,
153
154
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
155
        let mut clients = self.chat_completion_engines.write();
156
        clients.add(model, card_checksum, engine)
157
158
159
160
161
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
162
        card_checksum: &str,
163
164
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
165
        let mut clients = self.embeddings_engines.write();
166
        clients.add(model, card_checksum, engine)
167
168
    }

169
170
171
    pub fn add_tensor_model(
        &self,
        model: &str,
172
        card_checksum: &str,
173
174
175
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
176
        clients.add(model, card_checksum, engine)
177
178
    }

179
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
180
        let mut clients = self.completion_engines.write();
181
182
183
184
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
185
        let mut clients = self.chat_completion_engines.write();
186
187
188
189
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
190
        let mut clients = self.embeddings_engines.write();
191
192
193
        clients.remove(model)
    }

194
195
196
197
198
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

199
    pub fn get_embeddings_engine(
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

232
233
234
235
236
237
238
239
240
241
242
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

243
244
245
246
247
    /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
    /// deleted.
    pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
        self.cards.lock().insert(key.to_string(), card);
        Ok(())
248
249
    }

250
251
252
    /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
253
254
255
256
257
258
    }

    pub async fn kv_chooser_for(
        &self,
        model_name: &str,
        component: &Component,
259
        kv_cache_block_size: u32,
260
        kv_router_config: Option<KvRouterConfig>,
261
262
    ) -> anyhow::Result<Arc<KvRouter>> {
        if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
263
264
265
266
267
268
269
270
271
272
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
                    model_name = %model_name,
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
                    "KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
273
274
275
            return Ok(kv_chooser);
        }

276
        // Create new KV router with etcd registration
277
278
279
280
        let etcd_client = component
            .drt()
            .etcd_client()
            .ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
281
        let router_uuid = uuid::Uuid::new_v4();
282
        let router_key = format!(
283
284
285
286
            "{}/{}/{}",
            KV_ROUTERS_ROOT_PATH,
            component.path(),
            router_uuid
287
288
289
290
291
292
293
294
295
        );
        etcd_client
            .kv_create(
                &router_key,
                serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?,
                None, // use primary lease
            )
            .await?;

296
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
297
298
299
300
        let chooser = KvRouter::new(
            component.clone(),
            kv_cache_block_size,
            Some(selector),
301
            kv_router_config,
302
            router_uuid.to_string(),
303
304
        )
        .await?;
305
306
307
308
309
310
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
            .insert(model_name.to_string(), new_kv_chooser.clone());
        Ok(new_kv_chooser)
    }
311

312
313
314
315
    fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(model_name).cloned()
    }

316
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
317
        self.cards
318
319
            .lock()
            .values()
320
321
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
322
            .map(|parser| parser.to_string())
323
    }
324
325
326
327
328
329
330
331
332

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
333
334
335
336
337
338
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
339
340
341
    /// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
    /// same card.
    checksums: HashMap<String, String>,
342
343
344
345
346
347
348
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
349
            checksums: HashMap::new(),
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

365
    fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
366
367
368
369
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
370
371
        self.checksums
            .insert(model.to_string(), checksum.to_string());
372
373
374
375
376
377
378
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
379
        let _ = self.checksums.remove(model);
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
394
395
396
397
398
399

    /// Returns a newly allocated String for called convenience. All the places I use
    /// this I need a String.
    pub fn checksum(&self, model: &str) -> Option<String> {
        self.checksums.get(model).map(|s| s.to_string())
    }
400
}