Unverified Commit fc124360 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): router-mode settings (#2001)

parent dc75cf18
...@@ -14,7 +14,15 @@ import asyncio ...@@ -14,7 +14,15 @@ import asyncio
import uvloop import uvloop
from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input from dynamo.llm import (
EngineType,
EntrypointArgs,
KvRouterConfig,
RouterConfig,
RouterMode,
make_engine,
run_input,
)
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -32,6 +40,39 @@ def parse_args(): ...@@ -32,6 +40,39 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--http-port", type=int, default=8080, help="HTTP port for the engine (u16)." "--http-port", type=int, default=8080, help="HTTP port for the engine (u16)."
) )
parser.add_argument(
"--router-mode",
type=str,
choices=["round-robin", "random", "kv"],
default="round-robin",
help="How to route the request",
)
parser.add_argument(
"--kv-overlap-score-weight",
type=float,
default=1.0,
help="KV Router: Weight for overlap score in worker selection. Higher values prioritize KV cache reuse.",
)
parser.add_argument(
"--router-temperature",
type=float,
default=0.0,
help="KV Router: Temperature for worker sampling via softmax. Higher values promote more randomness, and 0 fallbacks to deterministic.",
)
parser.add_argument(
"--kv-events",
action="store_true",
dest="use_kv_events",
help=" KV Router: Whether to use KV events to maintain the view of cached blocks. If false, would use ApproxKvRouter for predicting block creation / deletion based only on incoming requests at a timer.",
)
parser.add_argument(
"--no-kv-events",
action="store_false",
dest="use_kv_events",
help=" KV Router. Disable KV events.",
)
parser.set_defaults(use_kv_events=True)
return parser.parse_args() return parser.parse_args()
...@@ -39,12 +80,28 @@ async def async_main(): ...@@ -39,12 +80,28 @@ async def async_main():
runtime = DistributedRuntime(asyncio.get_running_loop(), False) runtime = DistributedRuntime(asyncio.get_running_loop(), False)
flags = parse_args() flags = parse_args()
if flags.router_mode == "kv":
router_mode = RouterMode.KV
kv_router_config = KvRouterConfig(
overlap_score_weight=flags.kv_overlap_score_weight,
router_temperature=flags.router_temperature,
use_kv_events=flags.use_kv_events,
)
elif flags.router_mode == "random":
router_mode = RouterMode.Random
kv_router_config = None
else:
router_mode = RouterMode.RoundRobin
kv_router_config = None
kwargs = {
"http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig(router_mode, kv_router_config),
}
# out=dyn # out=dyn
e = EntrypointArgs( e = EntrypointArgs(EngineType.Dynamic, **kwargs)
EngineType.Dynamic,
http_port=flags.http_port,
kv_cache_block_size=flags.kv_cache_block_size,
)
engine = await make_engine(runtime, e) engine = await make_engine(runtime, e)
try: try:
......
...@@ -44,7 +44,7 @@ pub async fn run( ...@@ -44,7 +44,7 @@ pub async fn run(
// Only set if user provides. Usually loaded from tokenizer_config.json // Only set if user provides. Usually loaded from tokenizer_config.json
.context_length(flags.context_length) .context_length(flags.context_length)
.http_port(Some(flags.http_port)) .http_port(Some(flags.http_port))
.router_config(flags.router_config()) .router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone()); .request_template(flags.request_template.clone());
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
......
...@@ -78,6 +78,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -78,6 +78,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::EntrypointArgs>()?; m.add_class::<llm::entrypoint::EntrypointArgs>()?;
m.add_class::<llm::entrypoint::EngineConfig>()?; m.add_class::<llm::entrypoint::EngineConfig>()?;
m.add_class::<llm::entrypoint::EngineType>()?; m.add_class::<llm::entrypoint::EngineType>()?;
m.add_class::<llm::entrypoint::RouterConfig>()?;
m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?; m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?; m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
...@@ -160,7 +162,7 @@ fn register_llm<'p>( ...@@ -160,7 +162,7 @@ fn register_llm<'p>(
.model_name(model_name) .model_name(model_name)
.context_length(context_length) .context_length(context_length)
.kv_cache_block_size(kv_cache_block_size) .kv_cache_block_size(kv_cache_block_size)
.router_config(router_config); .router_config(Some(router_config));
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
......
...@@ -8,10 +8,14 @@ use pyo3::{exceptions::PyException, prelude::*}; ...@@ -8,10 +8,14 @@ use pyo3::{exceptions::PyException, prelude::*};
use dynamo_llm::entrypoint::input::Input; use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig; use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_llm::mocker::protocols::MockEngineArgs; use dynamo_llm::mocker::protocols::MockEngineArgs;
use dynamo_runtime::protocols::Endpoint as EndpointId; use dynamo_runtime::protocols::Endpoint as EndpointId;
use crate::RouterMode;
#[pyclass(eq, eq_int)] #[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
#[repr(i32)] #[repr(i32)]
...@@ -21,6 +25,56 @@ pub enum EngineType { ...@@ -21,6 +25,56 @@ pub enum EngineType {
Mocker = 3, Mocker = 3,
} }
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
inner: RsKvRouterConfig,
}
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true))]
fn new(overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool) -> Self {
KvRouterConfig {
inner: RsKvRouterConfig {
overlap_score_weight,
router_temperature,
use_kv_events,
..Default::default()
},
}
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
router_mode: RouterMode,
kv_router_config: KvRouterConfig,
}
#[pymethods]
impl RouterConfig {
#[new]
#[pyo3(signature = (mode, config=None))]
pub fn new(mode: RouterMode, config: Option<KvRouterConfig>) -> Self {
Self {
router_mode: mode,
kv_router_config: config.unwrap_or_default(),
}
}
}
impl From<RouterConfig> for RsRouterConfig {
fn from(rc: RouterConfig) -> RsRouterConfig {
RsRouterConfig {
router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner,
}
}
}
#[pyclass] #[pyclass]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs { pub(crate) struct EntrypointArgs {
...@@ -31,7 +85,7 @@ pub(crate) struct EntrypointArgs { ...@@ -31,7 +85,7 @@ pub(crate) struct EntrypointArgs {
endpoint_id: Option<EndpointId>, endpoint_id: Option<EndpointId>,
context_length: Option<u32>, context_length: Option<u32>,
template_file: Option<PathBuf>, template_file: Option<PathBuf>,
//router_config: Option<RouterConfig>, router_config: Option<RouterConfig>,
kv_cache_block_size: Option<u32>, kv_cache_block_size: Option<u32>,
http_port: Option<u16>, http_port: Option<u16>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
...@@ -41,7 +95,7 @@ pub(crate) struct EntrypointArgs { ...@@ -41,7 +95,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, kv_cache_block_size=None, http_port=None, extra_engine_args=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_port=None, extra_engine_args=None))]
pub fn new( pub fn new(
engine_type: EngineType, engine_type: EngineType,
model_path: Option<PathBuf>, model_path: Option<PathBuf>,
...@@ -50,7 +104,7 @@ impl EntrypointArgs { ...@@ -50,7 +104,7 @@ impl EntrypointArgs {
endpoint_id: Option<String>, endpoint_id: Option<String>,
context_length: Option<u32>, context_length: Option<u32>,
template_file: Option<PathBuf>, template_file: Option<PathBuf>,
//router_config: Option<RouterConfig>, router_config: Option<RouterConfig>,
kv_cache_block_size: Option<u32>, kv_cache_block_size: Option<u32>,
http_port: Option<u16>, http_port: Option<u16>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
...@@ -71,7 +125,7 @@ impl EntrypointArgs { ...@@ -71,7 +125,7 @@ impl EntrypointArgs {
endpoint_id: endpoint_id_obj, endpoint_id: endpoint_id_obj,
context_length, context_length,
template_file, template_file,
//router_config, router_config,
kv_cache_block_size, kv_cache_block_size,
http_port, http_port,
extra_engine_args, extra_engine_args,
...@@ -101,6 +155,7 @@ pub fn make_engine<'p>( ...@@ -101,6 +155,7 @@ pub fn make_engine<'p>(
.context_length(args.context_length) .context_length(args.context_length)
.request_template(args.template_file.clone()) .request_template(args.template_file.clone())
.kv_cache_block_size(args.kv_cache_block_size) .kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into()))
.http_port(args.http_port); .http_port(args.http_port);
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?; let local_model = builder.build().await.map_err(to_pyerr)?;
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ( from typing import (
Any, Any,
...@@ -831,9 +819,15 @@ class ModelType: ...@@ -831,9 +819,15 @@ class ModelType:
class RouterMode: class RouterMode:
"""Router mode for load balancing requests across workers""" """Router mode for load balancing requests across workers"""
RoundRobin: 'RouterMode' ...
Random: 'RouterMode'
KV: 'RouterMode' class RouterConfig:
"""How to route the request"""
...
class KvRouterConfig:
"""Values for KV router"""
...
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None) -> None: async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type""" """Attach the model at path to the given endpoint, and advertise it as model_type"""
......
...@@ -24,10 +24,13 @@ from dynamo._core import KvEventPublisher as KvEventPublisher ...@@ -24,10 +24,13 @@ from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import RadixTree as RadixTree from dynamo._core import RadixTree as RadixTree
from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode
from dynamo._core import SpecDecodeStats as SpecDecodeStats from dynamo._core import SpecDecodeStats as SpecDecodeStats
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import WorkerStats as WorkerStats from dynamo._core import WorkerStats as WorkerStats
......
...@@ -212,7 +212,7 @@ impl ModelManager { ...@@ -212,7 +212,7 @@ impl ModelManager {
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config.clone())); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new( let chooser = KvRouter::new(
component.clone(), component.clone(),
kv_cache_block_size, kv_cache_block_size,
......
...@@ -220,7 +220,7 @@ impl ModelWatcher { ...@@ -220,7 +220,7 @@ impl ModelWatcher {
&model_entry.name, &model_entry.name,
&component, &component,
card.kv_cache_block_size, card.kv_cache_block_size,
self.kv_router_config.clone(), self.kv_router_config,
) )
.await?; .await?;
let kv_push_router = KvPushRouter::new(router, chooser); let kv_push_router = KvPushRouter::new(router, chooser);
...@@ -261,7 +261,7 @@ impl ModelWatcher { ...@@ -261,7 +261,7 @@ impl ModelWatcher {
&model_entry.name, &model_entry.name,
&component, &component,
card.kv_cache_block_size, card.kv_cache_block_size,
self.kv_router_config.clone(), self.kv_router_config,
) )
.await?; .await?;
let kv_push_router = KvPushRouter::new(router, chooser); let kv_push_router = KvPushRouter::new(router, chooser);
......
...@@ -42,7 +42,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -42,7 +42,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
etcd_client.clone(), etcd_client.clone(),
MODEL_ROOT_PATH, MODEL_ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config.clone()), Some(router_config.kv_router_config),
) )
.await?; .await?;
} }
......
...@@ -62,7 +62,7 @@ pub trait WorkerSelector { ...@@ -62,7 +62,7 @@ pub trait WorkerSelector {
} }
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone)] #[derive(Debug, Clone, Copy)]
pub struct KvRouterConfig { pub struct KvRouterConfig {
pub overlap_score_weight: f64, pub overlap_score_weight: f64,
......
...@@ -102,8 +102,8 @@ impl LocalModelBuilder { ...@@ -102,8 +102,8 @@ impl LocalModelBuilder {
self self
} }
pub fn router_config(&mut self, router_config: RouterConfig) -> &mut Self { pub fn router_config(&mut self, router_config: Option<RouterConfig>) -> &mut Self {
self.router_config = Some(router_config); self.router_config = router_config;
self self
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment