Unverified Commit 29813508 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): KV-aware routing (#1064)

Router:
```
dynamo-run in=http out=dyn://dynamo.endpoint.generate --router-mode kv
```

Worker (* N):
```
dynamo-run in=dyn://dynamo.endpoint.generate out=vllm /data/llms/Qwen/Qwen3-4B
```

You need patched vllm and the C bindings `.so`. Full docs in the updated guide: `docs/guides/dynamo_run.md`.

This gives us a pure-Rust ingress node: OpenAI compliant HTTP server + Pre-processor + KV-aware router.
parent b82e7327
......@@ -28,7 +28,6 @@ use dynamo_runtime::{
use futures::stream;
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing as log;
pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
......@@ -45,7 +44,7 @@ impl KvEventPublisher {
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
log::debug!("Publish event: {:?}", event);
tracing::debug!("Publish event: {:?}", event);
self.tx.send(event)
}
......@@ -60,7 +59,7 @@ fn start_publish_task(
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) {
let component_clone = component.clone();
log::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT);
tracing::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT);
_ = component.drt().runtime().secondary().spawn(async move {
while let Some(event) = rx.recv().await {
......@@ -88,7 +87,7 @@ impl KvMetricsPublisher {
&self,
metrics: Arc<ForwardPassMetrics>,
) -> Result<(), tokio::sync::watch::error::SendError<Arc<ForwardPassMetrics>>> {
log::debug!("Publish metrics: {:?}", metrics);
tracing::trace!("Publish metrics: {metrics:?}");
self.tx.send(metrics)
}
......
......@@ -112,12 +112,11 @@ impl KvScheduler {
// Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
tracing::debug!("scheduler starting");
// Background task to handle scheduling requests
tokio::spawn(async move {
let mut request: SchedulingRequest;
let mut request_rx = request_rx;
tracing::debug!("scheduler background task started");
tracing::trace!("scheduler background task started");
'outer: loop {
request = tokio::select! {
......@@ -141,7 +140,6 @@ impl KvScheduler {
continue 'outer;
}
};
tracing::debug!("selected");
loop {
match selector.select_worker(&endpoints, &request, block_size) {
Ok(selection) => {
......@@ -189,17 +187,13 @@ impl KvScheduler {
overlap,
resp_tx,
};
tracing::debug!("before sending request");
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
tracing::debug!("after sending request");
let res = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
tracing::debug!("after receiving response");
Ok(res)
}
}
......@@ -219,7 +213,10 @@ pub fn process_worker_selection(
// Will be overwritten on next polling of metrics
worker.data.num_requests_waiting += 1;
// Assumes radix attention so KV load is only incremented by uncached blocks
worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64;
// overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not.
worker.data.kv_active_blocks += selection
.required_blocks
.saturating_sub(selection.overlap_blocks as u64);
// Emit event
if let Err(e) = event_tx.send(KVHitRateEvent {
......@@ -246,6 +243,10 @@ impl WorkerSelector for DefaultWorkerSelector {
) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0);
if workers.endpoints.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
}
let mut worker_scores = HashMap::new();
let mut max_waiting = 0.0;
......@@ -261,10 +262,6 @@ impl WorkerSelector for DefaultWorkerSelector {
max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
}
if max_waiting == 0.0 {
return Err(KvSchedulerError::NoEndpoints);
}
// make immutable
let worker_scores = worker_scores;
let max_waiting = max_waiting;
......@@ -291,13 +288,8 @@ impl WorkerSelector for DefaultWorkerSelector {
// Calculate logit using same formula as Python
let logit = 2.0 * score - gpu_cache_usage - normalized_waiting;
tracing::info!(
"Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}",
worker_id,
logit,
score,
gpu_cache_usage,
normalized_waiting
tracing::trace!(
"Formula for {worker_id}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}",
);
// Track best workers
......@@ -318,7 +310,7 @@ impl WorkerSelector for DefaultWorkerSelector {
if best_workers.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
} else if best_logit == 0.0 {
tracing::warn!("best worker logit is 0");
tracing::debug!("best worker logit is 0");
}
let worker_id = if best_workers.len() == 1 {
......@@ -329,9 +321,10 @@ impl WorkerSelector for DefaultWorkerSelector {
best_workers[rng.random_range(0..best_workers.len())]
};
// Log selection metrics
tracing::info!("Selected worker: {}, logit: {:.3}", worker_id, best_logit);
// Lower to trace level eventually. Nice to see KV routing working for now.
tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}");
// Log selection metrics
let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64;
let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
......
......@@ -27,6 +27,7 @@ pub mod http;
pub mod hub;
pub mod key_value_store;
pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod model_card;
pub mod model_type;
pub mod preprocessor;
......
......@@ -20,7 +20,8 @@ use crate::protocols::TokenIdType;
pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>;
pub use super::preprocessor::PreprocessedRequest as BackendInput;
pub use super::preprocessor::PreprocessedRequest as BackendInput; // TODO stop renaming this
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
......
......@@ -42,7 +42,12 @@ where
pub client: Client,
/// How we choose which endpoint to send traffic to.
router_mode: RouterMode,
///
/// Setting this to None means we never intend to call `generate` on this PushRouter. We are
/// not using it as an AsyncEngine.
/// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
/// dynamo-llm's KV Routing does this.
router_mode: Option<RouterMode>,
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
......@@ -57,14 +62,13 @@ where
_phantom: PhantomData<(T, U)>,
}
#[derive(Default, Debug, Clone, Copy)]
// Note there's no KV router in here because we are in dynamo-runtime. The KvRouter is in
// dynamo-llm.
#[derive(Default, Debug, Clone)]
pub enum RouterMode {
#[default]
Random,
RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID. Used by Python bindings.
Direct(i64),
}
......@@ -80,7 +84,10 @@ where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
pub async fn from_client(
client: Client,
router_mode: Option<RouterMode>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client,
......@@ -182,9 +189,12 @@ where
match &self.client.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
Some(RouterMode::Random) => self.random(request).await,
Some(RouterMode::RoundRobin) => self.round_robin(request).await,
Some(RouterMode::Direct(endpoint_id)) => self.direct(request, endpoint_id).await,
None => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
},
}
}
......
......@@ -173,6 +173,16 @@ impl FromStr for Endpoint {
}
}
impl Endpoint {
/// As a String like dyn://dynamo.internal.worker
pub fn as_url(&self) -> String {
format!(
"{ENDPOINT_SCHEME}{}.{}.{}",
self.namespace, self.component, self.name
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RouterType {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment