Unverified Commit 433f6012 authored by Chi's avatar Chi Committed by GitHub
Browse files

feat: Pass user_data to register_llm for LoRA support (#2286)

parent 4cbd4f38
......@@ -72,6 +72,7 @@ The `model_type` can be:
- `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM.
- `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16.
- `migration_limit`: Maximum number of times a request may be [migrated to another Instance](../architecture/request_migration.md). Defaults to 0.
- `user_data`: Optional dictionary containing custom metadata for worker behavior (e.g., LoRA configuration). Defaults to None.
See `components/backends` for full code examples.
......
......@@ -664,6 +664,7 @@ The `model_type` can be:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
- `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM.
- `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16.
- `user_data`: Optional dictionary containing custom metadata for worker behavior (e.g., LoRA configuration). Defaults to None.
Here are some example engines:
......
......@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, user_data=None))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
......@@ -143,6 +143,7 @@ fn register_llm<'p>(
kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>,
migration_limit: u32,
user_data: Option<&Bound<'p, PyDict>>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -156,6 +157,13 @@ fn register_llm<'p>(
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
let user_data_json = user_data
.map(|dict| pythonize::depythonize(dict))
.transpose()
.map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err))
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder
......@@ -164,7 +172,8 @@ fn register_llm<'p>(
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config))
.migration_limit(Some(migration_limit));
.migration_limit(Some(migration_limit))
.user_data(user_data_json);
// Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
......
......@@ -48,6 +48,7 @@ pub struct LocalModelBuilder {
http_port: u16,
migration_limit: u32,
is_mocker: bool,
user_data: Option<serde_json::Value>,
}
impl Default for LocalModelBuilder {
......@@ -64,6 +65,7 @@ impl Default for LocalModelBuilder {
router_config: Default::default(),
migration_limit: Default::default(),
is_mocker: Default::default(),
user_data: Default::default(),
}
}
}
......@@ -126,6 +128,11 @@ impl LocalModelBuilder {
self
}
pub fn user_data(&mut self, user_data: Option<serde_json::Value>) -> &mut Self {
self.user_data = user_data;
self
}
/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
......@@ -155,6 +162,7 @@ impl LocalModelBuilder {
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
);
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
return Ok(LocalModel {
card,
full_path: PathBuf::new(),
......@@ -211,6 +219,7 @@ impl LocalModelBuilder {
}
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
Ok(LocalModel {
card,
......
......@@ -93,6 +93,7 @@ impl ModelDeploymentCard {
context_length,
kv_cache_block_size: 0,
migration_limit: 0,
user_data: None,
})
}
......@@ -133,6 +134,7 @@ impl ModelDeploymentCard {
context_length,
kv_cache_block_size: 0, // set later
migration_limit: 0,
user_data: None,
})
}
}
......
......@@ -131,6 +131,10 @@ pub struct ModelDeploymentCard {
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
pub migration_limit: u32,
/// User-defined metadata for custom worker behavior
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<serde_json::Value>,
}
impl ModelDeploymentCard {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment