Unverified Commit f08729ae authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: python bindings for the entire KvPushRouter + per-request router configs (#2658)

parent 8e4d81f3
...@@ -6323,20 +6323,6 @@ dependencies = [ ...@@ -6323,20 +6323,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "router"
version = "0.4.1"
dependencies = [
"clap 4.5.42",
"dynamo-llm",
"dynamo-runtime",
"rand 0.9.2",
"serde",
"serde_json",
"tokio",
"tracing",
]
[[package]] [[package]]
name = "rstest" name = "rstest"
version = "0.18.2" version = "0.18.2"
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
[workspace] [workspace]
members = [ members = [
"components/metrics", "components/metrics",
"components/router",
"launch/*", "launch/*",
"lib/llm", "lib/llm",
"lib/runtime", "lib/runtime",
......
...@@ -49,15 +49,6 @@ The frontend component provides the HTTP API layer and request processing: ...@@ -49,15 +49,6 @@ The frontend component provides the HTTP API layer and request processing:
- **Router** - Routes requests to appropriate workers based on load and KV cache state - **Router** - Routes requests to appropriate workers based on load and KV cache state
- **Auto-discovery** - Automatically discovers and registers available workers - **Auto-discovery** - Automatically discovers and registers available workers
### [Router](router/)
A high-performance request router written in Rust that:
- Routes incoming requests to optimal workers based on KV cache state
- Implements KV-aware routing to minimize cache misses
- Provides load balancing across multiple worker instances
- Supports both aggregated and disaggregated serving patterns
### [Planner](planner/) ### [Planner](planner/)
The planner component monitors system state and dynamically adjusts worker allocation: The planner component monitors system state and dynamically adjusts worker allocation:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "router"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[dependencies]
dynamo-runtime = { workspace = true}
dynamo-llm = { workspace = true}
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO(#400):
// Instead of passing in a block_size, we should get this data from the backend component's config.
// What changes need to be made:
// 1. Take as an argument the name of the backend component.
// 2. Update the backend component to produce a config in a standard location.
// 3. Update the KvRouter to read the config from the backend component.
use std::collections::HashMap;
use std::sync::Arc;
use clap::Parser;
use dynamo_llm::kv_router::{
KvRouter, WorkerSelector,
protocols::WorkerSelectionResult,
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
};
use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::{
DistributedRuntime, Result, Runtime, Worker, logging, pipeline::network::Ingress,
};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Namespace for the distributed component
#[arg(long)]
namespace: String,
/// Component name for the service
#[arg(long, default_value = "kv_aware_router")]
component: String,
/// Block size for the router
#[arg(long)]
block_size: u32,
}
fn main() -> Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
let args = Args::parse();
let runtime = DistributedRuntime::from_settings(runtime).await?;
let component = runtime
.namespace(&args.namespace)?
.component(&args.component)?;
let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector), None).await?;
let router = Ingress::for_engine(Arc::new(router))?;
component
.service_builder()
.create()
.await?
.endpoint("generate")
.endpoint_builder()
.handler(router)
.start()
.await
}
#[derive(Default)]
pub struct CustomWorkerSelector(DefaultWorkerSelector);
impl WorkerSelector for CustomWorkerSelector {
fn select_worker(
&self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
// customize logic here
// F12 into [DefaultWorkerSelector] to see the original logic
self.0.select_worker(workers, request, block_size)
}
}
...@@ -143,4 +143,71 @@ The `router_temperature` parameter controls routing randomness: ...@@ -143,4 +143,71 @@ The `router_temperature` parameter controls routing randomness:
3. Adjust `kv-overlap-score-weight` to meet your performance goals: 3. Adjust `kv-overlap-score-weight` to meet your performance goals:
- To reduce TTFT: Increase the weight - To reduce TTFT: Increase the weight
- To reduce ITL: Decrease the weight - To reduce ITL: Decrease the weight
4. If you observe severe load imbalance, increase the temperature setting 4. If you observe severe load imbalance, increase the temperature setting
\ No newline at end of file
## Using KvPushRouter Python API
Instead of launching the KV Router via command line, you can create a `KvPushRouter` object directly in Python. This allows per-request routing configuration overrides.
### Setup
First, launch your backend engines:
```bash
python -m dynamo.vllm --model meta-llama/Llama-2-7b-hf --endpoint dyn://inference.vllm.generate
```
### Example Script
```python
import asyncio
from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig
async def main():
# Get runtime and create endpoint
runtime = DistributedRuntime.detached()
namespace = runtime.namespace("inference")
component = namespace.component("vllm")
endpoint = component.endpoint("generate")
# Create KV router
kv_router_config = KvRouterConfig()
router = KvPushRouter(
endpoint=endpoint,
block_size=16,
kv_router_config=kv_router_config
)
# Your input tokens
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Generate with per-request routing override
stream = await router.generate(
token_ids=token_ids,
model="meta-llama/Llama-2-7b-hf",
stop_conditions={
"max_tokens": 20, # Generate exactly 20 tokens
"ignore_eos": True, # Don't stop at EOS token
},
sampling_options={
"temperature": 0.7,
"top_p": 0.9,
},
router_config_override={
"overlap_score_weight": 2.0, # Prioritize cache hits for this request
"router_temperature": 0.5, # Add routing randomness
}
)
# Collect generated tokens
generated_tokens = []
async for response in stream:
if isinstance(response, dict) and "token_ids" in response:
generated_tokens.extend(response["token_ids"])
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
if __name__ == "__main__":
asyncio.run(main())
```
The `router_config_override` parameter allows you to adjust routing behavior per request without recreating the router. This is useful for implementing different routing strategies based on request characteristics.
\ No newline at end of file
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
components/backends/sglang/docs/multinode-examples.md components/backends/sglang/docs/multinode-examples.md
components/backends/sglang/docs/sgl-http-server.md components/backends/sglang/docs/sgl-http-server.md
components/backends/sglang/slurm_jobs/README.md components/backends/sglang/slurm_jobs/README.md
components/router/README.md
examples/README.md examples/README.md
guides/dynamo_deploy/create_deployment.md guides/dynamo_deploy/create_deployment.md
guides/dynamo_deploy/sla_planner_deployment.md guides/dynamo_deploy/sla_planner_deployment.md
......
...@@ -382,8 +382,6 @@ python -m dynamo.frontend \ ...@@ -382,8 +382,6 @@ python -m dynamo.frontend \
However, for maximum performance with shared prefixes and multi-turn conversations, KV routing provides significant advantages by minimizing redundant computation. However, for maximum performance with shared prefixes and multi-turn conversations, KV routing provides significant advantages by minimizing redundant computation.
For detailed router configuration and tuning options, see the [KV Router Documentation](../../../docs/components/router/README.md).
## Monitoring and Debugging ## Monitoring and Debugging
### Check Worker Registration ### Check Worker Registration
......
...@@ -1433,7 +1433,6 @@ dependencies = [ ...@@ -1433,7 +1433,6 @@ dependencies = [
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
"tokio",
"tracing", "tracing",
"uuid", "uuid",
] ]
......
...@@ -111,6 +111,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -111,6 +111,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::WorkerStats>()?; m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?; m.add_class::<llm::kv::KvStats>()?;
m.add_class::<llm::kv::SpecDecodeStats>()?; m.add_class::<llm::kv::SpecDecodeStats>()?;
m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?; m.add_class::<RouterMode>()?;
engine::add_to_module(m)?; engine::add_to_module(m)?;
......
...@@ -33,6 +33,12 @@ pub struct KvRouterConfig { ...@@ -33,6 +33,12 @@ pub struct KvRouterConfig {
inner: RsKvRouterConfig, inner: RsKvRouterConfig,
} }
impl KvRouterConfig {
pub fn inner(&self) -> RsKvRouterConfig {
self.inner
}
}
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use pythonize::{depythonize, pythonize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::AtomicU32; use std::sync::atomic::AtomicU32;
use tokio_stream::StreamExt;
use super::*; use super::*;
use llm_rs::kv_router::indexer::compute_block_hash_for_seq; use llm_rs::kv_router::indexer::compute_block_hash_for_seq;
...@@ -28,6 +30,7 @@ use tracing; ...@@ -28,6 +30,7 @@ use tracing;
use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig}; use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig};
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
#[pyfunction] #[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> { pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
...@@ -833,3 +836,188 @@ impl SpecDecodeStats { ...@@ -833,3 +836,188 @@ impl SpecDecodeStats {
}) })
} }
} }
#[pyclass]
pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>,
}
#[pymethods]
impl KvPushRouter {
#[new]
fn new(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: &super::entrypoint::KvRouterConfig,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move {
let client = endpoint.inner.client().await.map_err(to_pyerr)?;
// Create PushRouter with KV router mode
let push_router = rs::pipeline::PushRouter::<
llm_rs::protocols::common::preprocessor::PreprocessedRequest,
rs::protocols::annotated::Annotated<
llm_rs::protocols::common::llm_backend::LLMEngineOutput,
>,
>::from_client(
client,
rs::pipeline::network::egress::push_router::RouterMode::KV,
)
.await
.map_err(to_pyerr)?;
// Get component from endpoint
let component = endpoint.inner.component();
// Create KvRouter
let kv_router = llm_rs::kv_router::KvRouter::new(
component.clone(),
block_size as u32,
None, // default selector
Some(kv_router_config.inner()),
)
.await
.map_err(to_pyerr)?;
// Create KvPushRouter
let kv_push_router =
llm_rs::kv_router::KvPushRouter::new(push_router, Arc::new(kv_router));
Ok(Self {
inner: Arc::new(kv_push_router),
})
})
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None))]
fn generate<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
model: String,
stop_conditions: Option<PyObject>,
sampling_options: Option<PyObject>,
output_options: Option<PyObject>,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults
let (stop_conditions, sampling_options, output_options, router_config_override) =
Python::with_gil(|py| {
let stop_conditions: StopConditions = if let Some(obj) = stop_conditions {
depythonize(obj.bind(py)).map_err(to_pyerr)?
} else {
StopConditions::default()
};
let sampling_options: SamplingOptions = if let Some(obj) = sampling_options {
depythonize(obj.bind(py)).map_err(to_pyerr)?
} else {
SamplingOptions::default()
};
let output_options: OutputOptions = if let Some(obj) = output_options {
depythonize(obj.bind(py)).map_err(to_pyerr)?
} else {
OutputOptions::default()
};
let router_config_override: Option<llm_rs::kv_router::RouterConfigOverride> =
if let Some(obj) = router_config_override {
Some(depythonize(obj.bind(py)).map_err(to_pyerr)?)
} else {
None
};
Ok::<_, PyErr>((
stop_conditions,
sampling_options,
output_options,
router_config_override,
))
})?;
// Build the PreprocessedRequest
let request = llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder()
.model(model)
.token_ids(token_ids)
.stop_conditions(stop_conditions)
.sampling_options(sampling_options)
.output_options(output_options)
.router_config_override(router_config_override)
.build()
.map_err(to_pyerr)?;
let inner = self.inner.clone();
// Create a Python async generator that wraps the Rust stream
pyo3_async_runtimes::tokio::future_into_py(py, async move {
use rs::pipeline::{AsyncEngine, SingleIn};
use tokio_stream::StreamExt;
let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100);
// Spawn a task to process the stream
tokio::spawn(async move {
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| {
pythonize(py, &response.data)
.map(|obj| obj.unbind())
.map_err(|e| e.to_string())
});
match py_response {
Ok(obj) => {
if tx.send(obj).await.is_err() {
break; // Receiver dropped
}
}
Err(e) => {
tracing::error!("Failed to pythonize response: {}", e);
break;
}
}
}
});
// Return a Python async generator wrapper
Ok(KvPushRouterStream {
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
})
}
}
// Python async generator wrapper for the stream
#[pyclass]
pub(crate) struct KvPushRouterStream {
rx: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<PyObject>>>,
}
#[pymethods]
impl KvPushRouterStream {
#[pyo3(name = "__aiter__")]
fn aiter(slf: Bound<'_, Self>) -> PyResult<Py<PyAny>> {
Ok(slf.clone().into_any().unbind())
}
#[pyo3(name = "__anext__")]
fn anext<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let rx = self.rx.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut rx = rx.lock().await;
match rx.recv().await {
Some(obj) => Ok(obj),
None => Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
"Stream exhausted",
)),
}
})
}
}
...@@ -6,6 +6,7 @@ use std::sync::Arc; ...@@ -6,6 +6,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, InstanceSource}, component::{Component, InstanceSource},
pipeline::{ pipeline::{
...@@ -73,6 +74,16 @@ pub trait WorkerSelector { ...@@ -73,6 +74,16 @@ pub trait WorkerSelector {
) -> Result<WorkerSelectionResult, KvSchedulerError>; ) -> Result<WorkerSelectionResult, KvSchedulerError>;
} }
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
#[builder(default)]
pub overlap_score_weight: Option<f64>,
#[builder(default)]
pub router_temperature: Option<f64>,
}
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct KvRouterConfig { pub struct KvRouterConfig {
...@@ -261,6 +272,7 @@ impl KvRouter { ...@@ -261,6 +272,7 @@ impl KvRouter {
&self, &self,
context_id: &str, context_id: &str,
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
) -> anyhow::Result<(i64, u32)> { ) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
...@@ -276,6 +288,7 @@ impl KvRouter { ...@@ -276,6 +288,7 @@ impl KvRouter {
isl_tokens, isl_tokens,
seq_hashes.clone(), seq_hashes.clone(),
overlap_scores.clone(), overlap_scores.clone(),
router_config_override,
) )
.await?; .await?;
...@@ -315,7 +328,9 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -315,7 +328,9 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>, request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> { ) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts(); let (request, ctx) = request.into_parts();
let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?; let (worker_id, _) = self
.find_best_match(ctx.id(), &request.tokens, None)
.await?;
let response = RouterResponse { worker_id }; let response = RouterResponse { worker_id };
let response = Annotated::from_data(response); let response = Annotated::from_data(response);
...@@ -357,7 +372,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -357,7 +372,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} else { } else {
// Otherwise, find the best match // Otherwise, find the best match
self.chooser self.chooser
.find_best_match(&context_id, &request.token_ids) .find_best_match(
&context_id,
&request.token_ids,
request.router_config_override.as_ref(),
)
.await? .await?
}; };
......
...@@ -13,6 +13,7 @@ use tokio::sync::watch; ...@@ -13,6 +13,7 @@ use tokio::sync::watch;
use super::KV_HIT_RATE_SUBJECT; use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig; use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector; use super::WorkerSelector;
use super::indexer::OverlapScores; use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult; use super::protocols::WorkerSelectionResult;
...@@ -52,6 +53,8 @@ pub struct SchedulingRequest { ...@@ -52,6 +53,8 @@ pub struct SchedulingRequest {
pub overlaps: OverlapScores, pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>, pub decode_blocks: HashMap<i64, usize>,
pub prefill_tokens: HashMap<i64, usize>, pub prefill_tokens: HashMap<i64, usize>,
// Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>,
// Option to take it out to send the response without moving the struct // Option to take it out to send the response without moving the struct
resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>, resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
} }
...@@ -243,6 +246,7 @@ impl KvScheduler { ...@@ -243,6 +246,7 @@ impl KvScheduler {
isl_tokens: usize, isl_tokens: usize,
token_seq: Vec<SequenceHash>, token_seq: Vec<SequenceHash>,
overlaps: OverlapScores, overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
...@@ -252,6 +256,7 @@ impl KvScheduler { ...@@ -252,6 +256,7 @@ impl KvScheduler {
overlaps, overlaps,
decode_blocks: HashMap::new(), decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
resp_tx: Some(resp_tx), // Wrap in Some() resp_tx: Some(resp_tx), // Wrap in Some()
}; };
...@@ -402,14 +407,19 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -402,14 +407,19 @@ impl WorkerSelector for DefaultWorkerSelector {
.unwrap_or(&(potential_prefill_block.floor() as usize)) .unwrap_or(&(potential_prefill_block.floor() as usize))
as f64; as f64;
// Use override if provided, otherwise use default config
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
// Calculate logit (lower is better) // Calculate logit (lower is better)
let logit = let logit = overlap_weight * potential_prefill_block + decode_block;
self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit); max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit); worker_logits.insert(*worker_id, logit);
let overlap_weight = self.kv_router_config.overlap_score_weight;
tracing::info!( tracing::info!(
"Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \ "Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \
...@@ -418,7 +428,12 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -418,7 +428,12 @@ impl WorkerSelector for DefaultWorkerSelector {
} }
// Use softmax sampling to select worker // Use softmax sampling to select worker
let temperature = self.kv_router_config.router_temperature; // Use override if provided, otherwise use default config
let temperature = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let best_worker_id = softmax_sample(&worker_logits, temperature); let best_worker_id = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker_id]; let best_logit = worker_logits[&best_worker_id];
......
...@@ -176,22 +176,19 @@ mod tests { ...@@ -176,22 +176,19 @@ mod tests {
// Helper to create a mock preprocessed request // Helper to create a mock preprocessed request
fn create_mock_request(max_tokens: u32) -> PreprocessedRequest { fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
PreprocessedRequest { PreprocessedRequest::builder()
model: "mock".to_string(), .model("mock".to_string())
token_ids: vec![1, 2, 3], .token_ids(vec![1, 2, 3])
batch_token_ids: None, .stop_conditions(StopConditions {
stop_conditions: StopConditions {
max_tokens: Some(max_tokens), max_tokens: Some(max_tokens),
..Default::default() ..Default::default()
}, })
sampling_options: SamplingOptions::default(), .sampling_options(SamplingOptions::default())
output_options: OutputOptions::default(), .output_options(OutputOptions::default())
eos_token_ids: vec![], .eos_token_ids(vec![])
mdc_sum: None, .annotations(vec![])
annotations: vec![], .build()
estimated_prefix_hit_num_blocks: None, .unwrap()
backend_instance_id: None,
}
} }
// Helper to create mock LLM engine output // Helper to create mock LLM engine output
......
...@@ -44,7 +44,7 @@ use std::collections::HashMap; ...@@ -44,7 +44,7 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{Mutex, OnceCell, mpsc}; use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid; use uuid::Uuid;
pub const MOCKER_COMPONENT: &str = "mocker"; pub const MOCKER_COMPONENT: &str = "mocker";
...@@ -366,7 +366,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -366,7 +366,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
self.direct(direct_request, dp_rank as usize); self.direct(direct_request, dp_rank as usize);
// Create a simple channel for the stream // Create a simple channel for the stream
let (stream_tx, stream_rx) = mpsc::channel::<LLMEngineOutput>(64); let (stream_tx, stream_rx) = mpsc::unbounded_channel::<LLMEngineOutput>();
let active_requests = self.active_requests.clone(); let active_requests = self.active_requests.clone();
let async_context = ctx.context(); let async_context = ctx.context();
...@@ -380,20 +380,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -380,20 +380,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
tokio::select! { tokio::select! {
maybe_signal = request_rx.recv() => { maybe_signal = request_rx.recv() => {
let Some(signal) = maybe_signal else { let Some(signal) = maybe_signal else {
let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string())).await; let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string()));
break; break;
}; };
if signal.completed && token_count < max_tokens - 1 {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string())).await;
break;
}
if signal.completed {
let _ = stream_tx.send(LLMEngineOutput::length()).await;
break;
}
// Generate a new token // Generate a new token
let token_id = generate_random_token(); let token_id = generate_random_token();
token_count += 1; token_count += 1;
...@@ -409,13 +399,25 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -409,13 +399,25 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
index: None, index: None,
}; };
if stream_tx.send(output).await.is_err() { if signal.completed && token_count < max_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break;
}
if signal.completed {
let _ = stream_tx.send(output);
let _ = stream_tx.send(LLMEngineOutput::length());
break;
}
if stream_tx.send(output).is_err() {
tracing::error!("Output stream receiver closed.");
break; break;
} }
} }
_ = async_context.stopped() => { _ = async_context.stopped() => {
let _ = stream_tx.send(LLMEngineOutput::cancelled()).await; let _ = stream_tx.send(LLMEngineOutput::cancelled());
break; break;
} }
} }
...@@ -426,8 +428,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -426,8 +428,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
active.remove(&request_uuid); active.remove(&request_uuid);
}); });
// Create a simple ReceiverStream which is naturally Send + Sync // Create a simple UnboundedReceiverStream which is naturally Send + Sync
let stream = ReceiverStream::new(stream_rx); let stream = UnboundedReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context())) Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
} }
} }
...@@ -632,21 +634,20 @@ mod integration_tests { ...@@ -632,21 +634,20 @@ mod integration_tests {
tracing::info!("✓ Router created"); tracing::info!("✓ Router created");
// Create test requests for both DP workers // Create test requests for both DP workers
let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| PreprocessedRequest { let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| {
model: "mock".to_string(), PreprocessedRequest::builder()
token_ids: tokens, .model("mock".to_string())
batch_token_ids: None, .token_ids(tokens)
stop_conditions: StopConditions { .stop_conditions(StopConditions {
max_tokens: Some(TOKENS_PER_REQUEST as u32), max_tokens: Some(TOKENS_PER_REQUEST as u32),
..Default::default() ..Default::default()
}, })
sampling_options: SamplingOptions::default(), .sampling_options(SamplingOptions::default())
output_options: OutputOptions::default(), .output_options(OutputOptions::default())
eos_token_ids: vec![], .eos_token_ids(vec![])
mdc_sum: None, .annotations(vec![format!("dp_rank:{dp_rank}")])
annotations: vec![format!("dp_rank:{dp_rank}")], .build()
estimated_prefix_hit_num_blocks: None, .unwrap()
backend_instance_id: None,
}; };
let requests = vec![ let requests = vec![
......
...@@ -5,6 +5,7 @@ use derive_builder::Builder; ...@@ -5,6 +5,7 @@ use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{OutputOptions, SamplingOptions, StopConditions}; use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`] /// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
...@@ -54,6 +55,10 @@ pub struct PreprocessedRequest { ...@@ -54,6 +55,10 @@ pub struct PreprocessedRequest {
/// Targeted backend instance ID for the request /// Targeted backend instance ID for the request
#[builder(default)] #[builder(default)]
pub backend_instance_id: Option<i64>, pub backend_instance_id: Option<i64>,
/// Router configuration overrides for this specific request
#[builder(default)]
pub router_config_override: Option<RouterConfigOverride>,
} }
impl PreprocessedRequest { impl PreprocessedRequest {
......
...@@ -11,7 +11,7 @@ from typing import Any, Dict ...@@ -11,7 +11,7 @@ from typing import Any, Dict
import aiohttp import aiohttp
import pytest import pytest
from dynamo._core import DistributedRuntime from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
pytestmark = pytest.mark.pre_merge pytestmark = pytest.mark.pre_merge
...@@ -104,7 +104,7 @@ class KVRouterProcess(ManagedProcess): ...@@ -104,7 +104,7 @@ class KVRouterProcess(ManagedProcess):
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4): async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8):
"""Send a single request with exponential backoff retry""" """Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second wait_time = 1 # Start with 1 second
...@@ -550,3 +550,243 @@ def test_mocker_kv_router_overload_503(request, runtime_services): ...@@ -550,3 +550,243 @@ def test_mocker_kv_router_overload_503(request, runtime_services):
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
@pytest.mark.pre_merge
def test_kv_push_router_bindings(request, runtime_services):
"""
Test KvPushRouter Python bindings with mocker engines.
This test creates KvPushRouter as a Python object and verifies
token streaming with ignore_eos=True and max_tokens=20.
"""
# runtime_services starts etcd and nats
logger.info("Starting KvPushRouter bindings test")
# Create mocker args file
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
mocker_args_file = os.path.join(request.node.name, "mocker_args.json")
with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f)
# Start mocker instances
mocker_processes = []
try:
# Start mockers
for i in range(NUM_MOCKERS):
# Use unique endpoints for each mocker
endpoint = "dyn://test-namespace.mocker.generate"
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}")
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker_processes.append(mocker)
# Start all mockers
for mocker in mocker_processes:
mocker.__enter__()
# Wait for mockers to be ready by sending a dummy request with retry
async def wait_for_mockers_ready():
"""Send a dummy request to ensure mockers are ready"""
runtime = get_runtime()
namespace = runtime.namespace("test-namespace")
component = namespace.component("mocker")
endpoint = component.endpoint("generate")
kv_router_config = KvRouterConfig()
kv_push_router = KvPushRouter(
endpoint=endpoint,
block_size=BLOCK_SIZE,
kv_router_config=kv_router_config,
)
# Dummy request with minimal tokens
dummy_token_ids = [1, 2, 3] # Just a few tokens for testing
max_retries = 8
wait_time = 1
for attempt in range(max_retries + 1):
try:
logger.info(
f"Sending dummy request to check mocker readiness (attempt {attempt + 1})"
)
stream = await kv_push_router.generate(
token_ids=dummy_token_ids,
model=MODEL_NAME,
stop_conditions={"max_tokens": 1}, # Generate just 1 token
sampling_options={"temperature": 0.7},
output_options={
"include_input_tokens": False,
"return_full_text": False,
},
)
# Consume the stream to verify it works
token_count = 0
async for response in stream:
if isinstance(response, dict) and "token_ids" in response:
token_count += len(response["token_ids"])
logger.info(
f"Mockers are ready! Dummy request succeeded on attempt {attempt + 1}"
)
return True
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
await asyncio.sleep(wait_time)
wait_time *= 2 # Exponential backoff
else:
raise RuntimeError(
f"Failed to connect to mockers after {max_retries + 1} attempts"
)
return False
# Wait for mockers to be ready
asyncio.run(wait_for_mockers_ready())
# Run the async test
async def test_kv_push_router():
# Get runtime and create endpoint
runtime = get_runtime()
namespace = runtime.namespace("test-namespace")
component = namespace.component("mocker")
endpoint = component.endpoint("generate")
# Create KvRouterConfig with default settings
kv_router_config = KvRouterConfig()
# Create KvPushRouter Python object
kv_push_router = KvPushRouter(
endpoint=endpoint,
block_size=BLOCK_SIZE,
kv_router_config=kv_router_config,
)
logger.info("Created KvPushRouter Python object")
# Generate random token IDs (100 to 200 tokens)
num_input_tokens = random.randint(100, 200)
token_ids = [random.randint(1, 10000) for _ in range(num_input_tokens)]
logger.info(f"Generated {num_input_tokens} random token IDs")
# Set up generation parameters
stop_conditions = {
"ignore_eos": True, # Don't stop on EOS token
"max_tokens": 20, # Generate exactly 20 tokens
}
sampling_options = {"temperature": 0.7, "top_p": 0.9}
output_options = {"include_input_tokens": False, "return_full_text": False}
# Test with router config overrides
router_config_override = {
"overlap_score_weight": 0.5, # Override the default weight
"router_temperature": 0.5, # Override the default temperature
}
# Call generate method
logger.info(
"Calling generate method on KvPushRouter with router config overrides"
)
logger.info(f"Router config overrides: {router_config_override}")
stream = await kv_push_router.generate(
token_ids=token_ids,
model=MODEL_NAME,
stop_conditions=stop_conditions,
sampling_options=sampling_options,
output_options=output_options,
router_config_override=router_config_override,
)
# Collect tokens from the SSE stream
generated_tokens = []
async for response in stream:
if isinstance(response, dict):
# Check if response has token_ids
if "token_ids" in response:
tokens = response["token_ids"]
if isinstance(tokens, list):
generated_tokens.extend(tokens)
logger.debug(f"Received {len(tokens)} tokens: {tokens}")
# Check for finish reason
if "finish_reason" in response:
logger.info(
f"Stream finished with reason: {response['finish_reason']}"
)
# Verify we got exactly 20 tokens
logger.info(f"Total generated tokens: {len(generated_tokens)}")
assert len(generated_tokens) == 20, (
f"Expected exactly 20 tokens but got {len(generated_tokens)}. "
f"Tokens: {generated_tokens}"
)
logger.info(
"Successfully verified 20 tokens generated via KvPushRouter with overrides"
)
# Test again without overrides
logger.info("Testing again without router config overrides")
stream = await kv_push_router.generate(
token_ids=token_ids[:50], # Use fewer tokens for second test
model=MODEL_NAME,
stop_conditions={"max_tokens": 10},
sampling_options=sampling_options,
output_options=output_options,
# No router_config_override this time
)
generated_tokens_no_override = []
async for response in stream:
if isinstance(response, dict) and "token_ids" in response:
generated_tokens_no_override.extend(response["token_ids"])
assert (
len(generated_tokens_no_override) == 10
), f"Expected 10 tokens but got {len(generated_tokens_no_override)}"
logger.info("Successfully verified generation without overrides")
# Test with partial override (only temperature)
logger.info(
"Testing with partial router config override (temperature only)"
)
partial_override = {"router_temperature": 0.1}
stream = await kv_push_router.generate(
token_ids=token_ids[:30], # Use even fewer tokens
model=MODEL_NAME,
stop_conditions={"max_tokens": 5},
sampling_options=sampling_options,
output_options=output_options,
router_config_override=partial_override,
)
generated_tokens_partial = []
async for response in stream:
if isinstance(response, dict) and "token_ids" in response:
generated_tokens_partial.extend(response["token_ids"])
assert (
len(generated_tokens_partial) == 5
), f"Expected 5 tokens but got {len(generated_tokens_partial)}"
logger.info("Successfully verified generation with partial override")
# Run the async test
asyncio.run(test_kv_push_router())
logger.info("KvPushRouter bindings test completed successfully")
finally:
# Clean up mockers
for mocker in mocker_processes:
mocker.__exit__(None, None, None)
if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment