hub.rs 8.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::env;
5
6
use std::path::{Path, PathBuf};

7
use hf_hub::Cache;
8
use modelexpress_client::{
9
10
    Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
};
11
use modelexpress_common::download as mx;
12

13
use dynamo_runtime::config::environment_names::model as env_model;
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
/// Check if a model is already cached in the HuggingFace hub cache directory.
/// Returns the path to the cached model directory if found, None otherwise.
///
/// Uses hf-hub's Cache API to check for cached files. For tokenizer-only downloads
/// (ignore_weights=true), we check for config.json and tokenizer files.
/// For full downloads, we also require weight files to be present.
fn get_cached_model_path(model_name: &str, ignore_weights: bool) -> Option<PathBuf> {
    let cache = Cache::new(get_model_express_cache_dir());
    let repo = cache.model(model_name.to_string());

    // Check for required config file
    let config_path = repo.get("config.json")?;

    // Check for tokenizer files (at least one must exist)
Nikita's avatar
Nikita committed
29
30
31
32
    let has_tokenizer = repo.get("tokenizer.json").is_some()
        || repo.get("tokenizer_config.json").is_some()
        || repo.get("tiktoken.model").is_some()
        || has_tiktoken_file(config_path.parent()?);
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    if !has_tokenizer {
        return None;
    }

    // For full downloads, check for weight files
    if !ignore_weights {
        // Check common weight file patterns - at least one must exist
        let has_weights = repo.get("model.safetensors").is_some()
            || repo.get("pytorch_model.bin").is_some()
            || repo.get("model.safetensors.index.json").is_some()
            || repo.get("pytorch_model.bin.index.json").is_some();

        if !has_weights {
            return None;
        }
    }

    // Return the parent directory (snapshot dir) containing the model files
    let snapshot_path = config_path.parent()?.to_path_buf();
    tracing::info!("Found cached model '{model_name}' at {snapshot_path:?}, skipping download");
    Some(snapshot_path)
Nikita's avatar
Nikita committed
55
56
57
58
59
60
61
62
63
}

/// Check if the snapshot directory contains any `*.tiktoken` file (e.g. `qwen.tiktoken`).
fn has_tiktoken_file(dir: &Path) -> bool {
    std::fs::read_dir(dir)
        .into_iter()
        .flatten()
        .flatten()
        .any(|e| e.path().extension().is_some_and(|ext| ext == "tiktoken"))
64
65
66
67
68
69
70
71
72
}

/// Check if offline mode is enabled via HF_HUB_OFFLINE environment variable.
fn is_offline_mode() -> bool {
    env::var(env_model::huggingface::HF_HUB_OFFLINE)
        .map(|v| v == "1" || v.to_lowercase() == "true")
        .unwrap_or(false)
}

73
74
/// Download a model using ModelExpress client. The client first requests for the model
/// from the server and fallbacks to direct download in case of server failure.
75
/// If ignore_weights is true, model weight files will be skipped
76
/// Returns the path to the model files
77
78
79
///
/// If HF_HUB_OFFLINE=1 is set and the model is already cached, returns the cached
/// path without making any API calls to HuggingFace.
80
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
81
    let name = name.as_ref();
82
83
    let model_name = name.display().to_string();

84
85
86
87
88
89
90
91
92
93
94
95
96
    // In offline mode, check cache first and return immediately if found
    if is_offline_mode() {
        if let Some(cached_path) = get_cached_model_path(&model_name, ignore_weights) {
            tracing::info!(
                "Offline mode: using cached model '{model_name}' without API validation"
            );
            return Ok(cached_path);
        }
        tracing::warn!(
            "Offline mode enabled but model '{model_name}' not found in cache, attempting download anyway"
        );
    }

97
    let mut config: MxClientConfig = MxClientConfig::default();
98
    if let Ok(endpoint) = env::var(env_model::model_express::MODEL_EXPRESS_URL) {
99
100
        config = config.with_endpoint(endpoint);
    }
101

102
103
104
105
106
107
108
109
110
111
112
113
114
    let result = match MxClient::new(config).await {
        Ok(mut client) => {
            tracing::info!("Successfully connected to ModelExpress server");
            match client
                .request_model_with_provider_and_fallback(
                    &model_name,
                    MxModelProvider::HuggingFace,
                    ignore_weights,
                )
                .await
            {
                Ok(()) => {
                    tracing::info!("Server download succeeded for model: {model_name}");
115
116
117
118
                    match client
                        .get_model_path(&model_name, MxModelProvider::HuggingFace)
                        .await
                    {
119
120
121
122
123
124
125
126
                        Ok(path) => Ok(path),
                        Err(e) => {
                            tracing::warn!(
                                "Failed to resolve local model path after server download for '{model_name}': {e}. \
                                Falling back to direct download."
                            );
                            mx_download_direct(&model_name, ignore_weights).await
                        }
127
128
                    }
                }
129
130
131
132
133
134
                Err(e) => {
                    tracing::warn!(
                        "Server download failed for model '{model_name}': {e}. Falling back to direct download."
                    );
                    mx_download_direct(&model_name, ignore_weights).await
                }
135
136
            }
        }
137
138
139
        Err(e) => {
            tracing::warn!("Cannot connect to ModelExpress server: {e}. Using direct download.");
            mx_download_direct(&model_name, ignore_weights).await
140
        }
141
    };
142

143
144
145
146
    match result {
        Ok(path) => {
            tracing::info!("ModelExpress download completed successfully for model: {model_name}");
            Ok(path)
147
        }
148
149
150
        Err(e) => {
            tracing::warn!("ModelExpress download failed for model '{model_name}': {e}");
            Err(e)
151
        }
152
    }
153
154
}

155
156
// Direct download using the ModelExpress client.
async fn mx_download_direct(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
157
    let cache_dir = get_model_express_cache_dir();
158
159
160
161
162
163
164
    mx::download_model(
        model_name,
        MxModelProvider::HuggingFace,
        Some(cache_dir),
        ignore_weights,
    )
    .await
165
166
}

167
168
// TODO: remove in the future. This is a temporary workaround to find common
// cache directory between client and server.
169
fn get_model_express_cache_dir() -> PathBuf {
170
171
    // Check HF_HUB_CACHE environment variable
    // reference: https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhubcache
172
    if let Ok(cache_path) = env::var(env_model::huggingface::HF_HUB_CACHE) {
173
174
175
        return PathBuf::from(cache_path);
    }

176
177
    // Check HF_HOME environment variable (standard Hugging Face cache directory)
    // reference: https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome
178
    if let Ok(hf_home) = env::var(env_model::huggingface::HF_HOME) {
179
180
181
        return PathBuf::from(hf_home).join("hub");
    }

182
    if let Ok(cache_path) = env::var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH) {
183
184
        return PathBuf::from(cache_path);
    }
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    let home = env::var("HOME")
        .or_else(|_| env::var("USERPROFILE"))
        .unwrap_or_else(|_| ".".to_string());

    PathBuf::from(home).join(".cache/huggingface/hub")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_from_hf_with_model_express() {
        let test_path = PathBuf::from("test-model");
        let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
    }

    #[test]
    fn test_get_model_express_cache_dir() {
        let cache_dir = get_model_express_cache_dir();
        assert!(!cache_dir.to_string_lossy().is_empty());
        assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
    }
209
210
211
212
213
214
215

    #[serial_test::serial]
    #[test]
    fn test_get_model_express_cache_dir_with_hf_home() {
        // Test that HF_HOME is respected when set
        unsafe {
            // Clear other cache env vars to ensure HF_HOME is tested
216
217
218
            env::remove_var(env_model::huggingface::HF_HUB_CACHE);
            env::remove_var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH);
            env::set_var(env_model::huggingface::HF_HOME, "/custom/cache/path");
219
220
221
222
            let cache_dir = get_model_express_cache_dir();
            assert_eq!(cache_dir, PathBuf::from("/custom/cache/path/hub"));

            // Clean up
223
            env::remove_var(env_model::huggingface::HF_HOME);
224
225
        }
    }
226
}