model_manager.rs 11 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::{
    collections::{HashMap, HashSet},
6
    sync::Arc,
7
8
};

9
use parking_lot::{Mutex, RwLock};
10

11
use dynamo_runtime::component::Component;
12
use dynamo_runtime::prelude::DistributedRuntimeProvider;
13

14
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
15
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
16
17
use crate::{
    kv_router::KvRouter,
18
    types::generic::tensor::TensorStreamingEngine,
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    types::openai::{
        chat_completions::OpenAIChatCompletionsStreamingEngine,
        completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
    },
};

#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
    #[error("Model not found: {0}")]
    ModelNotFound(String),

    #[error("Model already exists: {0}")]
    ModelAlreadyExists(String),
}

// Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager {
    // We read a lot and write rarely, so these three are RwLock
    completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
    chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
    embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
40
    tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
41
42

    // These two are Mutex because we read and write rarely and equally
43
    cards: Mutex<HashMap<String, ModelDeploymentCard>>,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}

impl Default for ModelManager {
    fn default() -> Self {
        Self::new()
    }
}

impl ModelManager {
    pub fn new() -> Self {
        Self {
            completion_engines: RwLock::new(ModelEngines::default()),
            chat_completion_engines: RwLock::new(ModelEngines::default()),
            embeddings_engines: RwLock::new(ModelEngines::default()),
59
            tensor_engines: RwLock::new(ModelEngines::default()),
60
            cards: Mutex::new(HashMap::new()),
61
62
63
64
            kv_choosers: Mutex::new(HashMap::new()),
        }
    }

65
66
    pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
        self.cards.lock().values().cloned().collect()
67
68
    }

69
    pub fn has_model_any(&self, model: &str) -> bool {
70
71
        self.chat_completion_engines.read().contains(model)
            || self.completion_engines.read().contains(model)
72
73
    }

74
75
76
77
78
    pub fn model_display_names(&self) -> HashSet<String> {
        self.list_chat_completions_models()
            .into_iter()
            .chain(self.list_completions_models())
            .chain(self.list_embeddings_models())
79
            .chain(self.list_tensor_models())
80
81
82
            .collect()
    }

83
    pub fn list_chat_completions_models(&self) -> Vec<String> {
84
        self.chat_completion_engines.read().list()
85
86
87
    }

    pub fn list_completions_models(&self) -> Vec<String> {
88
        self.completion_engines.read().list()
89
90
91
    }

    pub fn list_embeddings_models(&self) -> Vec<String> {
92
        self.embeddings_engines.read().list()
93
94
    }

95
96
97
98
    pub fn list_tensor_models(&self) -> Vec<String> {
        self.tensor_engines.read().list()
    }

99
100
101
102
103
    pub fn add_completions_model(
        &self,
        model: &str,
        engine: OpenAICompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
104
        let mut clients = self.completion_engines.write();
105
106
107
108
109
110
111
112
        clients.add(model, engine)
    }

    pub fn add_chat_completions_model(
        &self,
        model: &str,
        engine: OpenAIChatCompletionsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
113
        let mut clients = self.chat_completion_engines.write();
114
115
116
117
118
119
120
121
        clients.add(model, engine)
    }

    pub fn add_embeddings_model(
        &self,
        model: &str,
        engine: OpenAIEmbeddingsStreamingEngine,
    ) -> Result<(), ModelManagerError> {
122
        let mut clients = self.embeddings_engines.write();
123
124
125
        clients.add(model, engine)
    }

126
127
128
129
130
131
132
133
134
    pub fn add_tensor_model(
        &self,
        model: &str,
        engine: TensorStreamingEngine,
    ) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.add(model, engine)
    }

135
    pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
136
        let mut clients = self.completion_engines.write();
137
138
139
140
        clients.remove(model)
    }

    pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
141
        let mut clients = self.chat_completion_engines.write();
142
143
144
145
        clients.remove(model)
    }

    pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
146
        let mut clients = self.embeddings_engines.write();
147
148
149
        clients.remove(model)
    }

150
151
152
153
154
    pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
        let mut clients = self.tensor_engines.write();
        clients.remove(model)
    }

155
    pub fn get_embeddings_engine(
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        &self,
        model: &str,
    ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
        self.embeddings_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
        self.completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

    pub fn get_chat_completions_engine(
        &self,
        model: &str,
    ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
        self.chat_completion_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

188
189
190
191
192
193
194
195
196
197
198
    pub fn get_tensor_engine(
        &self,
        model: &str,
    ) -> Result<TensorStreamingEngine, ModelManagerError> {
        self.tensor_engines
            .read()
            .get(model)
            .cloned()
            .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
    }

199
    /// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is
200
    /// deleted from etcd.
201
202
    pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) {
        self.cards.lock().insert(key.to_string(), entry);
203
204
    }

205
206
207
    /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
    pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
        self.cards.lock().remove(key)
208
209
210
211
212
213
    }

    pub async fn kv_chooser_for(
        &self,
        model_name: &str,
        component: &Component,
214
        kv_cache_block_size: u32,
215
        kv_router_config: Option<KvRouterConfig>,
216
217
    ) -> anyhow::Result<Arc<KvRouter>> {
        if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
218
219
220
221
222
223
224
225
226
227
            // Check if the existing router has a different block size
            if kv_chooser.block_size() != kv_cache_block_size {
                tracing::warn!(
                    model_name = %model_name,
                    existing_block_size = %kv_chooser.block_size(),
                    requested_block_size = %kv_cache_block_size,
                    "KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \
                     This will cause routing to fail silently. Consider using the same block size or restarting the router."
                );
            }
228
229
230
            return Ok(kv_chooser);
        }

231
        // Create new KV router with etcd registration
232
233
234
235
        let etcd_client = component
            .drt()
            .etcd_client()
            .ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
236
        let router_uuid = uuid::Uuid::new_v4();
237
        let router_key = format!(
238
239
240
241
            "{}/{}/{}",
            KV_ROUTERS_ROOT_PATH,
            component.path(),
            router_uuid
242
243
244
245
246
247
248
249
250
        );
        etcd_client
            .kv_create(
                &router_key,
                serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?,
                None, // use primary lease
            )
            .await?;

251
        let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
252
253
254
255
        let chooser = KvRouter::new(
            component.clone(),
            kv_cache_block_size,
            Some(selector),
256
            kv_router_config,
257
            router_uuid.to_string(),
258
259
        )
        .await?;
260
261
262
263
264
265
        let new_kv_chooser = Arc::new(chooser);
        self.kv_choosers
            .lock()
            .insert(model_name.to_string(), new_kv_chooser.clone());
        Ok(new_kv_chooser)
    }
266

267
268
269
270
    fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
        self.kv_choosers.lock().get(model_name).cloned()
    }

271
    pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
272
        self.cards
273
274
            .lock()
            .values()
275
276
            .find(|c| c.display_name == model)
            .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
277
            .map(|parser| parser.to_string())
278
    }
279
280
281
282
283
284
285
286
287

    /// Creates parsing options with tool call parser and reasoning parser for the specified model.
    /// Currently reasoning parser is not implemented (returns None).
    pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
        let tool_call_parser = self.get_model_tool_call_parser(model);
        let reasoning_parser = None; // TODO: Implement reasoning parser

        crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
    }
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
}

pub struct ModelEngines<E> {
    /// Optional default model name
    default: Option<String>,
    engines: HashMap<String, E>,
}

impl<E> Default for ModelEngines<E> {
    fn default() -> Self {
        Self {
            default: None,
            engines: HashMap::new(),
        }
    }
}

impl<E> ModelEngines<E> {
    #[allow(dead_code)]
    fn set_default(&mut self, model: &str) {
        self.default = Some(model.to_string());
    }

    #[allow(dead_code)]
    fn clear_default(&mut self) {
        self.default = None;
    }

    fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
        if self.engines.contains_key(model) {
            return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
        }
        self.engines.insert(model.to_string(), engine);
        Ok(())
    }

    fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
        if self.engines.remove(model).is_none() {
            return Err(ModelManagerError::ModelNotFound(model.to_string()));
        }
        Ok(())
    }

    fn get(&self, model: &str) -> Option<&E> {
        self.engines.get(model)
    }

    fn contains(&self, model: &str) -> bool {
        self.engines.contains_key(model)
    }

    pub fn list(&self) -> Vec<String> {
        self.engines.keys().map(|k| k.to_owned()).collect()
    }
}