hub.rs 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
17
#![allow(unexpected_cfgs)]

18
use hf_hub::api::tokio::ApiBuilder;
19
use std::env;
20
21
use std::path::{Path, PathBuf};

22
23
24
25
26
27
#[cfg(feature = "model-express")]
use model_express_client::{
    Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
};
#[cfg(feature = "model-express")]
use model_express_common::download as mx;
28

29
const MODEL_EXPRESS_ENDPOINT_ENV_VAR: &str = "MODEL_EXPRESS_URL";
30
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
31

32
33
34
35
36
37
38
39
40
/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
    filename.ends_with(".bin")
        || filename.ends_with(".safetensors")
        || filename.ends_with(".h5")
        || filename.ends_with(".msgpack")
        || filename.ends_with(".ckpt.index")
}

41
42
/// Attempt to download a model from Hugging Face using ModelExpress client
/// Only called when model-express feature is enabled, otherwise it will fall back to homonymous hf-hub function
43
/// Returns the directory it is in
44
/// If ignore_weights is true, model weight files will be skipped
45
#[cfg(feature = "model-express")]
46
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
47
    let name = name.as_ref();
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    let model_name = name.display().to_string();

    // Only use ModelExpress if the environment variable is explicitly set
    if let Ok(endpoint) = env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR) {
        tracing::info!(
            "ModelExpress endpoint configured, attempting to use ModelExpress for model: {model_name}"
        );

        let config: MxClientConfig = MxClientConfig::default().with_endpoint(endpoint.clone());

        let result = match MxClient::new(config.clone()).await {
            Ok(mut client) => {
                tracing::info!("Successfully connected to ModelExpress server");
                match client
                    .request_model_with_provider_and_fallback(
                        &model_name,
                        MxModelProvider::HuggingFace,
                    )
                    .await
                {
                    Ok(()) => {
                        tracing::info!("Server download succeeded for model: {model_name}");
                        get_mx_model_path_from_cache(&model_name)
                    }
                    Err(e) => {
                        tracing::warn!(
                            "Server download failed for model '{model_name}': {e}. Falling back to direct download."
                        );
                        mx_download_direct(&model_name).await
                    }
                }
            }
            Err(e) => {
                tracing::warn!(
                    "Cannot connect to ModelExpress server: {e}. Using direct download."
                );
                mx_download_direct(&model_name).await
            }
        };

        match result {
            Ok(path) => {
                tracing::info!(
                    "ModelExpress download completed successfully for model: {model_name}"
                );
                return Ok(path);
            }
            Err(e) => {
                tracing::warn!(
                    "ModelExpress download failed for model '{model_name}': {e}. Falling back to hf-hub."
                );
            }
        }
    }

    tracing::info!("Using hf-hub for model: {model_name}");
    download_with_hf_hub(&model_name, ignore_weights).await
}

/// Attempt to download a model from Hugging Face using hf-hub directly
/// Called when model-express feature is not enabled
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
#[cfg(not(feature = "model-express"))]
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
    let name = name.as_ref();
    let model_name = name.display().to_string();

    if env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR).is_ok() {
        tracing::warn!(
            "ModelExpress endpoint configured but model-express feature not enabled. Using hf-hub."
        );
    }

    tracing::info!("Using hf-hub for model: {model_name}");
    download_with_hf_hub(&model_name, ignore_weights).await
}

// Direct download using the ModelExpress client.
#[cfg(feature = "model-express")]
async fn mx_download_direct(model_name: &str) -> anyhow::Result<PathBuf> {
    let cache_dir = get_model_express_cache_dir();
    mx::download_model(model_name, MxModelProvider::HuggingFace, Some(cache_dir)).await
}

/// Attempt to download a model from Hugging Face with hf-hub
/// Returns the directory it is in
/// If ignore_weights is true, model weight files will be skipped
async fn download_with_hf_hub(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
137
    let token = env::var(HF_TOKEN_ENV_VAR).ok();
138

139
    let api = ApiBuilder::from_env()
140
141
        .with_progress(true)
        .with_token(token)
142
        .high()
143
        .build()?;
144

145
    let repo = api.model(model_name.to_string());
146

147
148
    let info = repo.info().await
        .map_err(|e| anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace: {e}. Is this a valid HuggingFace ID?"))?;
149
150
151

    if info.siblings.is_empty() {
        return Err(anyhow::anyhow!(
152
            "Model '{model_name}' exists but contains no downloadable files."
153
154
        ));
    }
155

156
    let mut model_path = PathBuf::new();
157
158
    let mut files_downloaded = false;

159
160
    for sibling in info.siblings {
        if is_ignored_file(&sibling.rfilename) || is_image_file(&sibling.rfilename) {
161
162
            continue;
        }
163

164
        if ignore_weights && is_weight_file(&sibling.rfilename) {
165
166
167
            continue;
        }

168
        match repo.get(&sibling.rfilename).await {
169
            Ok(path) => {
170
                model_path = path;
171
172
173
174
                files_downloaded = true;
            }
            Err(e) => {
                return Err(anyhow::anyhow!(
175
176
                    "Failed to download file '{}' from model '{model_name}': {e}",
                    sibling.rfilename
177
178
179
                ));
            }
        }
180
    }
181
182

    if !files_downloaded {
183
184
185
186
187
        let file_type = if ignore_weights {
            "non-weight"
        } else {
            "valid"
        };
188
        return Err(anyhow::anyhow!(
189
            "No {file_type} files found for model '{model_name}'."
190
191
192
        ));
    }

193
194
195
196
197
198
    match model_path.parent() {
        Some(path) => Ok(path.to_path_buf()),
        None => Err(anyhow::anyhow!(
            "Invalid HF cache path: {}",
            model_path.display()
        )),
199
200
201
    }
}

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
fn is_ignored_file(filename: &str) -> bool {
    const IGNORED_FILES: [&str; 5] = [
        ".gitattributes",
        "LICENSE",
        "LICENSE.txt",
        "README.md",
        "USE_POLICY.md",
    ];
    IGNORED_FILES.contains(&filename)
}

fn is_image_file(filename: &str) -> bool {
    filename.ends_with(".png")
        || filename.ends_with("PNG")
        || filename.ends_with(".jpg")
        || filename.ends_with("JPG")
        || filename.ends_with(".jpeg")
        || filename.ends_with("JPEG")
}

#[cfg(feature = "model-express")]
fn get_mx_model_path_from_cache(model_name: &str) -> anyhow::Result<PathBuf> {
    let cache_dir = get_model_express_cache_dir();
    let model_dir = cache_dir.join(model_name);

    if !model_dir.exists() {
        return Err(anyhow::anyhow!(
            "Model '{model_name}' was downloaded but directory not found at expected location: {}",
            model_dir.display()
        ));
    }

    Ok(model_dir)
}

#[cfg(feature = "model-express")]
fn get_model_express_cache_dir() -> PathBuf {
    if let Ok(cache_path) = env::var("HF_HUB_CACHE") {
        return PathBuf::from(cache_path);
    }

    if let Ok(cache_path) = env::var("MODEL_EXPRESS_PATH") {
        return PathBuf::from(cache_path);
    }
    let home = env::var("HOME")
        .or_else(|_| env::var("USERPROFILE"))
        .unwrap_or_else(|_| ".".to_string());

    PathBuf::from(home).join(".cache/huggingface/hub")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_from_hf_with_model_express() {
        let test_path = PathBuf::from("test-model");
        let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
    }

    #[cfg(feature = "model-express")]
    #[test]
    fn test_get_model_express_cache_dir() {
        let cache_dir = get_model_express_cache_dir();
        assert!(!cache_dir.to_string_lossy().is_empty());
        assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
    }

    #[test]
    fn test_is_ignored_file() {
        assert!(is_ignored_file(".gitattributes"));
        assert!(is_ignored_file("LICENSE"));
        assert!(is_ignored_file("LICENSE.txt"));
        assert!(is_ignored_file("README.md"));
        assert!(is_ignored_file("USE_POLICY.md"));

        assert!(!is_ignored_file("model.bin"));
        assert!(!is_ignored_file("tokenizer.json"));
        assert!(!is_ignored_file("config.json"));
    }

    #[test]
    fn test_is_weight_file() {
        assert!(is_weight_file("model.bin"));
        assert!(is_weight_file("model.safetensors"));
        assert!(is_weight_file("model.h5"));
        assert!(is_weight_file("model.msgpack"));
        assert!(is_weight_file("model.ckpt.index"));

        assert!(!is_weight_file("tokenizer.json"));
        assert!(!is_weight_file("config.json"));
        assert!(!is_weight_file("README.md"));
    }

    #[test]
    fn test_is_image_file() {
        assert!(is_image_file("image.png"));
        assert!(is_image_file("image.PNG"));
        assert!(is_image_file("photo.jpg"));
        assert!(is_image_file("photo.JPG"));
        assert!(is_image_file("picture.jpeg"));
        assert!(is_image_file("picture.JPEG"));

        assert!(!is_image_file("model.bin"));
        assert!(!is_image_file("tokenizer.json"));
        assert!(!is_image_file("config.json"));
        assert!(!is_image_file("README.md"));
    }
311
}