Unverified Commit 0a32b344 authored by Alec's avatar Alec Committed by GitHub
Browse files

fix: default to None initialization of routing config (#1713)

parent 54c21168
......@@ -15,12 +15,33 @@ use tokio::sync::Mutex;
use dynamo_runtime::{
self as rs, logging,
pipeline::{EngineStream, ManyOut, SingleIn},
pipeline::{
network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn,
},
protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider,
};
use dynamo_llm::{self as llm_rs};
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
pub enum RouterMode {
RoundRobin,
Random,
KV,
}
impl From<RouterMode> for RsRouterMode {
fn from(mode: RouterMode) -> Self {
match mode {
RouterMode::RoundRobin => Self::RoundRobin,
RouterMode::Random => Self::Random,
RouterMode::KV => Self::KV,
}
}
}
mod engine;
mod http;
......@@ -75,6 +96,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<RouterMode>()?;
engine::add_to_module(m)?;
......@@ -99,7 +121,8 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
model_type: ModelType,
......@@ -108,6 +131,7 @@ fn register_llm<'p>(
model_name: Option<&str>,
context_length: Option<u32>,
kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -118,13 +142,17 @@ fn register_llm<'p>(
let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string());
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder
.model_path(Some(PathBuf::from(inner_path)))
.model_name(model_name)
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size);
.kv_cache_block_size(kv_cache_block_size)
.router_config(router_config);
// Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
......
......@@ -777,7 +777,13 @@ class ModelType:
"""What type of request this model needs: Chat, Component or Backend (pre-processed)"""
...
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None) -> None:
class RouterMode:
"""Router mode for load balancing requests across workers"""
RoundRobin: 'RouterMode'
Random: 'RouterMode'
KV: 'RouterMode'
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment