create.rs 8.61 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::model_card::model::ModelDeploymentCard;
use anyhow::{Context, Result};
6
use std::path::{Path, PathBuf};
7

8
use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind};
9

10
11
use super::model::GenerationConfig;

12
impl ModelDeploymentCard {
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    /// Allow user to override the name we register this model under.
    /// Corresponds to vllm's `--served-model-name`.
    pub fn set_name(&mut self, name: &str) {
        self.display_name = name.to_string();
        self.service_name = name.to_string();
    }

    /// Build an in-memory ModelDeploymentCard from either:
    /// - a folder containing config.json, tokenizer.json and token_config.json
    /// - a GGUF file
    pub async fn load(config_path: impl AsRef<Path>) -> anyhow::Result<ModelDeploymentCard> {
        let config_path = config_path.as_ref();
        if config_path.is_dir() {
            Self::from_local_path(config_path).await
        } else {
            Self::from_gguf(config_path).await
        }
    }

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    /// Creates a ModelDeploymentCard from a local directory path.
    ///
    /// Currently HuggingFace format is supported and following files are expected:
    /// - config.json: Model configuration in HuggingFace format
    /// - tokenizer.json: Tokenizer configuration in HuggingFace format
    /// - tokenizer_config.json: Optional prompt formatter configuration
    ///
    /// # Arguments
    /// * `local_root_dir` - Path to the local model directory
    ///
    /// # Errors
    /// Returns an error if:
    /// - The path doesn't exist or isn't a directory
    /// - The path contains invalid Unicode characters
    /// - Required model files are missing or invalid
47
    async fn from_local_path(local_root_dir: impl AsRef<Path>) -> anyhow::Result<Self> {
48
49
50
51
52
53
54
        let local_root_dir = local_root_dir.as_ref();
        check_valid_local_repo_path(local_root_dir)?;
        let repo_id = local_root_dir
            .canonicalize()?
            .to_str()
            .ok_or_else(|| anyhow::anyhow!("Path contains invalid Unicode"))?
            .to_string();
55
56
57
58
        let model_name = local_root_dir
            .file_name()
            .and_then(|n| n.to_str())
            .ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
59
        Self::from_repo(&repo_id, model_name).await
60
61
    }

62
63
64
65
66
    async fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
        let model_name = gguf_file
            .iter()
            .next_back()
            .map(|n| n.to_string_lossy().to_string());
67
68
69
70
71
72
73
        let Some(model_name) = model_name else {
            // I think this would only happy on an empty path
            anyhow::bail!(
                "Could not extract model name from path '{}'",
                gguf_file.display()
            );
        };
74
75
76
77
78

        // TODO: we do this in HFConfig also, unify
        let content = super::model::load_gguf(gguf_file)?;
        let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())]
            .to_u32()
79
            .unwrap_or(0);
80
81
        tracing::debug!(context_length, "Loaded context length from GGUF");

82
83
84
        Ok(Self {
            display_name: model_name.to_string(),
            service_name: model_name.to_string(),
85
86
            model_info: Some(ModelInfoType::GGUF(gguf_file.to_path_buf())),
            tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
87
            gen_config: None, // AFAICT there is no equivalent in a GGUF
88
            prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
89
            chat_template_file: None,
90
91
92
            prompt_context: None, // TODO - auto-detect prompt context
            revision: 0,
            last_published: None,
93
94
            context_length,
            kv_cache_block_size: 0,
95
96
97
        })
    }

98
99
    #[allow(dead_code)]
    async fn from_ngc_repo(_: &str) -> anyhow::Result<Self> {
100
101
102
        Err(anyhow::anyhow!(
            "ModelDeploymentCard::from_ngc_repo is not implemented"
        ))
103
104
    }

105
    async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> {
106
        // This is usually the right choice
107
108
        let context_length = crate::file_json_field(
            &PathBuf::from(repo_id).join("config.json"),
109
            "max_position_embeddings",
110
        )
111
112
        // But sometimes this is
        .or_else(|_| {
113
114
            crate::file_json_field(
                &PathBuf::from(repo_id).join("tokenizer_config.json"),
115
116
117
118
                "model_max_length",
            )
        })
        // If neither of those are present let the engine default it
119
120
        .unwrap_or(0);

121
122
123
        Ok(Self {
            display_name: model_name.to_string(),
            service_name: model_name.to_string(),
124
125
            model_info: Some(ModelInfoType::from_repo(repo_id).await?),
            tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
126
            gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
127
            prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
128
            chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?,
129
130
131
            prompt_context: None, // TODO - auto-detect prompt context
            revision: 0,
            last_published: None,
132
133
            context_length,
            kv_cache_block_size: 0, // set later
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        })
    }
}

impl ModelInfoType {
    pub async fn from_repo(repo_id: &str) -> Result<Self> {
        Self::try_is_hf_repo(repo_id)
            .await
            .with_context(|| format!("unable to extract model info from repo {}", repo_id))
    }

    async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
        Ok(Self::HfConfigJson(
            check_for_file(repo, "config.json").await?,
        ))
    }
}

impl PromptFormatterArtifact {
    pub async fn from_repo(repo_id: &str) -> Result<Option<Self>> {
        // we should only error if we expect a prompt formatter and it's not found
        // right now, we don't know when to expect it, so we just return Ok(Some/None)
        Ok(Self::try_is_hf_repo(repo_id)
            .await
            .with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
            .ok())
    }

162
163
164
165
166
167
168
169
170
171
172
173
174
    pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
        Ok(Self::chat_template_try_is_hf_repo(repo_id)
            .await
            .with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
            .ok())
    }

    async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
        Ok(Self::HfChatTemplate(
            check_for_file(repo, "chat_template.jinja").await?,
        ))
    }

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
        Ok(Self::HfTokenizerConfigJson(
            check_for_file(repo, "tokenizer_config.json").await?,
        ))
    }
}

impl TokenizerKind {
    pub async fn from_repo(repo_id: &str) -> Result<Self> {
        Self::try_is_hf_repo(repo_id)
            .await
            .with_context(|| format!("unable to extract tokenizer kind from repo {}", repo_id))
    }

    async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
        Ok(Self::HfTokenizerJson(
            check_for_file(repo, "tokenizer.json").await?,
        ))
    }
}

196
197
198
199
200
201
202
203
204
205
206
207
impl GenerationConfig {
    pub async fn from_repo(repo_id: &str) -> Result<Self> {
        Self::try_is_hf_repo(repo_id)
            .await
            .with_context(|| format!("unable to extract generation config from repo {repo_id}"))
    }

    async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
        Ok(Self::HfGenerationConfigJson(
            check_for_file(repo, "generation_config.json").await?,
        ))
    }
208
209
}

210
211
212
213
214
215
/// Checks if the provided path contains the expected file.
async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
    let p = PathBuf::from(repo_id).join(file);
    let name = p.display().to_string();
    if !p.exists() {
        anyhow::bail!("File not found: {name}")
216
    }
217
    Ok(name)
218
219
220
221
222
223
224
225
226
227
228
229
}

/// Checks if the provided path is a valid local repository path.
///
/// # Arguments
/// * `path` - Path to validate
///
/// # Errors
/// Returns an error if the path doesn't exist or isn't a directory
fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
    let path = path.as_ref();
    if !path.exists() {
230
231
232
233
        return Err(anyhow::anyhow!(
            "Model path does not exist: {}",
            path.display()
        ));
234
235
236
    }

    if !path.is_dir() {
237
238
239
240
        return Err(anyhow::anyhow!(
            "Model path is not a directory: {}",
            path.display()
        ));
241
242
    }
    Ok(())
243
}