Unverified Commit bae25dc6 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: skip downloading model weights if using mocker (only tokenizer) (#2213)

parent 3bf22bb4
...@@ -40,7 +40,8 @@ pub async fn run( ...@@ -40,7 +40,8 @@ pub async fn run(
.http_port(Some(flags.http_port)) .http_port(Some(flags.http_port))
.router_config(Some(flags.router_config())) .router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone()) .request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit); .migration_limit(flags.migration_limit)
.is_mocker(matches!(out_opt, Some(Output::Mocker)));
// TODO: old, address this later: // TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
......
...@@ -156,7 +156,8 @@ pub fn make_engine<'p>( ...@@ -156,7 +156,8 @@ pub fn make_engine<'p>(
.request_template(args.template_file.clone()) .request_template(args.template_file.clone())
.kv_cache_block_size(args.kv_cache_block_size) .kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into())) .router_config(args.router_config.clone().map(|rc| rc.into()))
.http_port(args.http_port); .http_port(args.http_port)
.is_mocker(matches!(args.engine_type, EngineType::Mocker));
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?; let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model) let inner = select_engine(distributed_runtime, args, local_model)
......
...@@ -27,9 +27,19 @@ const IGNORED: [&str; 5] = [ ...@@ -27,9 +27,19 @@ const IGNORED: [&str; 5] = [
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN"; const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
/// Checks if a file is a model weight file
fn is_weight_file(filename: &str) -> bool {
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}
/// Attempt to download a model from Hugging Face /// Attempt to download a model from Hugging Face
/// Returns the directory it is in /// Returns the directory it is in
pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> { /// If ignore_weights is true, model weight files will be skipped
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref(); let name = name.as_ref();
let token = env::var(HF_TOKEN_ENV_VAR).ok(); let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::new() let api = ApiBuilder::new()
...@@ -66,6 +76,11 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> { ...@@ -66,6 +76,11 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
continue; continue;
} }
// If ignore_weights is true, skip weight files
if ignore_weights && is_weight_file(&sib.rfilename) {
continue;
}
match repo.get(&sib.rfilename).await { match repo.get(&sib.rfilename).await {
Ok(path) => { Ok(path) => {
p = path; p = path;
...@@ -83,8 +98,14 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> { ...@@ -83,8 +98,14 @@ pub async fn from_hf(name: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
} }
if !files_downloaded { if !files_downloaded {
let file_type = if ignore_weights {
"non-weight"
} else {
"valid"
};
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"No valid files found for model '{}'.", "No {} files found for model '{}'.",
file_type,
model_name model_name
)); ));
} }
......
...@@ -47,6 +47,7 @@ pub struct LocalModelBuilder { ...@@ -47,6 +47,7 @@ pub struct LocalModelBuilder {
kv_cache_block_size: u32, kv_cache_block_size: u32,
http_port: u16, http_port: u16,
migration_limit: u32, migration_limit: u32,
is_mocker: bool,
} }
impl Default for LocalModelBuilder { impl Default for LocalModelBuilder {
...@@ -62,6 +63,7 @@ impl Default for LocalModelBuilder { ...@@ -62,6 +63,7 @@ impl Default for LocalModelBuilder {
template_file: Default::default(), template_file: Default::default(),
router_config: Default::default(), router_config: Default::default(),
migration_limit: Default::default(), migration_limit: Default::default(),
is_mocker: Default::default(),
} }
} }
} }
...@@ -119,6 +121,11 @@ impl LocalModelBuilder { ...@@ -119,6 +121,11 @@ impl LocalModelBuilder {
self self
} }
pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
self.is_mocker = is_mocker;
self
}
/// Make an LLM ready for use: /// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary /// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path /// - Resolve the path
...@@ -169,7 +176,7 @@ impl LocalModelBuilder { ...@@ -169,7 +176,7 @@ impl LocalModelBuilder {
let relative_path = model_path.trim_start_matches(HF_SCHEME); let relative_path = model_path.trim_start_matches(HF_SCHEME);
let full_path = if is_hf_repo { let full_path = if is_hf_repo {
// HF download if necessary // HF download if necessary
super::hub::from_hf(relative_path).await? super::hub::from_hf(relative_path, self.is_mocker).await?
} else { } else {
fs::canonicalize(relative_path)? fs::canonicalize(relative_path)?
}; };
......
...@@ -41,11 +41,12 @@ TEST_MODELS = [ ...@@ -41,11 +41,12 @@ TEST_MODELS = [
] ]
def download_models(model_list=None): def download_models(model_list=None, ignore_weights=False):
"""Download models - can be called directly or via fixture """Download models - can be called directly or via fixture
Args: Args:
model_list: List of model IDs to download. If None, downloads TEST_MODELS. model_list: List of model IDs to download. If None, downloads TEST_MODELS.
ignore_weights: If True, skips downloading model weight files. Default is False.
""" """
if model_list is None: if model_list is None:
model_list = TEST_MODELS model_list = TEST_MODELS
...@@ -65,15 +66,33 @@ def download_models(model_list=None): ...@@ -65,15 +66,33 @@ def download_models(model_list=None):
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
for model_id in model_list: for model_id in model_list:
logging.info(f"Pre-downloading model: {model_id}") logging.info(
f"Pre-downloading {'model (no weights)' if ignore_weights else 'model'}: {model_id}"
)
try: try:
# Download the full model snapshot (includes all files) if ignore_weights:
# HuggingFace will handle caching automatically # Weight file patterns to exclude (based on hub.rs implementation)
snapshot_download( weight_patterns = [
repo_id=model_id, "*.bin",
token=hf_token, "*.safetensors",
) "*.h5",
"*.msgpack",
"*.ckpt.index",
]
# Download everything except weight files
snapshot_download(
repo_id=model_id,
token=hf_token,
ignore_patterns=weight_patterns,
)
else:
# Download the full model snapshot (includes all files)
snapshot_download(
repo_id=model_id,
token=hf_token,
)
logging.info(f"Successfully pre-downloaded: {model_id}") logging.info(f"Successfully pre-downloaded: {model_id}")
except Exception as e: except Exception as e:
...@@ -94,6 +113,13 @@ def predownload_models(): ...@@ -94,6 +113,13 @@ def predownload_models():
yield yield
@pytest.fixture(scope="session")
def predownload_tokenizers():
"""Fixture wrapper around download_models for all TEST_MODELS"""
download_models(ignore_weights=True)
yield
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def logger(request): def logger(request):
log_path = os.path.join(request.node.name, "test.log.txt") log_path = os.path.join(request.node.name, "test.log.txt")
...@@ -127,14 +153,24 @@ def pytest_collection_modifyitems(config, items): ...@@ -127,14 +153,24 @@ def pytest_collection_modifyitems(config, items):
# Auto-inject predownload_models fixture for serve tests only (not router tests) # Auto-inject predownload_models fixture for serve tests only (not router tests)
# Skip items that don't have fixturenames (like MypyFileItem) # Skip items that don't have fixturenames (like MypyFileItem)
if hasattr(item, "fixturenames"): if hasattr(item, "fixturenames"):
# Only apply to tests in the serve directory # Guard clause: skip if already has the fixtures
if ( if (
("serve" in str(item.path)) "predownload_models" in item.fixturenames
and ("predownload_models" not in item.fixturenames) or "predownload_tokenizers" in item.fixturenames
and (not item.get_closest_marker("skip_model_download"))
): ):
continue
# Guard clause: skip if marked with skip_model_download
if item.get_closest_marker("skip_model_download"):
continue
# Add appropriate fixture based on test path
if "serve" in str(item.path):
item.fixturenames = list(item.fixturenames) item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models") item.fixturenames.append("predownload_models")
elif "router" in str(item.path):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_tokenizers")
class EtcdServer(ManagedProcess): class EtcdServer(ManagedProcess):
......
...@@ -9,7 +9,6 @@ import os ...@@ -9,7 +9,6 @@ import os
import aiohttp import aiohttp
import pytest import pytest
from tests.conftest import download_models
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
pytestmark = pytest.mark.pre_merge pytestmark = pytest.mark.pre_merge
...@@ -96,9 +95,6 @@ def test_mocker_kv_router(request, runtime_services): ...@@ -96,9 +95,6 @@ def test_mocker_kv_router(request, runtime_services):
This test doesn't require GPUs and runs quickly for pre-merge validation. This test doesn't require GPUs and runs quickly for pre-merge validation.
""" """
# Download only the Qwen model for this test
download_models([MODEL_NAME])
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
logger.info("Starting mocker KV router test") logger.info("Starting mocker KV router test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment