Unverified Commit bae25dc6 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: skip downloading model weights if using mocker (only tokenizer) (#2213)

parent 3bf22bb4
......@@ -40,7 +40,8 @@ pub async fn run(
.http_port(Some(flags.http_port))
.router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit);
.migration_limit(flags.migration_limit)
.is_mocker(matches!(out_opt, Some(Output::Mocker)));
// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
......
......@@ -156,7 +156,8 @@ pub fn make_engine<'p>(
.request_template(args.template_file.clone())
.kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into()))
.http_port(args.http_port);
.http_port(args.http_port)
.is_mocker(matches!(args.engine_type, EngineType::Mocker));
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model)
......
......@@ -27,9 +27,19 @@ const IGNORED: [&str; 5] = [
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}
/// Attempt to download a model from Hugging Face
/// Returns the directory it is in
pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
/// If ignore_weights is true, model weight files will be skipped
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new()
......@@ -66,6 +76,11 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
continue;
}
// If ignore_weights is true, skip weight files
if ignore_weights && is_weight_file(&sib.rfilename) {
continue;
}
match repo.get(&sib.rfilename).await {
Ok(path) => {
p = path;
......@@ -83,8 +98,14 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
}
if !files_downloaded {
let file_type = if ignore_weights {
"non-weight"
} else {
"valid"
};
return Err(anyhow::anyhow!(
"No valid files found for model '{}'.",
"No {} files found for model '{}'.",
file_type,
model_name
));
}
......
......@@ -47,6 +47,7 @@ pub struct LocalModelBuilder {
kv_cache_block_size: u32,
http_port: u16,
migration_limit: u32,
is_mocker: bool,
}
impl Default for LocalModelBuilder {
......@@ -62,6 +63,7 @@ impl Default for LocalModelBuilder {
template_file: Default::default(),
router_config: Default::default(),
migration_limit: Default::default(),
is_mocker: Default::default(),
}
}
}
......@@ -119,6 +121,11 @@ impl LocalModelBuilder {
self
}
pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
self.is_mocker = is_mocker;
self
}
/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
......@@ -169,7 +176,7 @@ impl LocalModelBuilder {
let relative_path = model_path.trim_start_matches(HF_SCHEME);
let full_path = if is_hf_repo {
// HF download if necessary
super::hub::from_hf(relative_path).await?
super::hub::from_hf(relative_path, self.is_mocker).await?
} else {
fs::canonicalize(relative_path)?
};
......
......@@ -41,11 +41,12 @@ TEST_MODELS = [
]
def download_models(model_list=None):
def download_models(model_list=None, ignore_weights=False):
"""Download models - can be called directly or via fixture
Args:
model_list: List of model IDs to download. If None, downloads TEST_MODELS.
ignore_weights: If True, skips downloading model weight files. Default is False.
"""
if model_list is None:
model_list = TEST_MODELS
......@@ -65,11 +66,29 @@ def download_models(model_list=None):
from huggingface_hub import snapshot_download
for model_id in model_list:
logging.info(f"Pre-downloading model: {model_id}")
logging.info(
f"Pre-downloading {'model (no weights)' if ignore_weights else 'model'}: {model_id}"
)
try:
if ignore_weights:
# Weight file patterns to exclude (based on hub.rs implementation)
weight_patterns = [
"*.bin",
"*.safetensors",
"*.h5",
"*.msgpack",
"*.ckpt.index",
]
# Download everything except weight files
snapshot_download(
repo_id=model_id,
token=hf_token,
ignore_patterns=weight_patterns,
)
else:
# Download the full model snapshot (includes all files)
# HuggingFace will handle caching automatically
snapshot_download(
repo_id=model_id,
token=hf_token,
......@@ -94,6 +113,13 @@ def predownload_models():
yield
@pytest.fixture(scope="session")
def predownload_tokenizers():
"""Fixture wrapper around download_models for all TEST_MODELS"""
download_models(ignore_weights=True)
yield
@pytest.fixture(autouse=True)
def logger(request):
log_path = os.path.join(request.node.name, "test.log.txt")
......@@ -127,14 +153,24 @@ def pytest_collection_modifyitems(config, items):
# Auto-inject predownload_models fixture for serve tests only (not router tests)
# Skip items that don't have fixturenames (like MypyFileItem)
if hasattr(item, "fixturenames"):
# Only apply to tests in the serve directory
# Guard clause: skip if already has the fixtures
if (
("serve" in str(item.path))
and ("predownload_models" not in item.fixturenames)
and (not item.get_closest_marker("skip_model_download"))
"predownload_models" in item.fixturenames
or "predownload_tokenizers" in item.fixturenames
):
continue
# Guard clause: skip if marked with skip_model_download
if item.get_closest_marker("skip_model_download"):
continue
# Add appropriate fixture based on test path
if "serve" in str(item.path):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models")
elif "router" in str(item.path):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_tokenizers")
class EtcdServer(ManagedProcess):
......
......@@ -9,7 +9,6 @@ import os
import aiohttp
import pytest
from tests.conftest import download_models
from tests.utils.managed_process import ManagedProcess
pytestmark = pytest.mark.pre_merge
......@@ -96,9 +95,6 @@ def test_mocker_kv_router(request, runtime_services):
This test doesn't require GPUs and runs quickly for pre-merge validation.
"""
# Download only the Qwen model for this test
download_models([MODEL_NAME])
# runtime_services starts etcd and nats
logger.info("Starting mocker KV router test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment