hub.rs 9.82 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
#![allow(unexpected_cfgs)]

6
use hf_hub::api::tokio::ApiBuilder;
7
use std::env;
8
9
use std::path::{Path, PathBuf};

10
11
12
13
14
15
#[cfg(feature = "model-express")]
use model_express_client::{
    Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
};
#[cfg(feature = "model-express")]
use model_express_common::download as mx;
16

17
const MODEL_EXPRESS_ENDPOINT_ENV_VAR: &str = "MODEL_EXPRESS_URL";
18
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
19

20
21
22
23
24
25
26
27
28
/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
    filename.ends_with(".bin")
        || filename.ends_with(".safetensors")
        || filename.ends_with(".h5")
        || filename.ends_with(".msgpack")
        || filename.ends_with(".ckpt.index")
}

29
30
/// Attempt to download a model from Hugging Face using ModelExpress client
/// Only called when model-express feature is enabled, otherwise it will fall back to homonymous hf-hub function
31
/// Returns the directory it is in
32
/// If ignore_weights is true, model weight files will be skipped
33
#[cfg(feature = "model-express")]
34
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
35
    let name = name.as_ref();
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    let model_name = name.display().to_string();

    // Only use ModelExpress if the environment variable is explicitly set
    if let Ok(endpoint) = env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR) {
        tracing::info!(
            "ModelExpress endpoint configured, attempting to use ModelExpress for model: {model_name}"
        );

        let config: MxClientConfig = MxClientConfig::default().with_endpoint(endpoint.clone());

        let result = match MxClient::new(config.clone()).await {
            Ok(mut client) => {
                tracing::info!("Successfully connected to ModelExpress server");
                match client
                    .request_model_with_provider_and_fallback(
                        &model_name,
                        MxModelProvider::HuggingFace,
                    )
                    .await
                {
                    Ok(()) => {
                        tracing::info!("Server download succeeded for model: {model_name}");
                        get_mx_model_path_from_cache(&model_name)
                    }
                    Err(e) => {
                        tracing::warn!(
                            "Server download failed for model '{model_name}': {e}. Falling back to direct download."
                        );
                        mx_download_direct(&model_name).await
                    }
                }
            }
            Err(e) => {
                tracing::warn!(
                    "Cannot connect to ModelExpress server: {e}. Using direct download."
                );
                mx_download_direct(&model_name).await
            }
        };

        match result {
            Ok(path) => {
                tracing::info!(
                    "ModelExpress download completed successfully for model: {model_name}"
                );
                return Ok(path);
            }
            Err(e) => {
                tracing::warn!(
                    "ModelExpress download failed for model '{model_name}': {e}. Falling back to hf-hub."
                );
            }
        }
    }

    tracing::info!("Using hf-hub for model: {model_name}");
    download_with_hf_hub(&model_name, ignore_weights).await
}

/// Attempt to download a model from Hugging Face using hf-hub directly
/// Called when model-express feature is not enabled
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
#[cfg(not(feature = "model-express"))]
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
    let name = name.as_ref();
    let model_name = name.display().to_string();

    if env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR).is_ok() {
        tracing::warn!(
            "ModelExpress endpoint configured but model-express feature not enabled. Using hf-hub."
        );
    }

    tracing::info!("Using hf-hub for model: {model_name}");
    download_with_hf_hub(&model_name, ignore_weights).await
}

// Direct download using the ModelExpress client.
#[cfg(feature = "model-express")]
async fn mx_download_direct(model_name: &str) -> anyhow::Result<PathBuf> {
    let cache_dir = get_model_express_cache_dir();
    mx::download_model(model_name, MxModelProvider::HuggingFace, Some(cache_dir)).await
}

/// Attempt to download a model from Hugging Face with hf-hub
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
async fn download_with_hf_hub(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
125
    let token = env::var(HF_TOKEN_ENV_VAR).ok();
126

127
    let api = ApiBuilder::from_env()
128
129
        .with_progress(true)
        .with_token(token)
130
        .high()
131
        .build()?;
132

133
    let repo = api.model(model_name.to_string());
134

135
136
    let info = repo.info().await
        .map_err(|e| anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace: {e}. Is this a valid HuggingFace ID?"))?;
137
138
139

    if info.siblings.is_empty() {
        return Err(anyhow::anyhow!(
140
            "Model '{model_name}' exists but contains no downloadable files."
141
142
        ));
    }
143

144
    let mut model_path = PathBuf::new();
145
146
    let mut files_downloaded = false;

147
148
    for sibling in info.siblings {
        if is_ignored_file(&sibling.rfilename) || is_image_file(&sibling.rfilename) {
149
150
            continue;
        }
151

152
        if ignore_weights && is_weight_file(&sibling.rfilename) {
153
154
155
            continue;
        }

156
        match repo.get(&sibling.rfilename).await {
157
            Ok(path) => {
158
                model_path = path;
159
160
161
162
                files_downloaded = true;
            }
            Err(e) => {
                return Err(anyhow::anyhow!(
163
164
                    "Failed to download file '{}' from model '{model_name}': {e}",
                    sibling.rfilename
165
166
167
                ));
            }
        }
168
    }
169
170

    if !files_downloaded {
171
172
173
174
175
        let file_type = if ignore_weights {
            "non-weight"
        } else {
            "valid"
        };
176
        return Err(anyhow::anyhow!(
177
            "No {file_type} files found for model '{model_name}'."
178
179
180
        ));
    }

181
182
183
184
185
186
    match model_path.parent() {
        Some(path) => Ok(path.to_path_buf()),
        None => Err(anyhow::anyhow!(
            "Invalid HF cache path: {}",
            model_path.display()
        )),
187
188
189
    }
}

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
fn is_ignored_file(filename: &str) -> bool {
    const IGNORED_FILES: [&str; 5] = [
        ".gitattributes",
        "LICENSE",
        "LICENSE.txt",
        "README.md",
        "USE_POLICY.md",
    ];
    IGNORED_FILES.contains(&filename)
}

fn is_image_file(filename: &str) -> bool {
    filename.ends_with(".png")
        || filename.ends_with("PNG")
        || filename.ends_with(".jpg")
        || filename.ends_with("JPG")
        || filename.ends_with(".jpeg")
        || filename.ends_with("JPEG")
}

#[cfg(feature = "model-express")]
fn get_mx_model_path_from_cache(model_name: &str) -> anyhow::Result<PathBuf> {
    let cache_dir = get_model_express_cache_dir();
    let model_dir = cache_dir.join(model_name);

    if !model_dir.exists() {
        return Err(anyhow::anyhow!(
            "Model '{model_name}' was downloaded but directory not found at expected location: {}",
            model_dir.display()
        ));
    }

    Ok(model_dir)
}

#[cfg(feature = "model-express")]
fn get_model_express_cache_dir() -> PathBuf {
    if let Ok(cache_path) = env::var("HF_HUB_CACHE") {
        return PathBuf::from(cache_path);
    }

    if let Ok(cache_path) = env::var("MODEL_EXPRESS_PATH") {
        return PathBuf::from(cache_path);
    }
    let home = env::var("HOME")
        .or_else(|_| env::var("USERPROFILE"))
        .unwrap_or_else(|_| ".".to_string());

    PathBuf::from(home).join(".cache/huggingface/hub")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_from_hf_with_model_express() {
        let test_path = PathBuf::from("test-model");
        let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
    }

    #[cfg(feature = "model-express")]
    #[test]
    fn test_get_model_express_cache_dir() {
        let cache_dir = get_model_express_cache_dir();
        assert!(!cache_dir.to_string_lossy().is_empty());
        assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
    }

    #[test]
    fn test_is_ignored_file() {
        assert!(is_ignored_file(".gitattributes"));
        assert!(is_ignored_file("LICENSE"));
        assert!(is_ignored_file("LICENSE.txt"));
        assert!(is_ignored_file("README.md"));
        assert!(is_ignored_file("USE_POLICY.md"));

        assert!(!is_ignored_file("model.bin"));
        assert!(!is_ignored_file("tokenizer.json"));
        assert!(!is_ignored_file("config.json"));
    }

    #[test]
    fn test_is_weight_file() {
        assert!(is_weight_file("model.bin"));
        assert!(is_weight_file("model.safetensors"));
        assert!(is_weight_file("model.h5"));
        assert!(is_weight_file("model.msgpack"));
        assert!(is_weight_file("model.ckpt.index"));

        assert!(!is_weight_file("tokenizer.json"));
        assert!(!is_weight_file("config.json"));
        assert!(!is_weight_file("README.md"));
    }

    #[test]
    fn test_is_image_file() {
        assert!(is_image_file("image.png"));
        assert!(is_image_file("image.PNG"));
        assert!(is_image_file("photo.jpg"));
        assert!(is_image_file("photo.JPG"));
        assert!(is_image_file("picture.jpeg"));
        assert!(is_image_file("picture.JPEG"));

        assert!(!is_image_file("model.bin"));
        assert!(!is_image_file("tokenizer.json"));
        assert!(!is_image_file("config.json"));
        assert!(!is_image_file("README.md"));
    }
299
}