hub.rs 4.36 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::env;
5
6
use std::path::{Path, PathBuf};

7
use modelexpress_client::{
8
9
    Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
};
10
use modelexpress_common::download as mx;
11

12
/// Example: export MODEL_EXPRESS_URL=http://localhost:8001
13
const MODEL_EXPRESS_ENDPOINT_ENV_VAR: &str = "MODEL_EXPRESS_URL";
14

15
16
/// Download a model using ModelExpress client. The client first requests for the model
/// from the server and fallbacks to direct download in case of server failure.
17
/// If ignore_weights is true, model weight files will be skipped
18
/// Returns the path to the model files
19
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
20
    let name = name.as_ref();
21
22
    let model_name = name.display().to_string();

23
    let mut config: MxClientConfig = MxClientConfig::default();
24
    if let Ok(endpoint) = env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR) {
25
26
        config = config.with_endpoint(endpoint);
    }
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    let result = match MxClient::new(config).await {
        Ok(mut client) => {
            tracing::info!("Successfully connected to ModelExpress server");
            match client
                .request_model_with_provider_and_fallback(
                    &model_name,
                    MxModelProvider::HuggingFace,
                    ignore_weights,
                )
                .await
            {
                Ok(()) => {
                    tracing::info!("Server download succeeded for model: {model_name}");
                    match client.get_model_path(&model_name).await {
                        Ok(path) => Ok(path),
                        Err(e) => {
                            tracing::warn!(
                                "Failed to resolve local model path after server download for '{model_name}': {e}. \
                                Falling back to direct download."
                            );
                            mx_download_direct(&model_name, ignore_weights).await
                        }
50
51
                    }
                }
52
53
54
55
56
57
                Err(e) => {
                    tracing::warn!(
                        "Server download failed for model '{model_name}': {e}. Falling back to direct download."
                    );
                    mx_download_direct(&model_name, ignore_weights).await
                }
58
59
            }
        }
60
61
62
        Err(e) => {
            tracing::warn!("Cannot connect to ModelExpress server: {e}. Using direct download.");
            mx_download_direct(&model_name, ignore_weights).await
63
        }
64
    };
65

66
67
68
69
    match result {
        Ok(path) => {
            tracing::info!("ModelExpress download completed successfully for model: {model_name}");
            Ok(path)
70
        }
71
72
73
        Err(e) => {
            tracing::warn!("ModelExpress download failed for model '{model_name}': {e}");
            Err(e)
74
        }
75
    }
76
77
}

78
79
// Direct download using the ModelExpress client.
async fn mx_download_direct(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
80
    let cache_dir = get_model_express_cache_dir();
81
82
83
84
85
86
87
    mx::download_model(
        model_name,
        MxModelProvider::HuggingFace,
        Some(cache_dir),
        ignore_weights,
    )
    .await
88
89
}

90
91
// TODO: remove in the future. This is a temporary workaround to find common
// cache directory between client and server.
92
93
94
95
96
fn get_model_express_cache_dir() -> PathBuf {
    if let Ok(cache_path) = env::var("HF_HUB_CACHE") {
        return PathBuf::from(cache_path);
    }

97
    if let Ok(cache_path) = env::var("MODEL_EXPRESS_CACHE_PATH") {
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        return PathBuf::from(cache_path);
    }
    let home = env::var("HOME")
        .or_else(|_| env::var("USERPROFILE"))
        .unwrap_or_else(|_| ".".to_string());

    PathBuf::from(home).join(".cache/huggingface/hub")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_from_hf_with_model_express() {
        let test_path = PathBuf::from("test-model");
        let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
    }

    #[test]
    fn test_get_model_express_cache_dir() {
        let cache_dir = get_model_express_cache_dir();
        assert!(!cache_dir.to_string_lossy().is_empty());
        assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
    }
123
}