"lib/llm/src/vscode:/vscode.git/clone" did not exist on "49b7a0d99ec3c666df4e6a452062cdc72b88c903"
Unverified Commit 433f6012 authored by Chi's avatar Chi Committed by GitHub
Browse files

feat: Pass user_data to register_llm for LoRA support (#2286)

parent 4cbd4f38
...@@ -72,6 +72,7 @@ The `model_type` can be: ...@@ -72,6 +72,7 @@ The `model_type` can be:
- `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM. - `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM.
- `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16. - `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16.
- `migration_limit`: Maximum number of times a request may be [migrated to another Instance](../architecture/request_migration.md). Defaults to 0. - `migration_limit`: Maximum number of times a request may be [migrated to another Instance](../architecture/request_migration.md). Defaults to 0.
- `user_data`: Optional dictionary containing custom metadata for worker behavior (e.g., LoRA configuration). Defaults to None.
See `components/backends` for full code examples. See `components/backends` for full code examples.
......
...@@ -664,6 +664,7 @@ The `model_type` can be: ...@@ -664,6 +664,7 @@ The `model_type` can be:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name. - `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
- `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM. - `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM.
- `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16. - `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16.
- `user_data`: Optional dictionary containing custom metadata for worker behavior (e.g., LoRA configuration). Defaults to None.
Here are some example engines: Here are some example engines:
......
...@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) ...@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0))] #[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, user_data=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
...@@ -143,6 +143,7 @@ fn register_llm<'p>( ...@@ -143,6 +143,7 @@ fn register_llm<'p>(
kv_cache_block_size: Option<u32>, kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>, router_mode: Option<RouterMode>,
migration_limit: u32, migration_limit: u32,
user_data: Option<&Bound<'p, PyDict>>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type { let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelType::Chat => llm_rs::model_type::ModelType::Chat,
...@@ -156,6 +157,13 @@ fn register_llm<'p>( ...@@ -156,6 +157,13 @@ fn register_llm<'p>(
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin); let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default()); let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
let user_data_json = user_data
.map(|dict| pythonize::depythonize(dict))
.transpose()
.map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err))
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default(); let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder builder
...@@ -164,7 +172,8 @@ fn register_llm<'p>( ...@@ -164,7 +172,8 @@ fn register_llm<'p>(
.context_length(context_length) .context_length(context_length)
.kv_cache_block_size(kv_cache_block_size) .kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config)) .router_config(Some(router_config))
.migration_limit(Some(migration_limit)); .migration_limit(Some(migration_limit))
.user_data(user_data_json);
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
......
...@@ -48,6 +48,7 @@ pub struct LocalModelBuilder { ...@@ -48,6 +48,7 @@ pub struct LocalModelBuilder {
http_port: u16, http_port: u16,
migration_limit: u32, migration_limit: u32,
is_mocker: bool, is_mocker: bool,
user_data: Option<serde_json::Value>,
} }
impl Default for LocalModelBuilder { impl Default for LocalModelBuilder {
...@@ -64,6 +65,7 @@ impl Default for LocalModelBuilder { ...@@ -64,6 +65,7 @@ impl Default for LocalModelBuilder {
router_config: Default::default(), router_config: Default::default(),
migration_limit: Default::default(), migration_limit: Default::default(),
is_mocker: Default::default(), is_mocker: Default::default(),
user_data: Default::default(),
} }
} }
} }
...@@ -126,6 +128,11 @@ impl LocalModelBuilder { ...@@ -126,6 +128,11 @@ impl LocalModelBuilder {
self self
} }
pub fn user_data(&mut self, user_data: Option<serde_json::Value>) -> &mut Self {
self.user_data = user_data;
self
}
/// Make an LLM ready for use: /// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary /// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path /// - Resolve the path
...@@ -155,6 +162,7 @@ impl LocalModelBuilder { ...@@ -155,6 +162,7 @@ impl LocalModelBuilder {
self.model_name.as_deref().unwrap_or(DEFAULT_NAME), self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
); );
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
return Ok(LocalModel { return Ok(LocalModel {
card, card,
full_path: PathBuf::new(), full_path: PathBuf::new(),
...@@ -211,6 +219,7 @@ impl LocalModelBuilder { ...@@ -211,6 +219,7 @@ impl LocalModelBuilder {
} }
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
Ok(LocalModel { Ok(LocalModel {
card, card,
......
...@@ -93,6 +93,7 @@ impl ModelDeploymentCard { ...@@ -93,6 +93,7 @@ impl ModelDeploymentCard {
context_length, context_length,
kv_cache_block_size: 0, kv_cache_block_size: 0,
migration_limit: 0, migration_limit: 0,
user_data: None,
}) })
} }
...@@ -133,6 +134,7 @@ impl ModelDeploymentCard { ...@@ -133,6 +134,7 @@ impl ModelDeploymentCard {
context_length, context_length,
kv_cache_block_size: 0, // set later kv_cache_block_size: 0, // set later
migration_limit: 0, migration_limit: 0,
user_data: None,
}) })
} }
} }
......
...@@ -131,6 +131,10 @@ pub struct ModelDeploymentCard { ...@@ -131,6 +131,10 @@ pub struct ModelDeploymentCard {
/// How many times a request can be migrated to another worker if the HTTP server lost /// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker. /// connection to the current worker.
pub migration_limit: u32, pub migration_limit: u32,
/// User-defined metadata for custom worker behavior
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<serde_json::Value>,
} }
impl ModelDeploymentCard { impl ModelDeploymentCard {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment