Unverified Commit 6f14e941 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Add a checksum to ModelDeploymentCard fields (#2934)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 09c7b73c
...@@ -747,6 +747,8 @@ dependencies = [ ...@@ -747,6 +747,8 @@ dependencies = [
"cc", "cc",
"cfg-if 1.0.3", "cfg-if 1.0.3",
"constant_time_eq", "constant_time_eq",
"memmap2",
"rayon-core",
] ]
[[package]] [[package]]
......
...@@ -611,6 +611,8 @@ dependencies = [ ...@@ -611,6 +611,8 @@ dependencies = [
"cc", "cc",
"cfg-if 1.0.3", "cfg-if 1.0.3",
"constant_time_eq", "constant_time_eq",
"memmap2",
"rayon-core",
] ]
[[package]] [[package]]
......
...@@ -94,7 +94,7 @@ xxhash-rust = { workspace = true } ...@@ -94,7 +94,7 @@ xxhash-rust = { workspace = true }
akin = "0.4.0" akin = "0.4.0"
bitflags = { version = "2.4", features = ["serde"] } bitflags = { version = "2.4", features = ["serde"] }
blake3 = "1" blake3 = { version = "1.8", features=["mmap", "rayon"] }
bytemuck = "1.22" bytemuck = "1.22"
candle-core = { version = "0.9.1" } candle-core = { version = "0.9.1" }
derive-getters = "0.5" derive-getters = "0.5"
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod checked_file;
pub mod dtype; pub mod dtype;
pub mod versioned; pub mod versioned;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::{
fmt::Display,
path::{Path, PathBuf},
str::FromStr,
};
use either::Either;
use serde::{
Deserialize, Deserializer, Serialize, Serializer,
de::{self, Visitor},
ser::SerializeStruct as _,
};
use url::Url;
#[derive(Clone, Debug)]
pub struct CheckedFile {
/// Either a path on local disk or a remote URL (usually nats object store)
path: Either<PathBuf, Url>,
/// Checksum of the contents of path
checksum: Checksum,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Checksum {
/// The checksum is a hex encoded string of the file's content
hash: String,
/// Checksum algorithm
algorithm: CryptographicHashMethods,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub enum CryptographicHashMethods {
#[serde(rename = "blake3")]
BLAKE3,
}
impl CheckedFile {
pub fn from_disk<P: Into<PathBuf>>(filepath: P) -> anyhow::Result<Self> {
let path: PathBuf = filepath.into();
if !path.exists() {
anyhow::bail!("File not found: {}", path.display());
}
if !path.is_file() {
anyhow::bail!("Not a file: {}", path.display());
}
let hash = b3sum(&path)?;
Ok(CheckedFile {
path: Either::Left(path),
checksum: Checksum::blake3(hash),
})
}
/// Replace the local disk path with a remote URL.
/// Just updates the field, doesn't move any files.
pub fn move_to_url(&mut self, u: url::Url) {
self.path = Either::Right(u);
}
/// Replace a remove URL with local disk path.
/// Just updates the field, doesn't move any files.
pub fn move_to_disk<P: Into<PathBuf>>(&mut self, p: P) {
self.path = Either::Left(p.into());
}
pub fn path(&self) -> Option<&Path> {
match self.path.as_ref() {
Either::Left(p) => Some(p),
Either::Right(_) => None,
}
}
pub fn url(&self) -> Option<&Url> {
match self.path.as_ref() {
Either::Left(_) => None,
Either::Right(u) => Some(u),
}
}
pub fn is_nats_url(&self) -> bool {
matches!(self.path.as_ref(), Either::Right(u) if u.scheme() == "nats")
}
pub fn checksum(&self) -> &Checksum {
&self.checksum
}
/// Does the given file checksum to the same value as this CheckedFile?
pub fn checksum_matches<P: AsRef<Path> + std::fmt::Debug>(&self, disk_file: P) -> bool {
match b3sum(&disk_file) {
Ok(h) => Checksum::blake3(h) == self.checksum,
Err(error) => {
tracing::error!(disk_file = %disk_file.as_ref().display(), checked_file = self.to_string(), %error, "Checksum does not match");
false
}
}
}
}
impl Display for CheckedFile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let p = match &self.path {
Either::Left(local) => local.display().to_string(),
Either::Right(url) => url.to_string(),
};
write!(f, "({p}, {})", self.checksum)
}
}
impl Serialize for CheckedFile {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut cf = serializer.serialize_struct("CheckedFile", 2)?;
match &self.path {
Either::Left(path) => cf.serialize_field("path", &path)?,
Either::Right(url) => cf.serialize_field("path", &url)?,
};
cf.serialize_field("checksum", &self.checksum)?;
cf.end()
}
}
/// Internal type to simplify deserializing
#[derive(Deserialize)]
struct WireCheckedFile {
path: String,
checksum: Checksum,
}
// Convert from the temporary struct to CheckedFile with path type logic.
impl From<WireCheckedFile> for CheckedFile {
fn from(temp: WireCheckedFile) -> Self {
// Try to parse as a URL; if successful, use Either::Right(Url), else use Either::Left(PathBuf).
match Url::parse(&temp.path) {
Ok(url) => CheckedFile {
path: Either::Right(url),
checksum: temp.checksum,
},
Err(_) => CheckedFile {
path: Either::Left(PathBuf::from(temp.path)),
checksum: temp.checksum,
},
}
}
}
// Implement Deserialize for CheckedFile using the temporary struct.
impl<'de> Deserialize<'de> for CheckedFile {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Deserialize into WireCheckedFile, then convert to CheckedFile.
let temp = WireCheckedFile::deserialize(deserializer)?;
Ok(CheckedFile::from(temp))
}
}
fn b3sum<T: AsRef<Path> + std::fmt::Debug>(path: T) -> anyhow::Result<String> {
let path = path.as_ref();
let metadata = std::fs::metadata(path)?;
let filesize = metadata.len();
let mut hasher = blake3::Hasher::new();
if filesize > 128_000 {
// multithreaded. blake3 recommend this above 128 KiB.
hasher.update_mmap_rayon(path)?;
} else {
// Uses mmap above 16 KiB, normal load otherwise.
hasher.update_mmap(path)?;
}
let hash = hasher.finalize();
Ok(hash.to_string())
}
impl Checksum {
pub fn blake3(hash: impl Into<String>) -> Self {
Self::new(hash, CryptographicHashMethods::BLAKE3)
}
pub fn new(hash: impl Into<String>, algorithm: CryptographicHashMethods) -> Self {
Self {
hash: hash.into(),
algorithm,
}
}
}
impl Serialize for Checksum {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let serialized_str = format!("{}:{}", self.algorithm, self.hash);
serializer.serialize_str(&serialized_str)
}
}
impl<'de> Deserialize<'de> for Checksum {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ChecksumVisitor;
impl Visitor<'_> for ChecksumVisitor {
type Value = Checksum;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string in the format `{algo}:{hash}`")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let parts: Vec<&str> = value.split(':').collect();
if parts.len() != 2 {
return Err(de::Error::invalid_value(de::Unexpected::Str(value), &self));
}
let algorithm = parts[0].parse().map_err(|_| {
de::Error::invalid_value(de::Unexpected::Str(parts[0]), &"invalid algorithm")
})?;
Ok(Checksum::new(parts[1], algorithm))
}
}
deserializer.deserialize_str(ChecksumVisitor)
}
}
impl TryFrom<&str> for Checksum {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let parts: Vec<&str> = value.split(':').collect();
if parts.len() != 2 {
anyhow::bail!("Invalid checksum format; expect `algo:hash`; got: {value}");
}
let algo = match parts[0] {
"blake3" => CryptographicHashMethods::BLAKE3,
_ => {
anyhow::bail!("Unsupported cryptographic hash method: {}", parts[0]);
}
};
Ok(Checksum::new(parts[1], algo))
}
}
impl FromStr for CryptographicHashMethods {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"blake3" => Ok(CryptographicHashMethods::BLAKE3),
_ => Err(format!("Unsupported algorithm: {}", s)),
}
}
}
impl Display for CryptographicHashMethods {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CryptographicHashMethods::BLAKE3 => write!(f, "blake3"),
}
}
}
impl Display for Checksum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.algorithm, self.hash)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialization_blake3() {
let checksum = Checksum::blake3("a12c3d4");
let serialized = serde_json::to_string(&checksum).unwrap();
assert_eq!(serialized.trim(), "\"blake3:a12c3d4\"");
}
#[test]
fn test_deserialization_blake3() {
let s = "\"blake3:abcd1234\"";
let deserialized: Checksum = serde_json::from_str(s).unwrap();
assert_eq!(deserialized.algorithm, CryptographicHashMethods::BLAKE3);
assert_eq!(deserialized.hash, "abcd1234");
}
#[test]
fn test_deserialization_invalid_format() {
let s = "\"invalidformat\"";
let result: Result<Checksum, _> = serde_json::from_str(s);
assert!(result.is_err());
let s = "\"blake3:invalid:format\"";
let result: Result<Checksum, _> = serde_json::from_str(s);
assert!(result.is_err());
}
#[test]
fn test_checked_file_from_disk() {
let root = env!("CARGO_MANIFEST_DIR"); // ${WORKSPACE}/lib/llm
let full_path = format!("{root}/tests/data/sample-models/TinyLlama_v1.1/config.json");
let cf = CheckedFile::from_disk(full_path).unwrap();
let expected =
Checksum::blake3("62bc124be974d3a25db05bedc99422660c26715e5bbda0b37d14bd84a0c65ab2");
assert_eq!(expected, *cf.checksum());
}
}
...@@ -119,7 +119,7 @@ pub async fn start_kv_router_background( ...@@ -119,7 +119,7 @@ pub async fn start_kv_router_background(
))?; ))?;
match nats_client match nats_client
.object_store_download_data::<Vec<RouterEvent>>(url) .object_store_download_data::<Vec<RouterEvent>>(&url)
.await .await
{ {
Ok(events) => { Ok(events) => {
...@@ -353,7 +353,7 @@ async fn perform_snapshot_and_purge( ...@@ -353,7 +353,7 @@ async fn perform_snapshot_and_purge(
resources resources
.nats_client .nats_client
.object_store_upload_data(&events, url) .object_store_upload_data(&events, &url)
.await .await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?; .map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
......
...@@ -19,13 +19,13 @@ use std::path::{Path, PathBuf}; ...@@ -19,13 +19,13 @@ use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use crate::common::checked_file::CheckedFile;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
use url::Url;
use crate::gguf::{Content, ContentConfig, ModelConfigLike}; use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
...@@ -39,14 +39,14 @@ const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5); ...@@ -39,14 +39,14 @@ const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ModelInfoType { pub enum ModelInfoType {
HfConfigJson(String), HfConfigJson(CheckedFile),
GGUF(PathBuf), GGUF(PathBuf),
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TokenizerKind { pub enum TokenizerKind {
HfTokenizerJson(String), HfTokenizerJson(CheckedFile),
GGUF(Box<HfTokenizer>), GGUF(Box<HfTokenizer>),
} }
...@@ -65,8 +65,8 @@ pub enum TokenizerKind { ...@@ -65,8 +65,8 @@ pub enum TokenizerKind {
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact { pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(String), HfTokenizerConfigJson(CheckedFile),
HfChatTemplate(String), HfChatTemplate(CheckedFile),
GGUF(PathBuf), GGUF(PathBuf),
} }
...@@ -83,7 +83,7 @@ pub enum PromptContextMixin { ...@@ -83,7 +83,7 @@ pub enum PromptContextMixin {
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum GenerationConfig { pub enum GenerationConfig {
HfGenerationConfigJson(String), HfGenerationConfigJson(CheckedFile),
GGUF(PathBuf), GGUF(PathBuf),
} }
...@@ -223,8 +223,11 @@ impl ModelDeploymentCard { ...@@ -223,8 +223,11 @@ impl ModelDeploymentCard {
pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> { pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
match &self.tokenizer { match &self.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => { Some(TokenizerKind::HfTokenizerJson(checked_file)) => {
HfTokenizer::from_file(file).map_err(anyhow::Error::msg) let p = checked_file.path().ok_or_else(||
anyhow::anyhow!("Tokenizer is URL-backed ({:?}); call move_from_nats() before tokenizer_hf()", checked_file.url())
)?;
HfTokenizer::from_file(p).map_err(anyhow::Error::msg)
} }
Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()), Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
None => { None => {
...@@ -253,22 +256,23 @@ impl ModelDeploymentCard { ...@@ -253,22 +256,23 @@ impl ModelDeploymentCard {
macro_rules! nats_upload { macro_rules! nats_upload {
($field:expr, $enum_variant:path, $filename:literal) => { ($field:expr, $enum_variant:path, $filename:literal) => {
if let Some($enum_variant(src_file)) = $field.take() { if let Some($enum_variant(src_file)) = $field.as_mut()
if !nats::is_nats_url(&src_file) { && let Some(path) = src_file.path()
let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename); {
nats_client let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
.object_store_upload( let dest = url::Url::parse(&target)?;
&std::path::PathBuf::from(&src_file), nats_client.object_store_upload(path, &dest).await?;
url::Url::parse(&target)?, src_file.move_to_url(dest);
)
.await?;
$field = Some($enum_variant(target));
}
} }
}; };
} }
nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json"); nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
nats_upload!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
nats_upload!( nats_upload!(
self.prompt_formatter, self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson, PromptFormatterArtifact::HfTokenizerConfigJson,
...@@ -284,11 +288,6 @@ impl ModelDeploymentCard { ...@@ -284,11 +288,6 @@ impl ModelDeploymentCard {
TokenizerKind::HfTokenizerJson, TokenizerKind::HfTokenizerJson,
"tokenizer.json" "tokenizer.json"
); );
nats_upload!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
Ok(()) Ok(())
} }
...@@ -310,19 +309,29 @@ impl ModelDeploymentCard { ...@@ -310,19 +309,29 @@ impl ModelDeploymentCard {
macro_rules! nats_download { macro_rules! nats_download {
($field:expr, $enum_variant:path, $filename:literal) => { ($field:expr, $enum_variant:path, $filename:literal) => {
if let Some($enum_variant(src_url)) = $field.take() { if let Some($enum_variant(src_file)) = $field.as_mut()
if nats::is_nats_url(&src_url) { && let Some(src_url) = src_file.url()
let target = target_dir.path().join($filename); {
nats_client let target = target_dir.path().join($filename);
.object_store_download(Url::parse(&src_url)?, &target) nats_client.object_store_download(src_url, &target).await?;
.await?; if !src_file.checksum_matches(&target) {
$field = Some($enum_variant(target.display().to_string())); anyhow::bail!(
"Invalid {} in NATS for {}, checksum does not match.",
$filename,
self.display_name
);
} }
src_file.move_to_disk(target);
} }
}; };
} }
nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json"); nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
nats_download!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
nats_download!( nats_download!(
self.prompt_formatter, self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson, PromptFormatterArtifact::HfTokenizerConfigJson,
...@@ -338,11 +347,6 @@ impl ModelDeploymentCard { ...@@ -338,11 +347,6 @@ impl ModelDeploymentCard {
TokenizerKind::HfTokenizerJson, TokenizerKind::HfTokenizerJson,
"tokenizer.json" "tokenizer.json"
); );
nats_download!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
Ok(target_dir) Ok(target_dir)
} }
...@@ -499,7 +503,7 @@ impl ModelDeploymentCard { ...@@ -499,7 +503,7 @@ impl ModelDeploymentCard {
})?; })?;
Some(PromptFormatterArtifact::HfChatTemplate( Some(PromptFormatterArtifact::HfChatTemplate(
template_path.display().to_string(), CheckedFile::from_disk(template_path)?,
)) ))
} else { } else {
PromptFormatterArtifact::chat_template_from_repo(repo_id)? PromptFormatterArtifact::chat_template_from_repo(repo_id)?
...@@ -563,8 +567,13 @@ pub trait ModelInfo: Send + Sync { ...@@ -563,8 +567,13 @@ pub trait ModelInfo: Send + Sync {
impl ModelInfoType { impl ModelInfoType {
pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> { pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
match self { match self {
Self::HfConfigJson(info) => HFConfig::from_json_file(info), Self::HfConfigJson(checked_file) => {
Self::GGUF(path) => HFConfig::from_gguf(path), let Some(path) = checked_file.path() else {
anyhow::bail!("model info is not a local path: {checked_file:?}");
};
Ok(HFConfig::from_json_file(path)?)
}
Self::GGUF(path) => Ok(HFConfig::from_gguf(path)?),
} }
} }
pub fn is_gguf(&self) -> bool { pub fn is_gguf(&self) -> bool {
...@@ -615,9 +624,9 @@ struct HFTextConfig { ...@@ -615,9 +624,9 @@ struct HFTextConfig {
} }
impl HFConfig { impl HFConfig {
fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> { fn from_json_file<P: AsRef<Path>>(file: P) -> Result<Arc<dyn ModelInfo>> {
let file_pathbuf = PathBuf::from(file); let file_path = file.as_ref();
let contents = std::fs::read_to_string(file)?; let contents = std::fs::read_to_string(file_path)?;
let mut config: Self = serde_json::from_str(&contents)?; let mut config: Self = serde_json::from_str(&contents)?;
if config.text_config.is_none() { if config.text_config.is_none() {
let text_config: HFTextConfig = serde_json::from_str(&contents)?; let text_config: HFTextConfig = serde_json::from_str(&contents)?;
...@@ -630,17 +639,15 @@ impl HFConfig { ...@@ -630,17 +639,15 @@ impl HFConfig {
); );
}; };
let gencfg_path = file_path
.parent()
.unwrap_or_else(|| Path::new(""))
.join("generation_config.json");
if text_config.bos_token_id.is_none() { if text_config.bos_token_id.is_none() {
let bos_token_id = crate::file_json_field::<TokenIdType>( let bos_token_id = crate::file_json_field::<TokenIdType>(&gencfg_path, "bos_token_id")
&Path::join( .context(
file_pathbuf.parent().unwrap_or(&PathBuf::from("")), "missing bos_token_id in generation_config.json and config.json, cannot load",
"generation_config.json", )?;
),
"bos_token_id",
)
.context(
"missing bos_token_id in generation_config.json and config.json, cannot load",
)?;
text_config.bos_token_id = Some(bos_token_id); text_config.bos_token_id = Some(bos_token_id);
} }
// Now that we have it for sure, set it in the non-Option field // Now that we have it for sure, set it in the non-Option field
...@@ -672,7 +679,7 @@ impl HFConfig { ...@@ -672,7 +679,7 @@ impl HFConfig {
} else { } else {
tracing::error!( tracing::error!(
?v, ?v,
file, path = %file_path.display(),
"eos_token_id is not a number or an array, cannot use" "eos_token_id is not a number or an array, cannot use"
); );
None None
...@@ -680,13 +687,7 @@ impl HFConfig { ...@@ -680,13 +687,7 @@ impl HFConfig {
}) })
.or_else(|| { .or_else(|| {
// Maybe it's in generation_config.json // Maybe it's in generation_config.json
crate::file_json_field( crate::file_json_field(&gencfg_path, "eos_token_id")
&Path::join(
file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
"generation_config.json",
),
"eos_token_id",
)
.inspect_err( .inspect_err(
|err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"), |err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
) )
...@@ -794,12 +795,17 @@ fn capitalize(s: &str) -> String { ...@@ -794,12 +795,17 @@ fn capitalize(s: &str) -> String {
impl ModelInfoType { impl ModelInfoType {
pub fn from_repo(repo_id: &str) -> Result<Self> { pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id) let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("config.json"))
.with_context(|| format!("unable to extract model info from repo {}", repo_id)) .with_context(|| format!("unable to extract config.json from repo {repo_id}"))?;
Ok(Self::HfConfigJson(f))
} }
}
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { impl GenerationConfig {
Ok(Self::HfConfigJson(check_for_file(repo, "config.json")?)) pub fn from_repo(repo_id: &str) -> Result<Self> {
let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("generation_config.json"))
.with_context(|| format!("unable to extract generation_config from repo {repo_id}"))?;
Ok(Self::HfGenerationConfigJson(f))
} }
} }
...@@ -807,68 +813,26 @@ impl PromptFormatterArtifact { ...@@ -807,68 +813,26 @@ impl PromptFormatterArtifact {
pub fn from_repo(repo_id: &str) -> Result<Option<Self>> { pub fn from_repo(repo_id: &str) -> Result<Option<Self>> {
// we should only error if we expect a prompt formatter and it's not found // we should only error if we expect a prompt formatter and it's not found
// right now, we don't know when to expect it, so we just return Ok(Some/None) // right now, we don't know when to expect it, so we just return Ok(Some/None)
Ok(Self::try_is_hf_repo(repo_id) match CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer_config.json")) {
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) Ok(f) => Ok(Some(Self::HfTokenizerConfigJson(f))),
.ok()) Err(_) => Ok(None),
}
} }
pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> { pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
Ok(Self::chat_template_try_is_hf_repo(repo_id) match CheckedFile::from_disk(PathBuf::from(repo_id).join("chat_template.jinja")) {
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) Ok(f) => Ok(Some(Self::HfChatTemplate(f))),
.ok()) Err(_) => Ok(None),
} }
fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate(check_for_file(
repo,
"chat_template.jinja",
)?))
}
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson(check_for_file(
repo,
"tokenizer_config.json",
)?))
} }
} }
impl TokenizerKind { impl TokenizerKind {
pub fn from_repo(repo_id: &str) -> Result<Self> { pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id) let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer.json"))
.with_context(|| format!("unable to extract tokenizer kind from repo {}", repo_id)) .with_context(|| format!("unable to extract tokenizer kind from repo {repo_id}"))?;
} Ok(Self::HfTokenizerJson(f))
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerJson(check_for_file(
repo,
"tokenizer.json",
)?))
}
}
impl GenerationConfig {
pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id)
.with_context(|| format!("unable to extract generation config from repo {repo_id}"))
}
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfGenerationConfigJson(check_for_file(
repo,
"generation_config.json",
)?))
}
}
/// Checks if the provided path contains the expected file.
fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
let p = PathBuf::from(repo_id).join(file);
let name = p.display().to_string();
if !p.exists() {
anyhow::bail!("File not found: {name}")
} }
Ok(name)
} }
/// Checks if the provided path is a valid local repository path. /// Checks if the provided path is a valid local repository path.
...@@ -905,7 +869,7 @@ mod tests { ...@@ -905,7 +869,7 @@ mod tests {
pub fn test_config_json_llama3() -> anyhow::Result<()> { pub fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json"); .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string())?; let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), 128000); assert_eq!(config.bos_token_id(), 128000);
Ok(()) Ok(())
} }
...@@ -914,7 +878,7 @@ mod tests { ...@@ -914,7 +878,7 @@ mod tests {
pub fn test_config_json_llama4() -> anyhow::Result<()> { pub fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json"); .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string())?; let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), 200000); assert_eq!(config.bos_token_id(), 200000);
Ok(()) Ok(())
} }
......
...@@ -23,20 +23,34 @@ impl PromptFormatter { ...@@ -23,20 +23,34 @@ impl PromptFormatter {
.as_ref() .as_ref()
.ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))? .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
{ {
PromptFormatterArtifact::HfTokenizerConfigJson(file) => { PromptFormatterArtifact::HfTokenizerConfigJson(checked_file) => {
let Some(file) = checked_file.path() else {
anyhow::bail!(
"HfTokenizerConfigJson for {} is a URL, cannot load",
mdc.display_name
);
};
let content = std::fs::read_to_string(file) let content = std::fs::read_to_string(file)
.with_context(|| format!("fs:read_to_string '{file}'"))?; .with_context(|| format!("fs:read_to_string '{}'", file.display()))?;
let mut config: ChatTemplate = serde_json::from_str(&content)?; let mut config: ChatTemplate = serde_json::from_str(&content)?;
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
// stores the chat template in a separate file, we check if the file exists and // stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization. // put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag. // This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) = if let Some(PromptFormatterArtifact::HfChatTemplate(checked_file)) =
mdc.chat_template_file.as_ref() mdc.chat_template_file.as_ref()
{ {
let chat_template = std::fs::read_to_string(chat_template_file) let Some(chat_template_file) = checked_file.path() else {
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?; anyhow::bail!(
"HfChatTemplate for {} is a URL, cannot load",
mdc.display_name
);
};
let chat_template =
std::fs::read_to_string(chat_template_file).with_context(|| {
format!("fs:read_to_string '{}'", chat_template_file.display())
})?;
// clean up the string to remove newlines // clean up the string to remove newlines
let chat_template = chat_template.replace('\n', ""); let chat_template = chat_template.replace('\n', "");
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template))); config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
......
...@@ -173,10 +173,10 @@ impl Client { ...@@ -173,10 +173,10 @@ impl Client {
} }
/// Upload file to NATS at this URL /// Upload file to NATS at this URL
pub async fn object_store_upload(&self, filepath: &Path, nats_url: Url) -> anyhow::Result<()> { pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> {
let mut disk_file = TokioFile::open(filepath).await?; let mut disk_file = TokioFile::open(filepath).await?;
let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
let bucket = self.get_or_create_bucket(&bucket_name, true).await?; let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
let key_meta = async_nats::jetstream::object_store::ObjectMetadata { let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
...@@ -193,12 +193,12 @@ impl Client { ...@@ -193,12 +193,12 @@ impl Client {
/// Download file from NATS at this URL /// Download file from NATS at this URL
pub async fn object_store_download( pub async fn object_store_download(
&self, &self,
nats_url: Url, nats_url: &Url,
filepath: &Path, filepath: &Path,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut disk_file = TokioFile::create(filepath).await?; let mut disk_file = TokioFile::create(filepath).await?;
let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
let bucket = self.get_or_create_bucket(&bucket_name, false).await?; let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
let mut obj_reader = bucket.get(&key).await.map_err(|e| { let mut obj_reader = bucket.get(&key).await.map_err(|e| {
...@@ -225,7 +225,7 @@ impl Client { ...@@ -225,7 +225,7 @@ impl Client {
} }
/// Upload a serializable struct to NATS object store using bincode /// Upload a serializable struct to NATS object store using bincode
pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: Url) -> anyhow::Result<()> pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: &Url) -> anyhow::Result<()>
where where
T: Serialize, T: Serialize,
{ {
...@@ -233,7 +233,7 @@ impl Client { ...@@ -233,7 +233,7 @@ impl Client {
let binary_data = bincode::serialize(data) let binary_data = bincode::serialize(data)
.map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?;
let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
let bucket = self.get_or_create_bucket(&bucket_name, true).await?; let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
let key_meta = async_nats::jetstream::object_store::ObjectMetadata { let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
...@@ -251,11 +251,11 @@ impl Client { ...@@ -251,11 +251,11 @@ impl Client {
} }
/// Download and deserialize a struct from NATS object store using bincode /// Download and deserialize a struct from NATS object store using bincode
pub async fn object_store_download_data<T>(&self, nats_url: Url) -> anyhow::Result<T> pub async fn object_store_download_data<T>(&self, nats_url: &Url) -> anyhow::Result<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
let bucket = self.get_or_create_bucket(&bucket_name, false).await?; let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
let mut obj_reader = bucket.get(&key).await.map_err(|e| { let mut obj_reader = bucket.get(&key).await.map_err(|e| {
...@@ -1078,13 +1078,13 @@ mod tests { ...@@ -1078,13 +1078,13 @@ mod tests {
// Upload the data // Upload the data
client client
.object_store_upload_data(&test_data, url.clone()) .object_store_upload_data(&test_data, &url)
.await .await
.expect("Failed to upload data"); .expect("Failed to upload data");
// Download the data // Download the data
let downloaded_data: TestData = client let downloaded_data: TestData = client
.object_store_download_data(url.clone()) .object_store_download_data(&url)
.await .await
.expect("Failed to download data"); .expect("Failed to download data");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment