Unverified Commit 118323f2 authored by Neal Vaidya's avatar Neal Vaidya Committed by GitHub
Browse files

fix: skip HuggingFace download for non-llms (#4686)


Signed-off-by: default avatarNeal Vaidya <nealv@nvidia.com>
parent 00b64ae0
...@@ -276,6 +276,8 @@ fn register_llm<'p>( ...@@ -276,6 +276,8 @@ fn register_llm<'p>(
ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor, ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
}; };
let is_tensor_based = model_type.inner.supports_tensor();
let model_type_obj = model_type.inner; let model_type_obj = model_type.inner;
let inner_path = model_path.to_string(); let inner_path = model_path.to_string();
...@@ -323,7 +325,33 @@ fn register_llm<'p>( ...@@ -323,7 +325,33 @@ fn register_llm<'p>(
.or_else(|| Some(source_path.clone())); .or_else(|| Some(source_path.clone()));
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Resolve the model path (local or fetch from HuggingFace) // For TensorBased models, skip HuggingFace downloads and register directly
if is_tensor_based {
let model_name = model_name.unwrap_or_else(|| source_path.clone());
let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name);
card.model_type = model_type_obj;
card.model_input = model_input;
card.user_data = user_data_json;
if let Some(cfg) = runtime_config {
card.runtime_config = cfg.inner;
}
// Register the Model Deployment Card via discovery interface
let discovery = endpoint.inner.drt().discovery();
let spec = rs::discovery::DiscoverySpec::from_model(
endpoint.inner.component().namespace().name().to_string(),
endpoint.inner.component().name().to_string(),
endpoint.inner.name().to_string(),
&card,
)
.map_err(to_pyerr)?;
discovery.register(spec).await.map_err(to_pyerr)?;
return Ok(());
}
// For non-TensorBased models, resolve the model path (local or fetch from HuggingFace)
let model_path = if fs::exists(&source_path)? { let model_path = if fs::exists(&source_path)? {
PathBuf::from(&source_path) PathBuf::from(&source_path)
} else { } else {
......
...@@ -1077,6 +1077,10 @@ async def register_llm( ...@@ -1077,6 +1077,10 @@ async def register_llm(
Providing only one of these parameters will raise a ValueError. Providing only one of these parameters will raise a ValueError.
- `lora_name`: The served model name for the LoRA model - `lora_name`: The served model name for the LoRA model
- `base_model_path`: Path to the base model that the LoRA extends - `base_model_path`: Path to the base model that the LoRA extends
For TensorBased models (using ModelInput.Tensor), HuggingFace downloads are skipped
and a minimal model card is registered directly. Use model_path as the display name
for these models.
""" """
... ...
......
...@@ -34,15 +34,12 @@ async def test_register(runtime: DistributedRuntime): ...@@ -34,15 +34,12 @@ async def test_register(runtime: DistributedRuntime):
assert model_config == runtime_config.get_tensor_model_config() assert model_config == runtime_config.get_tensor_model_config()
# [gluo FIXME] register_llm will attempt to load a LLM model, # Use register_llm for tensor-based backends (skips HuggingFace downloads)
# which is not well-defined for Tensor yet. Currently provide
# a valid model name to pass the registration.
await register_llm( await register_llm(
ModelInput.Tensor, ModelInput.Tensor,
ModelType.TensorBased, ModelType.TensorBased,
endpoint, endpoint,
"Qwen/Qwen3-0.6B", "tensor", # model_path (used as display name for tensor-based models)
"tensor",
runtime_config=runtime_config, runtime_config=runtime_config,
) )
......
...@@ -385,6 +385,15 @@ impl ModelDeploymentCard { ...@@ -385,6 +385,15 @@ impl ModelDeploymentCard {
return Ok(()); return Ok(());
} }
// For TensorBased models, config files are not used - they handle everything in the backend
if self.model_type.supports_tensor() {
tracing::debug!(
display_name = %self.display_name,
"Skipping config download for TensorBased model"
);
return Ok(());
}
let ignore_weights = true; let ignore_weights = true;
let local_path = crate::hub::from_hf(&self.display_name, ignore_weights).await?; let local_path = crate::hub::from_hf(&self.display_name, ignore_weights).await?;
......
...@@ -53,15 +53,12 @@ async def echo_tensor_worker(runtime: DistributedRuntime): ...@@ -53,15 +53,12 @@ async def echo_tensor_worker(runtime: DistributedRuntime):
) )
assert model_config == retrieved_model_config assert model_config == retrieved_model_config
# [gluo FIXME] register_llm will attempt to load a LLM model, # Use register_llm for tensor-based backends (skips HuggingFace downloads)
# which is not well-defined for Tensor yet. Currently provide
# a valid model name to pass the registration.
await register_llm( await register_llm(
ModelInput.Tensor, ModelInput.Tensor,
ModelType.TensorBased, ModelType.TensorBased,
endpoint, endpoint,
"Qwen/Qwen3-0.6B", "echo", # model_path (used as display name for tensor-based models)
"echo",
runtime_config=runtime_config, runtime_config=runtime_config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment