local_model.rs 18.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

use std::fs;
use std::path::{Path, PathBuf};

7
use anyhow::Context as _;
8
use dynamo_runtime::component::Endpoint;
9
use dynamo_runtime::discovery::DiscoveryInstance;
10
use dynamo_runtime::discovery::DiscoverySpec;
11
use dynamo_runtime::protocols::EndpointId;
12
use dynamo_runtime::slug::Slug;
13
14
use dynamo_runtime::traits::DistributedRuntimeProvider;

15
use crate::entrypoint::RouterConfig;
16
use crate::model_card::ModelDeploymentCard;
17
use crate::model_type::{ModelInput, ModelType};
18
use crate::preprocessor::media::{MediaDecoder, MediaFetcher};
19
use crate::request_template::RequestTemplate;
20

21
22
23
pub mod runtime_config;

use runtime_config::ModelRuntimeConfig;
24

25
26
27
28
/// What we call a model if the user didn't provide a name. Usually this means the name
/// is invisible, for example in a text chat.
const DEFAULT_NAME: &str = "dynamo";

29
30
31
32
/// Engines don't usually provide a default, so we do.
const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16;

/// We can't have it default to 0, so pick something
Graham King's avatar
Graham King committed
33
34
/// 'pub' because the bindings use it for consistency.
pub const DEFAULT_HTTP_PORT: u16 = 8080;
35
36
37

pub struct LocalModelBuilder {
    model_path: Option<PathBuf>,
38
    source_path: Option<PathBuf>,
39
40
41
42
43
44
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
    router_config: Option<RouterConfig>,
    kv_cache_block_size: u32,
45
    http_host: Option<String>,
46
    http_port: u16,
47
    http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
48
49
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
50
    migration_limit: u32,
51
    is_mocker: bool,
52
53
    extra_engine_args: Option<PathBuf>,
    runtime_config: ModelRuntimeConfig,
54
    user_data: Option<serde_json::Value>,
55
    custom_template_path: Option<PathBuf>,
56
    namespace: Option<String>,
57
    namespace_prefix: Option<String>,
58
59
    media_decoder: Option<MediaDecoder>,
    media_fetcher: Option<MediaFetcher>,
60
61
}

62
impl Default for LocalModelBuilder {
63
    fn default() -> Self {
64
65
        LocalModelBuilder {
            kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE,
66
            http_host: Default::default(),
67
            http_port: DEFAULT_HTTP_PORT,
68
            http_metrics_port: None,
Graham King's avatar
Graham King committed
69
70
            tls_cert_path: Default::default(),
            tls_key_path: Default::default(),
71
            model_path: Default::default(),
72
            source_path: Default::default(),
73
74
75
76
77
            model_name: Default::default(),
            endpoint_id: Default::default(),
            context_length: Default::default(),
            template_file: Default::default(),
            router_config: Default::default(),
78
            migration_limit: Default::default(),
79
            is_mocker: Default::default(),
80
81
            extra_engine_args: Default::default(),
            runtime_config: Default::default(),
82
            user_data: Default::default(),
83
            custom_template_path: Default::default(),
84
            namespace: Default::default(),
85
            namespace_prefix: Default::default(),
86
87
            media_decoder: Default::default(),
            media_fetcher: Default::default(),
88
89
90
91
        }
    }
}

92
impl LocalModelBuilder {
93
    /// The path must exist, the model is already downloaded
94
95
    pub fn model_path(&mut self, model_path: PathBuf) -> &mut Self {
        self.model_path = Some(model_path);
96
        self
97
98
    }

99
100
101
102
103
104
105
106
    /// The HF name of the model before we downloaded it, or a local path if
    /// that was given on the cmd line. We need this because `model_path` is always
    /// a local path.
    pub fn source_path(&mut self, source_path: PathBuf) -> &mut Self {
        self.source_path = Some(source_path);
        self
    }

107
108
109
    pub fn model_name(&mut self, model_name: Option<String>) -> &mut Self {
        self.model_name = model_name;
        self
110
111
    }

112
113
    pub fn endpoint_id(&mut self, endpoint_id: Option<EndpointId>) -> &mut Self {
        self.endpoint_id = endpoint_id;
114
        self
115
116
    }

117
118
119
    pub fn context_length(&mut self, context_length: Option<u32>) -> &mut Self {
        self.context_length = context_length;
        self
120
121
    }

122
123
124
125
    /// Passing None resets it to default
    pub fn kv_cache_block_size(&mut self, kv_cache_block_size: Option<u32>) -> &mut Self {
        self.kv_cache_block_size = kv_cache_block_size.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE);
        self
126
127
    }

128
129
130
131
132
    pub fn http_host(&mut self, host: Option<String>) -> &mut Self {
        self.http_host = host;
        self
    }

Graham King's avatar
Graham King committed
133
134
135
136
137
    pub fn http_port(&mut self, port: u16) -> &mut Self {
        self.http_port = port;
        self
    }

138
139
140
141
142
    pub fn http_metrics_port(&mut self, port: Option<u16>) -> &mut Self {
        self.http_metrics_port = port;
        self
    }

Graham King's avatar
Graham King committed
143
144
145
146
147
148
149
    pub fn tls_cert_path(&mut self, p: Option<PathBuf>) -> &mut Self {
        self.tls_cert_path = p;
        self
    }

    pub fn tls_key_path(&mut self, p: Option<PathBuf>) -> &mut Self {
        self.tls_key_path = p;
150
        self
151
152
    }

153
154
    pub fn router_config(&mut self, router_config: Option<RouterConfig>) -> &mut Self {
        self.router_config = router_config;
155
156
157
        self
    }

158
159
160
161
162
    pub fn namespace(&mut self, namespace: Option<String>) -> &mut Self {
        self.namespace = namespace;
        self
    }

163
164
165
166
167
    pub fn namespace_prefix(&mut self, namespace_prefix: Option<String>) -> &mut Self {
        self.namespace_prefix = namespace_prefix;
        self
    }

168
169
170
    pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
        self.template_file = template_file;
        self
171
172
    }

173
174
175
176
177
    pub fn custom_template_path(&mut self, custom_template_path: Option<PathBuf>) -> &mut Self {
        self.custom_template_path = custom_template_path;
        self
    }

178
179
180
181
182
    pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
        self.migration_limit = migration_limit.unwrap_or(0);
        self
    }

183
184
185
186
187
    pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
        self.is_mocker = is_mocker;
        self
    }

188
189
190
191
192
193
194
195
196
197
    pub fn extra_engine_args(&mut self, extra_engine_args: Option<PathBuf>) -> &mut Self {
        self.extra_engine_args = extra_engine_args;
        self
    }

    pub fn runtime_config(&mut self, runtime_config: ModelRuntimeConfig) -> &mut Self {
        self.runtime_config = runtime_config;
        self
    }

198
199
200
201
202
    pub fn user_data(&mut self, user_data: Option<serde_json::Value>) -> &mut Self {
        self.user_data = user_data;
        self
    }

203
204
205
206
207
208
209
210
211
212
    pub fn media_decoder(&mut self, media_decoder: Option<MediaDecoder>) -> &mut Self {
        self.media_decoder = media_decoder;
        self
    }

    pub fn media_fetcher(&mut self, media_fetcher: Option<MediaFetcher>) -> &mut Self {
        self.media_fetcher = media_fetcher;
        self
    }

213
214
215
216
217
218
219
220
    /// Make an LLM ready for use:
    /// - Download it from Hugging Face (and NGC in future) if necessary
    /// - Resolve the path
    /// - Load it's ModelDeploymentCard card
    /// - Name it correctly
    ///
    /// The model name will depend on what "model_path" is:
    /// - A folder: The last part of the folder name: "/data/llms/Qwen2.5-3B-Instruct" -> "Qwen2.5-3B-Instruct"
221
222
223
224
225
226
227
228
    /// - An HF repo: The HF repo name: "Qwen/Qwen3-0.6B" stays the same
    pub async fn build(&mut self) -> anyhow::Result<LocalModel> {
        // Generate an endpoint ID for this model if the user didn't provide one.
        // The user only provides one if exposing the model.
        let endpoint_id = self
            .endpoint_id
            .take()
            .unwrap_or_else(|| internal_endpoint("local_model"));
229

230
231
232
233
234
235
        let template = self
            .template_file
            .as_deref()
            .map(RequestTemplate::load)
            .transpose()?;

236
        // frontend and echo engine don't need a path.
237
        if self.model_path.is_none() {
238
239
240
            let mut card = ModelDeploymentCard::with_name_only(
                self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
            );
Yan Ru Pei's avatar
Yan Ru Pei committed
241
            card.kv_cache_block_size = self.kv_cache_block_size;
242
            card.migration_limit = self.migration_limit;
243
            card.user_data = self.user_data.take();
244
            card.runtime_config = self.runtime_config.clone();
245
246
            card.media_decoder = self.media_decoder.clone();
            card.media_fetcher = self.media_fetcher.clone();
247

248
            return Ok(LocalModel {
249
                card,
250
251
252
                full_path: PathBuf::new(),
                endpoint_id,
                template,
253
                http_host: self.http_host.take(),
254
                http_port: self.http_port,
255
                http_metrics_port: self.http_metrics_port,
Graham King's avatar
Graham King committed
256
257
                tls_cert_path: self.tls_cert_path.take(),
                tls_key_path: self.tls_key_path.take(),
258
                router_config: self.router_config.take().unwrap_or_default(),
259
                runtime_config: self.runtime_config.clone(),
260
                namespace: self.namespace.clone(),
261
                namespace_prefix: self.namespace_prefix.clone(),
262
                migration_limit: self.migration_limit,
263
264
265
266
267
            });
        }

        // Main logic. We are running a model.
        let model_path = self.model_path.take().unwrap();
268
269
270
271
272
273
274
        if !model_path.exists() {
            anyhow::bail!(
                "Path does not exist: '{}'. Use LocalModel::fetch to download it.",
                model_path.display(),
            );
        }
        let model_path = fs::canonicalize(model_path)?;
275

276
        let mut card =
277
            ModelDeploymentCard::load_from_disk(&model_path, self.custom_template_path.as_deref())?;
278
279
280
281
282
        // Source path is the `--model-path` the user passed. By now our `model_path` is the local
        // path of the downloaded model.
        if let Some(source_path) = self.source_path.take() {
            card.set_source_path(source_path);
        }
283
284
        // The served model name defaults to the full model path.
        // This matches what vllm and sglang do.
285
286
        let alt = card.source_path().to_string();
        card.set_name(self.model_name.as_deref().unwrap_or(&alt));
287

288
        card.kv_cache_block_size = self.kv_cache_block_size;
289

290
291
292
293
        // Override max number of tokens in context. We usually only do this to limit kv cache allocation.
        if let Some(context_length) = self.context_length {
            card.context_length = context_length;
        }
294

295
        card.migration_limit = self.migration_limit;
296
        card.user_data = self.user_data.take();
297
        card.runtime_config = self.runtime_config.clone();
298
299
        card.media_decoder = self.media_decoder.clone();
        card.media_fetcher = self.media_fetcher.clone();
300

301
302
        Ok(LocalModel {
            card,
303
            full_path: model_path,
304
305
            endpoint_id,
            template,
306
            http_host: self.http_host.take(),
307
            http_port: self.http_port,
308
            http_metrics_port: self.http_metrics_port,
Graham King's avatar
Graham King committed
309
310
            tls_cert_path: self.tls_cert_path.take(),
            tls_key_path: self.tls_key_path.take(),
311
            router_config: self.router_config.take().unwrap_or_default(),
312
            runtime_config: self.runtime_config.clone(),
313
            namespace: self.namespace.clone(),
314
            namespace_prefix: self.namespace_prefix.clone(),
315
            migration_limit: self.migration_limit,
316
317
318
319
320
321
322
323
324
325
        })
    }
}

#[derive(Debug, Clone)]
pub struct LocalModel {
    full_path: PathBuf,
    card: ModelDeploymentCard,
    endpoint_id: EndpointId,
    template: Option<RequestTemplate>,
326
    http_host: Option<String>,
Graham King's avatar
Graham King committed
327
    http_port: u16,
328
    http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
329
330
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
331
    router_config: RouterConfig,
332
    runtime_config: ModelRuntimeConfig,
333
    namespace: Option<String>,
334
    namespace_prefix: Option<String>,
335
    migration_limit: u32,
336
337
338
}

impl LocalModel {
339
340
341
342
343
344
345
346
347
    /// Ensure a model is accessible locally, returning it's path.
    /// Downloads the model from Hugging Face if necessary.
    /// If ignore_weights is true, model weight files will be skipped and only the model config
    /// will be downloaded.
    /// Returns the path to the model files
    pub async fn fetch(remote_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
        super::hub::from_hf(remote_name, ignore_weights).await
    }

348
349
350
351
352
353
354
355
    pub fn card(&self) -> &ModelDeploymentCard {
        &self.card
    }

    pub fn path(&self) -> &Path {
        &self.full_path
    }

356
    /// Human friendly model name. This is the correct name.
357
358
359
360
    pub fn display_name(&self) -> &str {
        &self.card.display_name
    }

361
362
    /// The name under which we make this model available over HTTP.
    /// A slugified version of the model's name, for use in NATS, etcd, etc.
363
    pub fn service_name(&self) -> &str {
364
        self.card.slug().as_ref()
365
366
367
368
369
370
    }

    pub fn request_template(&self) -> Option<RequestTemplate> {
        self.template.clone()
    }

371
372
373
374
    pub fn http_host(&self) -> Option<String> {
        self.http_host.clone()
    }

375
376
377
378
    pub fn http_port(&self) -> u16 {
        self.http_port
    }

379
380
381
382
    pub fn http_metrics_port(&self) -> Option<u16> {
        self.http_metrics_port
    }

Graham King's avatar
Graham King committed
383
384
385
386
387
388
389
390
    pub fn tls_cert_path(&self) -> Option<&Path> {
        self.tls_cert_path.as_deref()
    }

    pub fn tls_key_path(&self) -> Option<&Path> {
        self.tls_key_path.as_deref()
    }

391
392
393
394
    pub fn router_config(&self) -> &RouterConfig {
        &self.router_config
    }

395
396
397
398
    pub fn runtime_config(&self) -> &ModelRuntimeConfig {
        &self.runtime_config
    }

399
400
401
402
    pub fn migration_limit(&self) -> u32 {
        self.migration_limit
    }

403
404
405
406
    pub fn namespace(&self) -> Option<&str> {
        self.namespace.as_deref()
    }

407
408
409
410
    pub fn namespace_prefix(&self) -> Option<&str> {
        self.namespace_prefix.as_deref()
    }

411
412
413
414
415
416
417
418
419
    /// An endpoint to identify this model by.
    pub fn endpoint_id(&self) -> &EndpointId {
        &self.endpoint_id
    }

    /// Drop the LocalModel returning it's ModelDeploymentCard.
    /// For the case where we only need the card and don't want to clone it.
    pub fn into_card(self) -> ModelDeploymentCard {
        self.card
420
421
    }

422
    /// Attach this model to the endpoint. This registers it on the network
423
    /// allowing ingress to discover it.
424
425
426
    ///
    /// For base models, pass `lora_name = None`.
    /// For LoRA adapters, pass `lora_name = Some("adapter-name")`.
427
428
429
430
    pub async fn attach(
        &mut self,
        endpoint: &Endpoint,
        model_type: ModelType,
431
        model_input: ModelInput,
432
        lora_info: Option<crate::model_card::LoraInfo>,
433
    ) -> anyhow::Result<()> {
434
435
        self.card.model_type = model_type;
        self.card.model_input = model_input;
436
        self.card.lora = lora_info.clone();
437

438
        // Compute model_suffix from lora_name if present
439
440
441
        let model_suffix = lora_info
            .as_ref()
            .map(|info| Slug::slugify(&info.name).to_string());
442
443
444
445
446
447
448
449
450
451
452
453
454
455

        let suffix_for_log = model_suffix
            .as_ref()
            .map(|s| format!("/{}", s))
            .unwrap_or_default();
        tracing::debug!(
            "Registering MDC at path: {}/{}/{}/{:x}{}",
            endpoint.component().namespace().name(),
            endpoint.component().name(),
            endpoint.name(),
            endpoint.drt().connection_id(),
            suffix_for_log
        );

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        let source_path = PathBuf::from(self.card.source_path());
        if !source_path.exists() {
            // The consumers of MDC (frontend) might not have the same local path as us, so
            // replace disk paths with a custom URL like "hf://Qwen/Qwen3-0.6B/config.json".
            //
            // We can't do this if the model came from disk, as it might not be the same version
            // as on Hugging Face (if it exists there at all).
            //
            // The URL is not used by anything. Frontend will download the repo and edit these
            // paths to be local, so only the filename part matters currently.
            // Possibly we should just use the filenames here. The URL feels nicer to me, it makes
            // each field fully identified and fetchable independently.
            self.card
                .move_to_url(&format!("hf://{}/", self.card.source_path()))
                .context("move_to_url")?;
        }

473
        // Register the Model Deployment Card via discovery interface
474
        // The model_suffix (for LoRA) will be appended AFTER the instance_id
475
        let discovery = endpoint.drt().discovery();
476
        let spec = DiscoverySpec::from_model_with_suffix(
477
478
479
480
            endpoint.component().namespace().name().to_string(),
            endpoint.component().name().to_string(),
            endpoint.name().to_string(),
            &self.card,
481
            model_suffix,
482
483
        )?;
        let _instance = discovery.register(spec).await?;
484

485
        Ok(())
486
    }
487
488

    /// Helper associated function to detach a model from an endpoint
489
490
491
492
493
494
495
    ///
    /// For base models, pass `lora_name = None`.
    /// For LoRA adapters, pass `lora_name = Some("adapter-name")`.
    pub async fn detach_from_endpoint(
        endpoint: &Endpoint,
        lora_name: Option<&str>,
    ) -> anyhow::Result<()> {
496
497
498
499
        let drt = endpoint.drt();
        let instance_id = drt.connection_id();
        let endpoint_id = endpoint.id();

500
501
502
        // Compute model_suffix from lora_name if present
        let model_suffix = lora_name.map(|name| Slug::slugify(name).to_string());

503
504
505
506
507
508
        let instance = DiscoveryInstance::Model {
            namespace: endpoint_id.namespace,
            component: endpoint_id.component,
            endpoint: endpoint_id.name,
            instance_id,
            card_json: serde_json::Value::Null,
509
            model_suffix,
510
511
512
513
514
        };

        let discovery = drt.discovery();
        discovery.unregister(instance).await?;

515
516
517
518
519
520
521
522
        if let Some(lora_name) = lora_name {
            tracing::info!(
                "Successfully unregistered LoRA '{}' from discovery",
                lora_name
            );
        } else {
            tracing::info!("Successfully unregistered model from discovery");
        }
523
524
525

        Ok(())
    }
526
}
527
528
529
530
531
532
533
534
535
536

/// A random endpoint to use for internal communication
/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
fn internal_endpoint(engine: &str) -> EndpointId {
    EndpointId {
        namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
        component: engine.to_string(),
        name: "generate".to_string(),
    }
}