Unverified Commit c920cbd9 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Add --custom-jinja-template argument to pass a custom chat template for vLLM (#2829)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent dea5f887
......@@ -47,6 +47,7 @@ class Config:
migration_limit: int = 0
kv_port: Optional[int] = None
port_range: DynamoPortRange
custom_jinja_template: Optional[str] = None
# mirror vLLM
model: str
......@@ -100,7 +101,7 @@ def parse_args() -> Config:
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
# To avoid name conflicts with different backends, adopted prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
......@@ -115,6 +116,12 @@ def parse_args() -> Config:
choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model.",
)
parser.add_argument(
"--custom-jinja-template",
type=str,
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
......@@ -148,6 +155,7 @@ def parse_args() -> Config:
)
config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser
config.custom_jinja_template = args.custom_jinja_template
# Check for conflicting flags
has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config")
......
......@@ -258,6 +258,7 @@ async def init(runtime: DistributedRuntime, config: Config):
kv_cache_block_size=config.engine_args.block_size,
migration_limit=config.migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
)
try:
......
......@@ -1210,6 +1210,27 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "derive_more"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
"unicode-xid",
]
[[package]]
name = "dialoguer"
version = "0.11.0"
......@@ -1433,8 +1454,10 @@ dependencies = [
"anyhow",
"dynamo-async-openai",
"lazy_static",
"num-traits",
"openai-harmony",
"regex",
"rustpython-parser",
"serde",
"serde_json",
"tracing",
......@@ -1453,6 +1476,7 @@ dependencies = [
"dlpark",
"dynamo-async-openai",
"dynamo-llm",
"dynamo-parsers",
"dynamo-runtime",
"either",
"futures",
......@@ -2205,6 +2229,15 @@ dependencies = [
"version_check",
]
[[package]]
name = "getopts"
version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df"
dependencies = [
"unicode-width",
]
[[package]]
name = "getrandom"
version = "0.2.16"
......@@ -2322,6 +2355,9 @@ name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
......@@ -2799,12 +2835,33 @@ dependencies = [
"serde",
]
[[package]]
name = "is-macro"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
......@@ -2884,6 +2941,12 @@ dependencies = [
"winapi-build",
]
[[package]]
name = "lalrpop-util"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
[[package]]
name = "lazy_static"
version = "1.5.0"
......@@ -3021,6 +3084,64 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
[[package]]
name = "malachite"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5"
dependencies = [
"malachite-base",
"malachite-nz",
"malachite-q",
]
[[package]]
name = "malachite-base"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb"
dependencies = [
"hashbrown 0.14.5",
"itertools 0.11.0",
"libm",
"ryu",
]
[[package]]
name = "malachite-bigint"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8"
dependencies = [
"derive_more",
"malachite",
"num-integer",
"num-traits",
"paste",
]
[[package]]
name = "malachite-nz"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7"
dependencies = [
"itertools 0.11.0",
"libm",
"malachite-base",
]
[[package]]
name = "malachite-q"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191"
dependencies = [
"itertools 0.11.0",
"malachite-base",
"malachite-nz",
]
[[package]]
name = "matchers"
version = "0.1.0"
......@@ -3718,6 +3839,44 @@ dependencies = [
"indexmap 2.11.0",
]
[[package]]
name = "phf"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_shared",
]
[[package]]
name = "phf_codegen"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a"
dependencies = [
"phf_generator",
"phf_shared",
]
[[package]]
name = "phf_generator"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
"rand 0.8.5",
]
[[package]]
name = "phf_shared"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.1.10"
......@@ -4761,6 +4920,63 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustpython-ast"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5"
dependencies = [
"is-macro",
"malachite-bigint",
"rustpython-parser-core",
"static_assertions",
]
[[package]]
name = "rustpython-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb"
dependencies = [
"anyhow",
"is-macro",
"itertools 0.11.0",
"lalrpop-util",
"log",
"malachite-bigint",
"num-traits",
"phf",
"phf_codegen",
"rustc-hash 1.1.0",
"rustpython-ast",
"rustpython-parser-core",
"tiny-keccak",
"unic-emoji-char",
"unic-ucd-ident",
"unicode_names2",
]
[[package]]
name = "rustpython-parser-core"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad"
dependencies = [
"is-macro",
"memchr",
"rustpython-parser-vendored",
]
[[package]]
name = "rustpython-parser-vendored"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6"
dependencies = [
"memchr",
"once_cell",
]
[[package]]
name = "rustversion"
version = "1.0.22"
......@@ -5100,6 +5316,12 @@ dependencies = [
"quote",
]
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]]
name = "slab"
version = "0.4.11"
......@@ -5426,6 +5648,15 @@ dependencies = [
"time-core",
]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]]
name = "tinystr"
version = "0.8.1"
......@@ -5889,6 +6120,58 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unic-char-property"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221"
dependencies = [
"unic-char-range",
]
[[package]]
name = "unic-char-range"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc"
[[package]]
name = "unic-common"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc"
[[package]]
name = "unic-emoji-char"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b07221e68897210270a38bde4babb655869637af0f69407f96053a34f76494d"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-ident"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e230a37c0381caa9219d67cf063aa3a375ffed5bf541a452db16e744bdab6987"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-version"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4"
dependencies = [
"unic-common",
]
[[package]]
name = "unicase"
version = "2.8.1"
......@@ -5922,12 +6205,40 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unicode_names2"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd"
dependencies = [
"phf",
"unicode_names2_generator",
]
[[package]]
name = "unicode_names2_generator"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e"
dependencies = [
"getopts",
"log",
"phf_codegen",
"rand 0.8.5",
]
[[package]]
name = "unindent"
version = "0.2.4"
......
......@@ -141,7 +141,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
......@@ -155,6 +155,7 @@ fn register_llm<'p>(
migration_limit: u32,
runtime_config: Option<ModelRuntimeConfig>,
user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -168,6 +169,19 @@ fn register_llm<'p>(
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
// Early validation of custom template path
let custom_template_path_owned = custom_template_path
.map(|s| {
let path = PathBuf::from(s);
if !path.exists() {
return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
format!("Custom template file does not exist: {}", path.display()),
));
}
Ok(path)
})
.transpose()?;
let user_data_json = user_data
.map(|dict| pythonize::depythonize(dict))
.transpose()
......@@ -185,7 +199,8 @@ fn register_llm<'p>(
.router_config(Some(router_config))
.migration_limit(Some(migration_limit))
.runtime_config(runtime_config.unwrap_or_default().inner)
.user_data(user_data_json);
.user_data(user_data_json)
.custom_template_path(custom_template_path_owned);
// Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
......
......@@ -18,7 +18,9 @@ impl ModelDeploymentCard {
#[staticmethod]
fn load(path: String, model_name: String, py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut card = RsModelDeploymentCard::load(&path).await.map_err(to_pyerr)?;
let mut card = RsModelDeploymentCard::load(&path, None)
.await
.map_err(to_pyerr)?;
card.set_name(&model_name);
Ok(ModelDeploymentCard { inner: card })
})
......
......@@ -859,7 +859,7 @@ class KvRouterConfig:
"""Values for KV router"""
...
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None) -> None:
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None, migration_limit: int = 0, runtime_config: Optional[ModelRuntimeConfig] = None, user_data: Optional[dict] = None, custom_template_path: Optional[str] = None) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
...
......
......@@ -282,7 +282,7 @@ mod tests {
#[tokio::test]
async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?;
let card = ModelDeploymentCard::load(HF_PATH, None).await?;
let engine = crate::engines::make_engine_core();
// Build pipeline for chat completions
......@@ -301,7 +301,7 @@ mod tests {
#[tokio::test]
async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?;
let card = ModelDeploymentCard::load(HF_PATH, None).await?;
let engine = crate::engines::make_engine_core();
// Build pipeline for completions
......
......@@ -59,6 +59,7 @@ pub struct LocalModelBuilder {
extra_engine_args: Option<PathBuf>,
runtime_config: ModelRuntimeConfig,
user_data: Option<serde_json::Value>,
custom_template_path: Option<PathBuf>,
namespace: Option<String>,
}
......@@ -82,6 +83,7 @@ impl Default for LocalModelBuilder {
extra_engine_args: Default::default(),
runtime_config: Default::default(),
user_data: Default::default(),
custom_template_path: Default::default(),
namespace: Default::default(),
}
}
......@@ -154,6 +156,11 @@ impl LocalModelBuilder {
self
}
pub fn custom_template_path(&mut self, custom_template_path: Option<PathBuf>) -> &mut Self {
self.custom_template_path = custom_template_path;
self
}
pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
self.migration_limit = migration_limit.unwrap_or(0);
self
......@@ -245,7 +252,9 @@ impl LocalModelBuilder {
// --model-config takes precedence over --model-path
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card = ModelDeploymentCard::load(&model_config_path).await?;
let mut card =
ModelDeploymentCard::load(&model_config_path, self.custom_template_path.as_deref())
.await?;
// Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| {
......
......@@ -371,11 +371,19 @@ impl ModelDeploymentCard {
/// Build an in-memory ModelDeploymentCard from either:
/// - a folder containing config.json, tokenizer.json and token_config.json
/// - a GGUF file
pub async fn load(config_path: impl AsRef<Path>) -> anyhow::Result<ModelDeploymentCard> {
/// With an optional custom template
pub async fn load(
config_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<ModelDeploymentCard> {
let config_path = config_path.as_ref();
if config_path.is_dir() {
Self::from_local_path(config_path).await
Self::from_local_path(config_path, custom_template_path).await
} else {
// GGUF files don't support custom templates yet
if custom_template_path.is_some() {
anyhow::bail!("Custom templates are not supported for GGUF files");
}
Self::from_gguf(config_path).await
}
}
......@@ -395,7 +403,10 @@ impl ModelDeploymentCard {
/// - The path doesn't exist or isn't a directory
/// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid
async fn from_local_path(local_root_dir: impl AsRef<Path>) -> anyhow::Result<Self> {
async fn from_local_path(
local_root_dir: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
let local_root_dir = local_root_dir.as_ref();
check_valid_local_repo_path(local_root_dir)?;
let repo_id = local_root_dir
......@@ -407,7 +418,8 @@ impl ModelDeploymentCard {
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
Self::from_repo(&repo_id, model_name).await
Self::from_repo(&repo_id, model_name, custom_template_path).await
}
async fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
......@@ -456,7 +468,11 @@ impl ModelDeploymentCard {
))
}
async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> {
async fn from_repo(
repo_id: &str,
model_name: &str,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
// This is usually the right choice
let context_length = crate::file_json_field(
&PathBuf::from(repo_id).join("config.json"),
......@@ -472,6 +488,30 @@ impl ModelDeploymentCard {
// If neither of those are present let the engine default it
.unwrap_or(0);
// Load chat template - either custom or from repo
let chat_template_file = if let Some(template_path) = custom_template_path {
if !template_path.exists() {
anyhow::bail!(
"Custom template file does not exist: {}",
template_path.display()
);
}
// Verify the file is readable
let _template_content = std::fs::read_to_string(template_path).with_context(|| {
format!(
"Failed to read custom template file: {}",
template_path.display()
)
})?;
Some(PromptFormatterArtifact::HfChatTemplate(
template_path.display().to_string(),
))
} else {
PromptFormatterArtifact::chat_template_from_repo(repo_id).await?
};
Ok(Self {
display_name: model_name.to_string(),
slug: Slug::from_string(model_name),
......@@ -479,7 +519,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?,
chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
......
......@@ -26,9 +26,11 @@ impl PromptFormatter {
let content = std::fs::read_to_string(&file)
.with_context(|| format!("fs:read_to_string '{file}'"))?;
let mut config: ChatTemplate = serde_json::from_str(&content)?;
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
// stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file
{
......
......@@ -6,7 +6,7 @@ use dynamo_llm::model_card::ModelDeploymentCard;
#[tokio::test]
async fn test_sequence_factory() {
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1")
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None)
.await
.unwrap();
......
......@@ -8,7 +8,7 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
#[tokio::test]
async fn test_model_info_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap();
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
let info = mdc.model_info.unwrap().get_model_info().await.unwrap();
assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1);
......@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() {
#[tokio::test]
async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist";
let result = ModelDeploymentCard::load(path).await;
let result = ModelDeploymentCard::load(path, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap();
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
// Verify tokenizer file was found
match mdc.tokenizer.unwrap() {
TokenizerKind::HfTokenizerJson(_) => (),
......@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
#[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap();
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
// Verify prompt formatter was found
match mdc.prompt_formatter {
Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
......@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() {
async fn test_missing_required_files() {
// Create empty temp directory
let temp_dir = tempdir().unwrap();
let result = ModelDeploymentCard::load(temp_dir.path()).await;
let result = ModelDeploymentCard::load(temp_dir.path(), None).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
// Should fail because config.json is missing
......
......@@ -57,7 +57,9 @@ async fn make_mdc_from_repo(
//TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::load(downloaded_path).await.unwrap();
let mut mdc = ModelDeploymentCard::load(downloaded_path, None)
.await
.unwrap();
mdc.set_name(&display_name);
mdc.prompt_context = mixins;
mdc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment