Unverified Commit c920cbd9 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Add --custom-jinja-template argument to pass a custom chat template for vLLM (#2829)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent dea5f887
...@@ -47,6 +47,7 @@ class Config: ...@@ -47,6 +47,7 @@ class Config:
migration_limit: int = 0 migration_limit: int = 0
kv_port: Optional[int] = None kv_port: Optional[int] = None
port_range: DynamoPortRange port_range: DynamoPortRange
custom_jinja_template: Optional[str] = None
# mirror vLLM # mirror vLLM
model: str model: str
...@@ -100,7 +101,7 @@ def parse_args() -> Config: ...@@ -100,7 +101,7 @@ def parse_args() -> Config:
help="List of connectors to use in order (e.g., --connector nixl lmcache). " help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.", "Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
) )
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args # To avoid name conflicts with different backends, adopted prefix "dyn-" for dynamo specific args
parser.add_argument( parser.add_argument(
"--dyn-tool-call-parser", "--dyn-tool-call-parser",
type=str, type=str,
...@@ -115,6 +116,12 @@ def parse_args() -> Config: ...@@ -115,6 +116,12 @@ def parse_args() -> Config:
choices=get_reasoning_parser_names(), choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model.", help="Reasoning parser name for the model.",
) )
parser.add_argument(
"--custom-jinja-template",
type=str,
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -148,6 +155,7 @@ def parse_args() -> Config: ...@@ -148,6 +155,7 @@ def parse_args() -> Config:
) )
config.tool_call_parser = args.dyn_tool_call_parser config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser config.reasoning_parser = args.dyn_reasoning_parser
config.custom_jinja_template = args.custom_jinja_template
# Check for conflicting flags # Check for conflicting flags
has_kv_transfer_config = ( has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config") hasattr(engine_args, "kv_transfer_config")
......
...@@ -258,6 +258,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -258,6 +258,7 @@ async def init(runtime: DistributedRuntime, config: Config):
kv_cache_block_size=config.engine_args.block_size, kv_cache_block_size=config.engine_args.block_size,
migration_limit=config.migration_limit, migration_limit=config.migration_limit,
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
) )
try: try:
......
...@@ -1210,6 +1210,27 @@ dependencies = [ ...@@ -1210,6 +1210,27 @@ dependencies = [
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "derive_more"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
"unicode-xid",
]
[[package]] [[package]]
name = "dialoguer" name = "dialoguer"
version = "0.11.0" version = "0.11.0"
...@@ -1433,8 +1454,10 @@ dependencies = [ ...@@ -1433,8 +1454,10 @@ dependencies = [
"anyhow", "anyhow",
"dynamo-async-openai", "dynamo-async-openai",
"lazy_static", "lazy_static",
"num-traits",
"openai-harmony", "openai-harmony",
"regex", "regex",
"rustpython-parser",
"serde", "serde",
"serde_json", "serde_json",
"tracing", "tracing",
...@@ -1453,6 +1476,7 @@ dependencies = [ ...@@ -1453,6 +1476,7 @@ dependencies = [
"dlpark", "dlpark",
"dynamo-async-openai", "dynamo-async-openai",
"dynamo-llm", "dynamo-llm",
"dynamo-parsers",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
"futures", "futures",
...@@ -2205,6 +2229,15 @@ dependencies = [ ...@@ -2205,6 +2229,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "getopts"
version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df"
dependencies = [
"unicode-width",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.16" version = "0.2.16"
...@@ -2322,6 +2355,9 @@ name = "hashbrown" ...@@ -2322,6 +2355,9 @@ name = "hashbrown"
version = "0.14.5" version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
...@@ -2799,12 +2835,33 @@ dependencies = [ ...@@ -2799,12 +2835,33 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "is-macro"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.1" version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.12.1" version = "0.12.1"
...@@ -2884,6 +2941,12 @@ dependencies = [ ...@@ -2884,6 +2941,12 @@ dependencies = [
"winapi-build", "winapi-build",
] ]
[[package]]
name = "lalrpop-util"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.5.0" version = "1.5.0"
...@@ -3021,6 +3084,64 @@ version = "0.2.2" ...@@ -3021,6 +3084,64 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
[[package]]
name = "malachite"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5"
dependencies = [
"malachite-base",
"malachite-nz",
"malachite-q",
]
[[package]]
name = "malachite-base"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb"
dependencies = [
"hashbrown 0.14.5",
"itertools 0.11.0",
"libm",
"ryu",
]
[[package]]
name = "malachite-bigint"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8"
dependencies = [
"derive_more",
"malachite",
"num-integer",
"num-traits",
"paste",
]
[[package]]
name = "malachite-nz"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7"
dependencies = [
"itertools 0.11.0",
"libm",
"malachite-base",
]
[[package]]
name = "malachite-q"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191"
dependencies = [
"itertools 0.11.0",
"malachite-base",
"malachite-nz",
]
[[package]] [[package]]
name = "matchers" name = "matchers"
version = "0.1.0" version = "0.1.0"
...@@ -3718,6 +3839,44 @@ dependencies = [ ...@@ -3718,6 +3839,44 @@ dependencies = [
"indexmap 2.11.0", "indexmap 2.11.0",
] ]
[[package]]
name = "phf"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_shared",
]
[[package]]
name = "phf_codegen"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a"
dependencies = [
"phf_generator",
"phf_shared",
]
[[package]]
name = "phf_generator"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
"rand 0.8.5",
]
[[package]]
name = "phf_shared"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]] [[package]]
name = "pin-project" name = "pin-project"
version = "1.1.10" version = "1.1.10"
...@@ -4761,6 +4920,63 @@ dependencies = [ ...@@ -4761,6 +4920,63 @@ dependencies = [
"untrusted", "untrusted",
] ]
[[package]]
name = "rustpython-ast"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5"
dependencies = [
"is-macro",
"malachite-bigint",
"rustpython-parser-core",
"static_assertions",
]
[[package]]
name = "rustpython-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb"
dependencies = [
"anyhow",
"is-macro",
"itertools 0.11.0",
"lalrpop-util",
"log",
"malachite-bigint",
"num-traits",
"phf",
"phf_codegen",
"rustc-hash 1.1.0",
"rustpython-ast",
"rustpython-parser-core",
"tiny-keccak",
"unic-emoji-char",
"unic-ucd-ident",
"unicode_names2",
]
[[package]]
name = "rustpython-parser-core"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad"
dependencies = [
"is-macro",
"memchr",
"rustpython-parser-vendored",
]
[[package]]
name = "rustpython-parser-vendored"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6"
dependencies = [
"memchr",
"once_cell",
]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.22" version = "1.0.22"
...@@ -5100,6 +5316,12 @@ dependencies = [ ...@@ -5100,6 +5316,12 @@ dependencies = [
"quote", "quote",
] ]
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.11" version = "0.4.11"
...@@ -5426,6 +5648,15 @@ dependencies = [ ...@@ -5426,6 +5648,15 @@ dependencies = [
"time-core", "time-core",
] ]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]] [[package]]
name = "tinystr" name = "tinystr"
version = "0.8.1" version = "0.8.1"
...@@ -5889,6 +6120,58 @@ version = "0.2.2" ...@@ -5889,6 +6120,58 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c" checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unic-char-property"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221"
dependencies = [
"unic-char-range",
]
[[package]]
name = "unic-char-range"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc"
[[package]]
name = "unic-common"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc"
[[package]]
name = "unic-emoji-char"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b07221e68897210270a38bde4babb655869637af0f69407f96053a34f76494d"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-ident"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e230a37c0381caa9219d67cf063aa3a375ffed5bf541a452db16e744bdab6987"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-version"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4"
dependencies = [
"unic-common",
]
[[package]] [[package]]
name = "unicase" name = "unicase"
version = "2.8.1" version = "2.8.1"
...@@ -5922,12 +6205,40 @@ version = "0.2.1" ...@@ -5922,12 +6205,40 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]] [[package]]
name = "unicode_categories" name = "unicode_categories"
version = "0.1.1" version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unicode_names2"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd"
dependencies = [
"phf",
"unicode_names2_generator",
]
[[package]]
name = "unicode_names2_generator"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e"
dependencies = [
"getopts",
"log",
"phf_codegen",
"rand 0.8.5",
]
[[package]] [[package]]
name = "unindent" name = "unindent"
version = "0.2.4" version = "0.2.4"
......
...@@ -141,7 +141,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) ...@@ -141,7 +141,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None))] #[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
...@@ -155,6 +155,7 @@ fn register_llm<'p>( ...@@ -155,6 +155,7 @@ fn register_llm<'p>(
migration_limit: u32, migration_limit: u32,
runtime_config: Option<ModelRuntimeConfig>, runtime_config: Option<ModelRuntimeConfig>,
user_data: Option<&Bound<'p, PyDict>>, user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type { let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelType::Chat => llm_rs::model_type::ModelType::Chat,
...@@ -168,6 +169,19 @@ fn register_llm<'p>( ...@@ -168,6 +169,19 @@ fn register_llm<'p>(
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin); let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default()); let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
// Early validation of custom template path
let custom_template_path_owned = custom_template_path
.map(|s| {
let path = PathBuf::from(s);
if !path.exists() {
return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
format!("Custom template file does not exist: {}", path.display()),
));
}
Ok(path)
})
.transpose()?;
let user_data_json = user_data let user_data_json = user_data
.map(|dict| pythonize::depythonize(dict)) .map(|dict| pythonize::depythonize(dict))
.transpose() .transpose()
...@@ -185,7 +199,8 @@ fn register_llm<'p>( ...@@ -185,7 +199,8 @@ fn register_llm<'p>(
.router_config(Some(router_config)) .router_config(Some(router_config))
.migration_limit(Some(migration_limit)) .migration_limit(Some(migration_limit))
.runtime_config(runtime_config.unwrap_or_default().inner) .runtime_config(runtime_config.unwrap_or_default().inner)
.user_data(user_data_json); .user_data(user_data_json)
.custom_template_path(custom_template_path_owned);
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
......
...@@ -18,7 +18,9 @@ impl ModelDeploymentCard { ...@@ -18,7 +18,9 @@ impl ModelDeploymentCard {
#[staticmethod] #[staticmethod]
fn load(path: String, model_name: String, py: Python<'_>) -> PyResult<Bound<'_, PyAny>> { fn load(path: String, model_name: String, py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut card = RsModelDeploymentCard::load(&path).await.map_err(to_pyerr)?; let mut card = RsModelDeploymentCard::load(&path, None)
.await
.map_err(to_pyerr)?;
card.set_name(&model_name); card.set_name(&model_name);
Ok(ModelDeploymentCard { inner: card }) Ok(ModelDeploymentCard { inner: card })
}) })
......
...@@ -859,7 +859,7 @@ class KvRouterConfig: ...@@ -859,7 +859,7 @@ class KvRouterConfig:
"""Values for KV router""" """Values for KV router"""
... ...
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None) -> None: async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None, migration_limit: int = 0, runtime_config: Optional[ModelRuntimeConfig] = None, user_data: Optional[dict] = None, custom_template_path: Optional[str] = None) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type""" """Attach the model at path to the given endpoint, and advertise it as model_type"""
... ...
......
...@@ -282,7 +282,7 @@ mod tests { ...@@ -282,7 +282,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card // Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?; let card = ModelDeploymentCard::load(HF_PATH, None).await?;
let engine = crate::engines::make_engine_core(); let engine = crate::engines::make_engine_core();
// Build pipeline for chat completions // Build pipeline for chat completions
...@@ -301,7 +301,7 @@ mod tests { ...@@ -301,7 +301,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card // Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?; let card = ModelDeploymentCard::load(HF_PATH, None).await?;
let engine = crate::engines::make_engine_core(); let engine = crate::engines::make_engine_core();
// Build pipeline for completions // Build pipeline for completions
......
...@@ -59,6 +59,7 @@ pub struct LocalModelBuilder { ...@@ -59,6 +59,7 @@ pub struct LocalModelBuilder {
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
runtime_config: ModelRuntimeConfig, runtime_config: ModelRuntimeConfig,
user_data: Option<serde_json::Value>, user_data: Option<serde_json::Value>,
custom_template_path: Option<PathBuf>,
namespace: Option<String>, namespace: Option<String>,
} }
...@@ -82,6 +83,7 @@ impl Default for LocalModelBuilder { ...@@ -82,6 +83,7 @@ impl Default for LocalModelBuilder {
extra_engine_args: Default::default(), extra_engine_args: Default::default(),
runtime_config: Default::default(), runtime_config: Default::default(),
user_data: Default::default(), user_data: Default::default(),
custom_template_path: Default::default(),
namespace: Default::default(), namespace: Default::default(),
} }
} }
...@@ -154,6 +156,11 @@ impl LocalModelBuilder { ...@@ -154,6 +156,11 @@ impl LocalModelBuilder {
self self
} }
pub fn custom_template_path(&mut self, custom_template_path: Option<PathBuf>) -> &mut Self {
self.custom_template_path = custom_template_path;
self
}
pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self { pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
self.migration_limit = migration_limit.unwrap_or(0); self.migration_limit = migration_limit.unwrap_or(0);
self self
...@@ -245,7 +252,9 @@ impl LocalModelBuilder { ...@@ -245,7 +252,9 @@ impl LocalModelBuilder {
// --model-config takes precedence over --model-path // --model-config takes precedence over --model-path
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path); let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card = ModelDeploymentCard::load(&model_config_path).await?; let mut card =
ModelDeploymentCard::load(&model_config_path, self.custom_template_path.as_deref())
.await?;
// Usually we infer from the path, self.model_name is user override // Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| { let model_name = self.model_name.take().unwrap_or_else(|| {
......
...@@ -371,11 +371,19 @@ impl ModelDeploymentCard { ...@@ -371,11 +371,19 @@ impl ModelDeploymentCard {
/// Build an in-memory ModelDeploymentCard from either: /// Build an in-memory ModelDeploymentCard from either:
/// - a folder containing config.json, tokenizer.json and token_config.json /// - a folder containing config.json, tokenizer.json and token_config.json
/// - a GGUF file /// - a GGUF file
pub async fn load(config_path: impl AsRef<Path>) -> anyhow::Result<ModelDeploymentCard> { /// With an optional custom template
pub async fn load(
config_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<ModelDeploymentCard> {
let config_path = config_path.as_ref(); let config_path = config_path.as_ref();
if config_path.is_dir() { if config_path.is_dir() {
Self::from_local_path(config_path).await Self::from_local_path(config_path, custom_template_path).await
} else { } else {
// GGUF files don't support custom templates yet
if custom_template_path.is_some() {
anyhow::bail!("Custom templates are not supported for GGUF files");
}
Self::from_gguf(config_path).await Self::from_gguf(config_path).await
} }
} }
...@@ -395,7 +403,10 @@ impl ModelDeploymentCard { ...@@ -395,7 +403,10 @@ impl ModelDeploymentCard {
/// - The path doesn't exist or isn't a directory /// - The path doesn't exist or isn't a directory
/// - The path contains invalid Unicode characters /// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid /// - Required model files are missing or invalid
async fn from_local_path(local_root_dir: impl AsRef<Path>) -> anyhow::Result<Self> { async fn from_local_path(
local_root_dir: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
let local_root_dir = local_root_dir.as_ref(); let local_root_dir = local_root_dir.as_ref();
check_valid_local_repo_path(local_root_dir)?; check_valid_local_repo_path(local_root_dir)?;
let repo_id = local_root_dir let repo_id = local_root_dir
...@@ -407,7 +418,8 @@ impl ModelDeploymentCard { ...@@ -407,7 +418,8 @@ impl ModelDeploymentCard {
.file_name() .file_name()
.and_then(|n| n.to_str()) .and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?; .ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
Self::from_repo(&repo_id, model_name).await
Self::from_repo(&repo_id, model_name, custom_template_path).await
} }
async fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> { async fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
...@@ -456,7 +468,11 @@ impl ModelDeploymentCard { ...@@ -456,7 +468,11 @@ impl ModelDeploymentCard {
)) ))
} }
async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> { async fn from_repo(
repo_id: &str,
model_name: &str,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
// This is usually the right choice // This is usually the right choice
let context_length = crate::file_json_field( let context_length = crate::file_json_field(
&PathBuf::from(repo_id).join("config.json"), &PathBuf::from(repo_id).join("config.json"),
...@@ -472,6 +488,30 @@ impl ModelDeploymentCard { ...@@ -472,6 +488,30 @@ impl ModelDeploymentCard {
// If neither of those are present let the engine default it // If neither of those are present let the engine default it
.unwrap_or(0); .unwrap_or(0);
// Load chat template - either custom or from repo
let chat_template_file = if let Some(template_path) = custom_template_path {
if !template_path.exists() {
anyhow::bail!(
"Custom template file does not exist: {}",
template_path.display()
);
}
// Verify the file is readable
let _template_content = std::fs::read_to_string(template_path).with_context(|| {
format!(
"Failed to read custom template file: {}",
template_path.display()
)
})?;
Some(PromptFormatterArtifact::HfChatTemplate(
template_path.display().to_string(),
))
} else {
PromptFormatterArtifact::chat_template_from_repo(repo_id).await?
};
Ok(Self { Ok(Self {
display_name: model_name.to_string(), display_name: model_name.to_string(),
slug: Slug::from_string(model_name), slug: Slug::from_string(model_name),
...@@ -479,7 +519,7 @@ impl ModelDeploymentCard { ...@@ -479,7 +519,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?), tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?, prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?, chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
last_published: None, last_published: None,
......
...@@ -26,9 +26,11 @@ impl PromptFormatter { ...@@ -26,9 +26,11 @@ impl PromptFormatter {
let content = std::fs::read_to_string(&file) let content = std::fs::read_to_string(&file)
.with_context(|| format!("fs:read_to_string '{file}'"))?; .with_context(|| format!("fs:read_to_string '{file}'"))?;
let mut config: ChatTemplate = serde_json::from_str(&content)?; let mut config: ChatTemplate = serde_json::from_str(&content)?;
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
// stores the chat template in a separate file, we check if the file exists and // stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization. // put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) = if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file mdc.chat_template_file
{ {
......
...@@ -6,7 +6,7 @@ use dynamo_llm::model_card::ModelDeploymentCard; ...@@ -6,7 +6,7 @@ use dynamo_llm::model_card::ModelDeploymentCard;
#[tokio::test] #[tokio::test]
async fn test_sequence_factory() { async fn test_sequence_factory() {
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1") let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None)
.await .await
.unwrap(); .unwrap();
......
...@@ -8,7 +8,7 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1"; ...@@ -8,7 +8,7 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
#[tokio::test] #[tokio::test]
async fn test_model_info_from_hf_like_local_repo() { async fn test_model_info_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap(); let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
let info = mdc.model_info.unwrap().get_model_info().await.unwrap(); let info = mdc.model_info.unwrap().get_model_info().await.unwrap();
assert_eq!(info.model_type(), "llama"); assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1); assert_eq!(info.bos_token_id(), 1);
...@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() { ...@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_model_info_from_non_existent_local_repo() { async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist"; let path = "tests/data/sample-models/this-model-does-not-exist";
let result = ModelDeploymentCard::load(path).await; let result = ModelDeploymentCard::load(path, None).await;
assert!(result.is_err()); assert!(result.is_err());
} }
#[tokio::test] #[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() { async fn test_tokenizer_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap(); let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
// Verify tokenizer file was found // Verify tokenizer file was found
match mdc.tokenizer.unwrap() { match mdc.tokenizer.unwrap() {
TokenizerKind::HfTokenizerJson(_) => (), TokenizerKind::HfTokenizerJson(_) => (),
...@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() { ...@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() { async fn test_prompt_formatter_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap(); let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
// Verify prompt formatter was found // Verify prompt formatter was found
match mdc.prompt_formatter { match mdc.prompt_formatter {
Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (), Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
...@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() { ...@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() {
async fn test_missing_required_files() { async fn test_missing_required_files() {
// Create empty temp directory // Create empty temp directory
let temp_dir = tempdir().unwrap(); let temp_dir = tempdir().unwrap();
let result = ModelDeploymentCard::load(temp_dir.path()).await; let result = ModelDeploymentCard::load(temp_dir.path(), None).await;
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err().to_string(); let err = result.unwrap_err().to_string();
// Should fail because config.json is missing // Should fail because config.json is missing
......
...@@ -57,7 +57,9 @@ async fn make_mdc_from_repo( ...@@ -57,7 +57,9 @@ async fn make_mdc_from_repo(
//TODO: remove this once we have nim-hub support. See the NOTE above. //TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await; let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision); let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::load(downloaded_path).await.unwrap(); let mut mdc = ModelDeploymentCard::load(downloaded_path, None)
.await
.unwrap();
mdc.set_name(&display_name); mdc.set_name(&display_name);
mdc.prompt_context = mixins; mdc.prompt_context = mixins;
mdc mdc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment