Unverified Commit ab0da582 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Python binding to download a model. (#3593)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 6a1391eb
......@@ -16,6 +16,7 @@ from sglang.srt.server_args import ServerArgs
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.config_dump import register_encoder
from dynamo.llm import fetch_llm
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang import __version__
......@@ -203,8 +204,9 @@ def _set_parser(
return dynamo_str
def parse_args(args: list[str]) -> Config:
async def parse_args(args: list[str]) -> Config:
"""Parse CLI arguments and return combined configuration.
Download the model if necessary.
Args:
args: Command-line argument strings.
......@@ -339,6 +341,14 @@ def parse_args(args: list[str]) -> Config:
)
logging.debug(f"Dynamo args: {dynamo_args}")
# TODO: sglang downloads the model in `from_cli_args`, so we need to do it here.
# That's unfortunate because `parse_args` isn't the right place for this. Fix.
model_path = parsed_args.model_path
if not parsed_args.served_model_name:
parsed_args.served_model_name = model_path
if not os.path.exists(model_path):
parsed_args.model_path = await fetch_llm(model_path)
server_args = ServerArgs.from_cli_args(parsed_args)
if parsed_args.use_sglang_tokenizer:
......
......@@ -45,8 +45,9 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = parse_args(sys.argv[1:])
config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
if config.dynamo_args.embedding_worker:
await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor:
......
......@@ -21,6 +21,7 @@ from dynamo.llm import (
ModelType,
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
fetch_llm,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
......@@ -82,6 +83,15 @@ async def worker(runtime: DistributedRuntime):
logging.debug("Signal handlers set up for graceful shutdown")
dump_config(config.dump_config_to, config)
# Download the model if necessary.
# register_llm would do this for us, but we want it on disk before we start vllm.
# Ensure the original HF name (e.g. "Qwen/Qwen3-0.6B") is used as the served_model_name.
if not config.served_model_name:
config.served_model_name = config.engine_args.served_model_name = config.model
if not os.path.exists(config.model):
config.model = config.engine_args.model = await fetch_llm(config.model)
if config.is_prefill_worker:
await init_prefill(runtime, config)
logger.debug("init_prefill completed")
......@@ -165,9 +175,11 @@ def setup_vllm_engine(config, stat_logger=None):
disable_log_stats=engine_args.disable_log_stats,
)
if ENABLE_LMCACHE:
logger.info(f"VllmWorker for {config.model} has been initialized with LMCache")
logger.info(
f"VllmWorker for {config.served_model_name} has been initialized with LMCache"
)
else:
logger.info(f"VllmWorker for {config.model} has been initialized")
logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
return engine_client, vllm_config, default_sampling_params
......@@ -207,11 +219,13 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", config.model)],
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[("model", config.served_model_name or config.model)],
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
handler.clear_kv_blocks,
metrics_labels=[("model", config.served_model_name)],
),
)
logger.debug("serve_endpoint completed for prefill worker")
......@@ -251,7 +265,7 @@ async def init(runtime: DistributedRuntime, config: Config):
factory = StatLoggerFactory(
component,
config.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", config.model)],
metrics_labels=[("model", config.served_model_name or config.model)],
)
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(
config, factory
......@@ -262,8 +276,6 @@ async def init(runtime: DistributedRuntime, config: Config):
factory.set_request_total_slots_all(vllm_config.scheduler_config.max_num_seqs)
factory.init_publish()
logger.info(f"VllmWorker for {config.model} has been initialized")
handler = DecodeWorkerHandler(
runtime,
component,
......@@ -321,11 +333,12 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=config.migration_limit <= 0,
metrics_labels=[("model", config.model)],
metrics_labels=[("model", config.served_model_name or config.model)],
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
handler.clear_kv_blocks,
metrics_labels=[("model", config.served_model_name or config.model)],
),
)
logger.debug("serve_endpoint completed for decode worker")
......
......@@ -21,18 +21,32 @@ pub async fn run(
out_opt: Option<Output>,
mut flags: Flags,
) -> anyhow::Result<()> {
//
// Download
//
let maybe_remote_repo = flags
.model_path_pos
.clone()
.or_else(|| flags.model_path_flag.clone());
let model_path = match maybe_remote_repo {
None => None,
Some(p) if p.exists() => {
// Already a local path
Some(p)
}
Some(p) => {
// model_path might be an HF repo, not a local path. Resolve it by downloading.
Some(LocalModel::fetch(&p.display().to_string(), false).await?)
}
};
//
// Configure
//
let mut builder = LocalModelBuilder::default();
builder
.model_path(
flags
.model_path_pos
.clone()
.or(flags.model_path_flag.clone()),
)
.model_name(flags.model_name.clone())
.kv_cache_block_size(flags.kv_cache_block_size)
// Only set if user provides. Usually loaded from tokenizer_config.json
......@@ -45,6 +59,11 @@ pub async fn run(
.migration_limit(flags.migration_limit)
.is_mocker(matches!(out_opt, Some(Output::Mocker)));
// Only the worker has a model path
if let Some(model_path) = model_path {
builder.model_path(model_path);
}
// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::local_model::LocalModel;
use futures::StreamExt;
use once_cell::sync::OnceCell;
use pyo3::IntoPyObjectExt;
......@@ -9,6 +10,7 @@ use pyo3::types::{PyDict, PyString};
use pyo3::{exceptions::PyException, prelude::*};
use rand::seq::IteratorRandom as _;
use rs::pipeline::network::Ingress;
use std::fs;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::path::PathBuf;
use std::time::Duration;
......@@ -96,6 +98,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
m.add_function(wrap_pyfunction!(log_message, m)?)?;
m.add_function(wrap_pyfunction!(register_llm, m)?)?;
m.add_function(wrap_pyfunction!(fetch_llm, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
......@@ -174,6 +177,8 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
logging::log_message(level, message, module, file, line);
}
/// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend.
#[pyfunction]
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))]
#[allow(clippy::too_many_arguments)]
......@@ -201,7 +206,7 @@ fn register_llm<'p>(
let model_type_obj = model_type.inner;
let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string());
let mut model_name = model_name.map(|n| n.to_string());
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
......@@ -226,9 +231,22 @@ fn register_llm<'p>(
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let model_path = if fs::exists(&inner_path)? {
PathBuf::from(inner_path)
} else {
// Preserve the model name
if model_name.is_none() {
model_name = Some(inner_path.clone());
}
// Likely it's a Hugging Face repo, download it
LocalModel::fetch(&inner_path, false)
.await
.map_err(to_pyerr)?
};
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder
.model_path(Some(PathBuf::from(inner_path)))
.model_path(model_path)
.model_name(model_name)
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size)
......@@ -237,7 +255,7 @@ fn register_llm<'p>(
.runtime_config(runtime_config.unwrap_or_default().inner)
.user_data(user_data_json)
.custom_template_path(custom_template_path_owned);
// Download from HF, load the ModelDeploymentCard
// Load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
local_model
......@@ -249,6 +267,17 @@ fn register_llm<'p>(
})
}
/// Download a model from Hugging Face, returning it's local path
/// Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
#[pyfunction]
#[pyo3(signature = (remote_name))]
fn fetch_llm<'p>(py: Python<'p>, remote_name: &str) -> PyResult<Bound<'p, PyAny>> {
let repo = remote_name.to_string();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
LocalModel::fetch(&repo, false).await.map_err(to_pyerr)
})
}
#[pyclass]
#[derive(Clone)]
pub struct DistributedRuntime {
......
......@@ -180,6 +180,8 @@ pub(crate) struct EngineConfig {
inner: RsEngineConfig,
}
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
......@@ -189,8 +191,11 @@ pub fn make_engine<'p>(
) -> PyResult<Bound<'p, PyAny>> {
let mut builder = LocalModelBuilder::default();
builder
.model_path(args.model_path.clone())
.model_name(args.model_name.clone())
.model_name(
args.model_name
.clone()
.or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
)
.endpoint_id(args.endpoint_id.clone())
.context_length(args.context_length)
.request_template(args.template_file.clone())
......@@ -206,6 +211,17 @@ pub fn make_engine<'p>(
.custom_backend_metrics_endpoint(args.custom_backend_metrics_endpoint.clone())
.custom_backend_metrics_polling_interval(args.custom_backend_metrics_polling_interval);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
if let Some(model_path) = args.model_path.clone() {
let local_path = if model_path.exists() {
model_path
} else {
LocalModel::fetch(&model_path.display().to_string(), false)
.await
.map_err(to_pyerr)?
};
builder.model_path(local_path);
}
let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model)
.await
......
......@@ -892,6 +892,13 @@ async def register_llm(
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
...
async def fetch_llm(remote_name: str) -> str:
"""
Download a model from Hugging Face, returning it's local path.
Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
"""
...
class EngineConfig:
"""Holds internal configuration for a Dynamo engine."""
...
......
......@@ -42,6 +42,7 @@ from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from dynamo._core import ZmqKvEventPublisher as ZmqKvEventPublisher
from dynamo._core import ZmqKvEventPublisherConfig as ZmqKvEventPublisherConfig
from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py
from dynamo._core import fetch_llm as fetch_llm
from dynamo._core import make_engine
from dynamo._core import register_llm as register_llm
from dynamo._core import run_input
......
......@@ -5,7 +5,6 @@ use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::Context as _;
use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::storage::key_value_store::Key;
......@@ -25,9 +24,6 @@ pub mod runtime_config;
use runtime_config::ModelRuntimeConfig;
/// Prefix for Hugging Face model repository
const HF_SCHEME: &str = "hf://";
/// What we call a model if the user didn't provide a name. Usually this means the name
/// is invisible, for example in a text chat.
const DEFAULT_NAME: &str = "dynamo";
......@@ -90,8 +86,9 @@ impl Default for LocalModelBuilder {
}
impl LocalModelBuilder {
pub fn model_path(&mut self, model_path: Option<PathBuf>) -> &mut Self {
self.model_path = model_path;
/// The path must exist
pub fn model_path(&mut self, model_path: PathBuf) -> &mut Self {
self.model_path = Some(model_path);
self
}
......@@ -214,7 +211,7 @@ impl LocalModelBuilder {
.map(RequestTemplate::load)
.transpose()?;
// echo engine doesn't need a path. It's an edge case, move it out of the way.
// frontend and echo engine don't need a path.
if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
......@@ -243,40 +240,24 @@ impl LocalModelBuilder {
// Main logic. We are running a model.
let model_path = self.model_path.take().unwrap();
let model_path = model_path.to_str().context("Invalid UTF-8 in model path")?;
// Check for hf:// prefix first, in case we really want an HF repo but it conflicts
// with a relative path.
let is_hf_repo =
model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false);
let relative_path = model_path.trim_start_matches(HF_SCHEME);
let full_path = if is_hf_repo {
// HF download if necessary
super::hub::from_hf(relative_path, self.is_mocker).await?
} else {
fs::canonicalize(relative_path)?
};
if !model_path.exists() {
anyhow::bail!(
"Path does not exist: '{}'. Use LocalModel::fetch to download it.",
model_path.display(),
);
}
let model_path = fs::canonicalize(model_path)?;
let mut card =
ModelDeploymentCard::load_from_disk(&full_path, self.custom_template_path.as_deref())?;
// Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| {
if is_hf_repo {
// HF repos use their full name ("org/name") not the folder name
relative_path.to_string()
} else {
full_path
.iter()
.next_back()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| {
// Panic because we can't do anything without a model
panic!("Invalid model path, too short: '{}'", full_path.display())
})
}
});
card.set_name(&model_name);
ModelDeploymentCard::load_from_disk(&model_path, self.custom_template_path.as_deref())?;
// The served model name defaults to the full model path.
// This matches what vllm and sglang do.
card.set_name(
&self
.model_name
.clone()
.unwrap_or_else(|| model_path.display().to_string()),
);
card.kv_cache_block_size = self.kv_cache_block_size;
......@@ -303,7 +284,7 @@ impl LocalModelBuilder {
Ok(LocalModel {
card,
full_path,
full_path: model_path,
endpoint_id,
template,
http_host: self.http_host.take(),
......@@ -337,6 +318,15 @@ pub struct LocalModel {
}
impl LocalModel {
/// Ensure a model is accessible locally, returning it's path.
/// Downloads the model from Hugging Face if necessary.
/// If ignore_weights is true, model weight files will be skipped and only the model config
/// will be downloaded.
/// Returns the path to the model files
pub async fn fetch(remote_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
super::hub::from_hf(remote_name, ignore_weights).await
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
......
......@@ -13,7 +13,7 @@
//! - Prompt formatter settings (PromptFormatterArtifact)
use std::fmt;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::{Arc, OnceLock};
use crate::common::checked_file::CheckedFile;
......@@ -485,43 +485,31 @@ impl ModelDeploymentCard {
/// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid
fn from_local_path(
local_root_dir: impl AsRef<Path>,
local_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
let local_root_dir = local_root_dir.as_ref();
check_valid_local_repo_path(local_root_dir)?;
let repo_id = local_root_dir
.canonicalize()?
.to_str()
.ok_or_else(|| anyhow::anyhow!("Path contains invalid Unicode"))?
.to_string();
let model_name = local_root_dir
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
Self::from_repo(&repo_id, model_name, custom_template_path)
}
fn from_repo(
repo_id: &str,
model_name: &str,
check_valid_local_repo_path(&local_path)?;
Self::from_repo_checkout(&local_path, custom_template_path)
}
fn from_repo_checkout(
local_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
let local_path = local_path.as_ref();
// This is usually the right choice
let context_length = crate::file_json_field(
&PathBuf::from(repo_id).join("config.json"),
"max_position_embeddings",
)
// But sometimes this is
.or_else(|_| {
crate::file_json_field(
&PathBuf::from(repo_id).join("tokenizer_config.json"),
"model_max_length",
)
})
// If neither of those are present let the engine default it
.unwrap_or(0);
let context_length =
crate::file_json_field(&local_path.join("config.json"), "max_position_embeddings")
// But sometimes this is
.or_else(|_| {
crate::file_json_field(
&local_path.join("tokenizer_config.json"),
"model_max_length",
)
})
// If neither of those are present let the engine default it
.unwrap_or(0);
// Load chat template - either custom or from repo
let chat_template_file = if let Some(template_path) = custom_template_path {
......@@ -544,16 +532,17 @@ impl ModelDeploymentCard {
CheckedFile::from_disk(template_path)?,
))
} else {
PromptFormatterArtifact::chat_template_from_repo(repo_id)?
PromptFormatterArtifact::chat_template_from_disk(local_path)?
};
let display_name = local_path.display().to_string();
Ok(Self {
display_name: model_name.to_string(),
slug: Slug::from_string(model_name),
model_info: Some(ModelInfoType::from_repo(repo_id)?),
tokenizer: Some(TokenizerKind::from_repo(repo_id)?),
gen_config: GenerationConfig::from_repo(repo_id).ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?,
slug: Slug::from_string(&display_name),
display_name,
model_info: Some(ModelInfoType::from_disk(local_path)?),
tokenizer: Some(TokenizerKind::from_disk(local_path)?),
gen_config: GenerationConfig::from_disk(local_path).ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_disk(local_path)?,
chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context
context_length,
......@@ -778,33 +767,43 @@ impl ModelInfo for HFConfig {
}
impl ModelInfoType {
pub fn from_repo(repo_id: &str) -> Result<Self> {
let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("config.json"))
.with_context(|| format!("unable to extract config.json from repo {repo_id}"))?;
pub fn from_disk(directory: &Path) -> Result<Self> {
let f = CheckedFile::from_disk(directory.join("config.json")).with_context(|| {
format!(
"unable to extract config.json from directory {}",
directory.display()
)
})?;
Ok(Self::HfConfigJson(f))
}
}
impl GenerationConfig {
pub fn from_repo(repo_id: &str) -> Result<Self> {
let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("generation_config.json"))
.with_context(|| format!("unable to extract generation_config from repo {repo_id}"))?;
pub fn from_disk(directory: &Path) -> Result<Self> {
let f = CheckedFile::from_disk(directory.join("generation_config.json")).with_context(
|| {
format!(
"unable to extract generation_config from directory {}",
directory.display()
)
},
)?;
Ok(Self::HfGenerationConfigJson(f))
}
}
impl PromptFormatterArtifact {
pub fn from_repo(repo_id: &str) -> Result<Option<Self>> {
pub fn from_disk(directory: &Path) -> Result<Option<Self>> {
// we should only error if we expect a prompt formatter and it's not found
// right now, we don't know when to expect it, so we just return Ok(Some/None)
match CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer_config.json")) {
match CheckedFile::from_disk(directory.join("tokenizer_config.json")) {
Ok(f) => Ok(Some(Self::HfTokenizerConfigJson(f))),
Err(_) => Ok(None),
}
}
pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
match CheckedFile::from_disk(PathBuf::from(repo_id).join("chat_template.jinja")) {
pub fn chat_template_from_disk(directory: &Path) -> Result<Option<Self>> {
match CheckedFile::from_disk(directory.join("chat_template.jinja")) {
Ok(f) => Ok(Some(Self::HfChatTemplate(f))),
Err(_) => Ok(None),
}
......@@ -812,9 +811,13 @@ impl PromptFormatterArtifact {
}
impl TokenizerKind {
pub fn from_repo(repo_id: &str) -> Result<Self> {
let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer.json"))
.with_context(|| format!("unable to extract tokenizer kind from repo {repo_id}"))?;
pub fn from_disk(directory: &Path) -> Result<Self> {
let f = CheckedFile::from_disk(directory.join("tokenizer.json")).with_context(|| {
format!(
"unable to extract tokenizer kind from directory {}",
directory.display()
)
})?;
Ok(Self::HfTokenizerJson(f))
}
}
......
......@@ -29,7 +29,8 @@ pytestmark = [
mock_sglang_cli = make_cli_args_fixture("dynamo.sglang")
def test_custom_jinja_template_invalid_path(mock_sglang_cli):
@pytest.mark.asyncio
async def test_custom_jinja_template_invalid_path(mock_sglang_cli):
"""Test that invalid file path raises FileNotFoundError."""
invalid_path = "/nonexistent/path/to/template.jinja"
mock_sglang_cli(
......@@ -40,14 +41,15 @@ def test_custom_jinja_template_invalid_path(mock_sglang_cli):
FileNotFoundError,
match=re.escape(f"Custom Jinja template file not found: {invalid_path}"),
):
parse_args(sys.argv[1:])
await parse_args(sys.argv[1:])
def test_custom_jinja_template_valid_path(mock_sglang_cli):
@pytest.mark.asyncio
async def test_custom_jinja_template_valid_path(mock_sglang_cli):
"""Test that valid absolute path is stored correctly."""
mock_sglang_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=JINJA_TEMPLATE_PATH)
config = parse_args(sys.argv[1:])
config = await parse_args(sys.argv[1:])
assert config.dynamo_args.custom_jinja_template == JINJA_TEMPLATE_PATH, (
f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, "
......@@ -55,7 +57,8 @@ def test_custom_jinja_template_valid_path(mock_sglang_cli):
)
def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_sglang_cli):
@pytest.mark.asyncio
async def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_sglang_cli):
"""Test that environment variables in paths are expanded by Python code."""
jinja_dir = str(TEST_DIR / "serve" / "fixtures")
monkeypatch.setenv("JINJA_DIR", jinja_dir)
......@@ -63,7 +66,7 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_sglang_cli):
cli_path = "$JINJA_DIR/custom_template.jinja"
mock_sglang_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=cli_path)
config = parse_args(sys.argv[1:])
config = await parse_args(sys.argv[1:])
assert "$JINJA_DIR" not in config.dynamo_args.custom_jinja_template
assert config.dynamo_args.custom_jinja_template == JINJA_TEMPLATE_PATH, (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment