Unverified Commit 29813508 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): KV-aware routing (#1064)

Router:
```
dynamo-run in=http out=dyn://dynamo.endpoint.generate --router-mode kv
```

Worker (* N):
```
dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B
```

You need patched vllm and the C bindings `.so`. Full docs in the updated guide: `docs/guides/dynamo_run.md`.

This gives us a pure-Rust ingress node: OpenAI compliant HTTP server + Pre-processor + KV-aware router.
parent b82e7327
...@@ -28,7 +28,6 @@ use dynamo_runtime::{ ...@@ -28,7 +28,6 @@ use dynamo_runtime::{
use futures::stream; use futures::stream;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing as log;
pub struct KvEventPublisher { pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
...@@ -45,7 +44,7 @@ impl KvEventPublisher { ...@@ -45,7 +44,7 @@ impl KvEventPublisher {
} }
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> { pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
log::debug!("Publish event: {:?}", event); tracing::debug!("Publish event: {:?}", event);
self.tx.send(event) self.tx.send(event)
} }
...@@ -60,7 +59,7 @@ fn start_publish_task( ...@@ -60,7 +59,7 @@ fn start_publish_task(
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>, mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) { ) {
let component_clone = component.clone(); let component_clone = component.clone();
log::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT); tracing::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT);
_ = component.drt().runtime().secondary().spawn(async move { _ = component.drt().runtime().secondary().spawn(async move {
while let Some(event) = rx.recv().await { while let Some(event) = rx.recv().await {
...@@ -88,7 +87,7 @@ impl KvMetricsPublisher { ...@@ -88,7 +87,7 @@ impl KvMetricsPublisher {
&self, &self,
metrics: Arc<ForwardPassMetrics>, metrics: Arc<ForwardPassMetrics>,
) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> { ) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> {
log::debug!("Publish metrics: {:?}", metrics); tracing::trace!("Publish metrics: {metrics:?}");
self.tx.send(metrics) self.tx.send(metrics)
} }
......
...@@ -112,12 +112,11 @@ impl KvScheduler { ...@@ -112,12 +112,11 @@ impl KvScheduler {
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
tracing::debug!("scheduler starting");
// Background task to handle scheduling requests // Background task to handle scheduling requests
tokio::spawn(async move { tokio::spawn(async move {
let mut request: SchedulingRequest; let mut request: SchedulingRequest;
let mut request_rx = request_rx; let mut request_rx = request_rx;
tracing::debug!("scheduler background task started"); tracing::trace!("scheduler background task started");
'outer: loop { 'outer: loop {
request = tokio::select! { request = tokio::select! {
...@@ -141,7 +140,6 @@ impl KvScheduler { ...@@ -141,7 +140,6 @@ impl KvScheduler {
continue 'outer; continue 'outer;
} }
}; };
tracing::debug!("selected");
loop { loop {
match selector.select_worker(&endpoints, &request, block_size) { match selector.select_worker(&endpoints, &request, block_size) {
Ok(selection) => { Ok(selection) => {
...@@ -189,17 +187,13 @@ impl KvScheduler { ...@@ -189,17 +187,13 @@ impl KvScheduler {
overlap, overlap,
resp_tx, resp_tx,
}; };
tracing::debug!("before sending request");
self.request_tx self.request_tx
.send(request) .send(request)
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
tracing::debug!("after sending request");
let res = resp_rx let res = resp_rx
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
tracing::debug!("after receiving response");
Ok(res) Ok(res)
} }
} }
...@@ -219,7 +213,10 @@ pub fn process_worker_selection( ...@@ -219,7 +213,10 @@ pub fn process_worker_selection(
// Will be overwritten on next polling of metrics // Will be overwritten on next polling of metrics
worker.data.num_requests_waiting += 1; worker.data.num_requests_waiting += 1;
// Assumes radix attention so KV load is only incremented by uncached blocks // Assumes radix attention so KV load is only incremented by uncached blocks
worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64; // overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not.
worker.data.kv_active_blocks += selection
.required_blocks
.saturating_sub(selection.overlap_blocks as u64);
// Emit event // Emit event
if let Err(e) = event_tx.send(KVHitRateEvent { if let Err(e) = event_tx.send(KVHitRateEvent {
...@@ -246,6 +243,10 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -246,6 +243,10 @@ impl WorkerSelector for DefaultWorkerSelector {
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0); assert!(request.isl_tokens > 0);
if workers.endpoints.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
}
let mut worker_scores = HashMap::new(); let mut worker_scores = HashMap::new();
let mut max_waiting = 0.0; let mut max_waiting = 0.0;
...@@ -261,10 +262,6 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -261,10 +262,6 @@ impl WorkerSelector for DefaultWorkerSelector {
max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64); max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
} }
if max_waiting == 0.0 {
return Err(KvSchedulerError::NoEndpoints);
}
// make immutable // make immutable
let worker_scores = worker_scores; let worker_scores = worker_scores;
let max_waiting = max_waiting; let max_waiting = max_waiting;
...@@ -291,13 +288,8 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -291,13 +288,8 @@ impl WorkerSelector for DefaultWorkerSelector {
// Calculate logit using same formula as Python // Calculate logit using same formula as Python
let logit = 2.0 * score - gpu_cache_usage - normalized_waiting; let logit = 2.0 * score - gpu_cache_usage - normalized_waiting;
tracing::info!( tracing::trace!(
"Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}", "Formula for {worker_id}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}",
worker_id,
logit,
score,
gpu_cache_usage,
normalized_waiting
); );
// Track best workers // Track best workers
...@@ -318,7 +310,7 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -318,7 +310,7 @@ impl WorkerSelector for DefaultWorkerSelector {
if best_workers.is_empty() { if best_workers.is_empty() {
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} else if best_logit == 0.0 { } else if best_logit == 0.0 {
tracing::warn!("best worker logit is 0"); tracing::debug!("best worker logit is 0");
} }
let worker_id = if best_workers.len() == 1 { let worker_id = if best_workers.len() == 1 {
...@@ -329,9 +321,10 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -329,9 +321,10 @@ impl WorkerSelector for DefaultWorkerSelector {
best_workers[rng.random_range(0..best_workers.len())] best_workers[rng.random_range(0..best_workers.len())]
}; };
// Log selection metrics // Lower to trace level eventually. Nice to see KV routing working for now.
tracing::info!("Selected worker: {}, logit: {:.3}", worker_id, best_logit); tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}");
// Log selection metrics
let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64; let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64;
let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize; let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
......
...@@ -27,6 +27,7 @@ pub mod http; ...@@ -27,6 +27,7 @@ pub mod http;
pub mod hub; pub mod hub;
pub mod key_value_store; pub mod key_value_store;
pub mod kv_router; pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod model_card; pub mod model_card;
pub mod model_type; pub mod model_type;
pub mod preprocessor; pub mod preprocessor;
......
...@@ -20,7 +20,8 @@ use crate::protocols::TokenIdType; ...@@ -20,7 +20,8 @@ use crate::protocols::TokenIdType;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>; pub type LogProbs = Vec<f64>;
pub use super::preprocessor::PreprocessedRequest as BackendInput; pub use super::preprocessor::PreprocessedRequest as BackendInput; // TODO stop renaming this
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason; pub use super::FinishReason;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
......
...@@ -42,7 +42,12 @@ where ...@@ -42,7 +42,12 @@ where
pub client: Client, pub client: Client,
/// How we choose which endpoint to send traffic to. /// How we choose which endpoint to send traffic to.
router_mode: RouterMode, ///
/// Setting this to None means we never intend to call `generate` on this PushRouter. We are
/// not using it as an AsyncEngine.
/// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
/// dynamo-llm's KV Routing does this.
router_mode: Option<RouterMode>,
/// Number of round robin requests handled. Used to decide which server is next. /// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>, round_robin_counter: Arc<AtomicU64>,
...@@ -57,14 +62,13 @@ where ...@@ -57,14 +62,13 @@ where
_phantom: PhantomData<(T, U)>, _phantom: PhantomData<(T, U)>,
} }
#[derive(Default, Debug, Clone, Copy)] // Note there's no KV router in here because we are in dynamo-runtime. The KvRouter is in
// dynamo-llm.
#[derive(Default, Debug, Clone)]
pub enum RouterMode { pub enum RouterMode {
#[default] #[default]
Random, Random,
RoundRobin, RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID. Used by Python bindings.
Direct(i64), Direct(i64),
} }
...@@ -80,7 +84,10 @@ where ...@@ -80,7 +84,10 @@ where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de>,
{ {
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> { pub async fn from_client(
client: Client,
router_mode: Option<RouterMode>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?; let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter { Ok(PushRouter {
client, client,
...@@ -182,9 +189,12 @@ where ...@@ -182,9 +189,12 @@ where
match &self.client.endpoints { match &self.client.endpoints {
EndpointSource::Static => self.r#static(request).await, EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode { EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await, Some(RouterMode::Random) => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await, Some(RouterMode::RoundRobin) => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await, Some(RouterMode::Direct(endpoint_id)) => self.direct(request, endpoint_id).await,
None => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
}, },
} }
} }
......
...@@ -173,6 +173,16 @@ impl FromStr for Endpoint { ...@@ -173,6 +173,16 @@ impl FromStr for Endpoint {
} }
} }
impl Endpoint {
/// As a String like dyn://dynamo.internal.worker
pub fn as_url(&self) -> String {
format!(
"{ENDPOINT_SCHEME}{}.{}.{}",
self.namespace, self.component, self.name
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum RouterType { pub enum RouterType {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment