Unverified Commit cc22114d authored by kornelcsernai-harmonic's avatar kornelcsernai-harmonic Committed by GitHub
Browse files

feat: add least-loaded router (#6314)


Signed-off-by: default avatarKornel Csernai <239206175+kornelcsernai-harmonic@users.noreply.github.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 21ba0c4b
......@@ -186,10 +186,18 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_ROUTER_MODE",
default="round-robin",
help="How to route the request. power-of-two picks 2 random workers and "
"routes to the one with fewer in-flight requests. In disaggregated prefill "
"mode, power-of-two skips bootstrap optimization and falls back to the "
"synchronous prefill path.",
choices=["round-robin", "random", "power-of-two", "kv", "direct"],
"routes to the one with fewer in-flight requests. least-loaded routes to "
"the worker with the fewest active requests. In disaggregated prefill mode, "
"both power-of-two and least-loaded skip bootstrap optimization and fall "
"back to the synchronous prefill path.",
choices=[
"round-robin",
"random",
"power-of-two",
"kv",
"direct",
"least-loaded",
],
)
add_argument(
g,
......
......@@ -7,7 +7,8 @@
# - OpenAI HTTP server.
# - Auto-discovery: Watches etcd for engine/worker registration (via `register_model`).
# - Pre-processor: Prompt templating and tokenization.
# - Router, defaulting to round-robin. Use --router-mode to switch (round-robin, random, kv, direct).
# - Router, defaulting to round-robin. Use --router-mode to switch
# (round-robin, random, kv, direct, least-loaded).
#
# Pass `--interactive` or `-i` for text chat instead of HTTP server.
#
......@@ -230,6 +231,9 @@ async def async_main():
elif config.router_mode == "power-of-two":
router_mode = RouterMode.PowerOfTwoChoices
kv_router_config = None
elif config.router_mode == "least-loaded":
router_mode = RouterMode.LeastLoaded
kv_router_config = None
else:
router_mode = RouterMode.RoundRobin
kv_router_config = None
......
......@@ -85,7 +85,7 @@ spec:
|-----------|---------|-------------|
| `--http-port` | 8000 | HTTP server port |
| `--kserve-grpc-server` | false | Enable KServe gRPC server |
| `--router-mode` | `round-robin` | Routing strategy: `round-robin`, `random`, `kv`, `direct` |
| `--router-mode` | `round-robin` | Routing strategy: `round-robin`, `random`, `kv`, `direct`, `least-loaded` (`power-of-two` and `least-loaded` use synchronous prefill fallback in disaggregated prefill mode) |
See the [Frontend Guide](frontend-guide.md) for full configuration options.
......
......@@ -46,7 +46,7 @@ For all CLI arguments, environment variables, K8s deployment examples, and tunin
**Limitations:**
- Static endpoints not supported—KV router requires dynamic model discovery via etcd to track worker instances and their KV cache states
For basic model registration without KV routing, use `--router-mode round-robin` or `--router-mode random` with both static and dynamic endpoints.
For basic model registration without KV routing, use `--router-mode round-robin`, `--router-mode random`, or `--router-mode least-loaded` with both static and dynamic endpoints.
## Next Steps
......
......@@ -20,6 +20,7 @@ The Dynamo router can be deployed in several configurations. The table below sho
| **Frontend + Random** | `python -m dynamo.frontend --router-mode random` | Random worker selection | None | Aggregated | Stateless load balancing |
| **Frontend + KV (Aggregated)** | `python -m dynamo.frontend --router-mode kv` | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Aggregated | Production single-pool serving with cache reuse |
| **Frontend + KV (Disaggregated)** | `python -m dynamo.frontend --router-mode kv` with prefill + decode workers | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Disaggregated (prefill + decode pools) | Separate prefill/decode for large-scale serving |
| **Frontend + Least-Loaded** | `python -m dynamo.frontend --router-mode least-loaded` | Fewest active connections | None | Aggregated or disaggregated fallback | Simple load-aware balancing without KV awareness |
| **Frontend + Direct** | `python -m dynamo.frontend --router-mode direct` | Worker ID from request hints | None | Aggregated | External orchestrator (e.g., EPP/GAIE) selects workers |
| **Standalone Router** | `python -m dynamo.router` | KV cache overlap + load | NATS Core / JetStream / ZMQ | Any | Routing without the HTTP frontend (multi-tier, custom pipelines) |
......@@ -30,6 +31,7 @@ The Dynamo router can be deployed in several configurations. The table below sho
| **Round-Robin** | `round-robin` (default) | Cycles through available workers in order |
| **Random** | `random` | Selects a random worker for each request |
| **KV** | `kv` | Evaluates KV cache overlap and decode load per worker; picks lowest cost |
| **Least-Loaded** | `least-loaded` | Routes to the worker with fewest active connections; in disaggregated prefill paths it skips bootstrap optimization and falls back to synchronous prefill |
| **Direct** | `direct` | Reads the target `worker_id` from the request's routing hints; no selection logic |
### KV Event Transport Modes (within `--router-mode kv`)
......@@ -214,6 +216,8 @@ We can then use the default routing methods exposed by the client class to send
- **Random routing**: Default strategy, available via `client.generate()` or `client.random()`
- **Round-robin routing**: Cycles through available workers via `client.round_robin()`
- **Direct routing**: Explicitly targets a specific worker via `client.direct(input, component_id)`
- **Least-loaded routing**: Routes to the worker with fewest active connections via `--router-mode least-loaded`
In disaggregated prefill paths it skips bootstrap optimization and uses the synchronous prefill path, matching power-of-two routing.
KV Cache routing uses direct routing with a special worker selection algorithm.
......
......@@ -51,6 +51,7 @@ pub enum RouterMode {
/// Direct routing - reads worker ID from each request's routing hints.
/// Used when an external orchestrator (e.g., EPP) handles worker selection.
Direct,
LeastLoaded,
}
impl From<RouterMode> for RsRouterMode {
......@@ -61,6 +62,7 @@ impl From<RouterMode> for RsRouterMode {
RouterMode::PowerOfTwoChoices => Self::PowerOfTwoChoices,
RouterMode::KV => Self::KV,
RouterMode::Direct => Self::Direct,
RouterMode::LeastLoaded => Self::LeastLoaded,
}
}
}
......
......@@ -1129,6 +1129,7 @@ class RouterMode:
PowerOfTwoChoices: "RouterMode"
KV: "RouterMode"
Direct: "RouterMode"
LeastLoaded: "RouterMode"
...
class RouterConfig:
......@@ -1149,7 +1150,7 @@ class RouterConfig:
Create a RouterConfig.
Args:
mode: The router mode (RoundRobin, Random, KV, or Direct)
mode: The router mode (RoundRobin, Random, KV, Direct, or LeastLoaded)
config: Optional KV router configuration (used when mode is KV)
active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection
active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection
......
......@@ -336,9 +336,10 @@ where
RouterMode::Direct => {
ServiceBackend::from_engine(Arc::new(DirectRoutingRouter::new(router)))
}
RouterMode::Random | RouterMode::RoundRobin | RouterMode::PowerOfTwoChoices => {
ServiceBackend::from_engine(Arc::new(router))
}
RouterMode::Random
| RouterMode::RoundRobin
| RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded => ServiceBackend::from_engine(Arc::new(router)),
RouterMode::KV => {
let Some(chooser) = chooser else {
anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
......
......@@ -64,6 +64,8 @@ mod registry;
pub mod service;
pub use client::Client;
pub(crate) use client::RoutingOccupancyState;
pub(crate) use client::get_or_create_routing_occupancy_state;
pub use endpoint::build_transport_type;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
......
......@@ -2,10 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
use std::sync::atomic::{AtomicU64, Ordering};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use anyhow::Result;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use futures::StreamExt;
use tokio::net::unix::pipe::Receiver;
......@@ -21,6 +26,70 @@ use crate::{
transports::etcd::Client as EtcdClient,
};
/// Shared occupancy state for routing modes that track per-worker in-flight requests.
#[derive(Debug, Default)]
pub(crate) struct RoutingOccupancyState {
counts: DashMap<u64, AtomicU64>,
exact_selection_lock: tokio::sync::Mutex<()>,
}
impl RoutingOccupancyState {
pub(crate) fn increment(&self, instance_id: u64) {
self.counts
.entry(instance_id)
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
let _guard = self.exact_selection_lock.lock().await;
let id = *instance_ids.iter().min_by_key(|&&id| self.load(id))?;
self.increment(id);
Some(id)
}
pub(crate) fn decrement(&self, instance_id: u64) {
if let Some(count) = self.counts.get(&instance_id) {
let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(1))
});
}
}
pub(crate) fn load(&self, instance_id: u64) -> u64 {
self.counts
.get(&instance_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub(crate) fn retain(&self, instance_ids: &[u64]) {
let live: HashSet<u64> = instance_ids.iter().copied().collect();
self.counts.retain(|id, _| live.contains(id));
}
}
/// Get or create the shared routing occupancy state for an endpoint.
pub(crate) async fn get_or_create_routing_occupancy_state(
endpoint: &Endpoint,
) -> Arc<RoutingOccupancyState> {
let drt = endpoint.drt();
let registry = drt.routing_occupancy_states();
let mut registry = registry.lock().await;
if let Some(weak) = registry.get(endpoint) {
if let Some(state) = weak.upgrade() {
return state;
} else {
registry.remove(endpoint);
}
}
let state = Arc::new(RoutingOccupancyState::default());
registry.insert(endpoint.clone(), Arc::downgrade(&state));
state
}
/// Default interval for periodic reconciliation of instance_avail with instance_source
const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
......@@ -181,6 +250,15 @@ impl Client {
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids.clone()));
// Clean up stale occupancy counters for instances that no longer exist.
let registry = client.endpoint.drt().routing_occupancy_states();
if let Ok(registry) = registry.try_lock()
&& let Some(weak) = registry.get(&client.endpoint)
&& let Some(state) = weak.upgrade()
{
state.retain(&instance_ids);
}
// Send update to watch channel subscribers
let _ = client.instance_avail_tx.send(instance_ids);
......@@ -396,4 +474,123 @@ mod tests {
rt.shutdown();
}
/// Test that concurrent select_and_increment distributes load correctly.
#[tokio::test]
async fn test_concurrent_select_and_increment() {
let state = Arc::new(RoutingOccupancyState::default());
let instance_ids: Vec<u64> = vec![100, 200, 300];
let num_requests = 90;
let mut handles = Vec::new();
for _ in 0..num_requests {
let state = state.clone();
let ids = instance_ids.clone();
handles.push(tokio::spawn(async move {
state.select_exact_min_and_increment(&ids).await
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(state.load(100), 30);
assert_eq!(state.load(200), 30);
assert_eq!(state.load(300), 30);
}
#[tokio::test]
async fn test_connection_counts() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
let picked1 = state1
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
assert_eq!(state1.load(picked1), 1);
let picked2 = state1
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
assert_ne!(picked1, picked2);
// state2 should see the same counts (same underlying Arc)
assert_eq!(state2.load(10), state1.load(10));
assert_eq!(state2.load(20), state1.load(20));
assert_eq!(state2.load(30), state1.load(30));
state2.decrement(picked1);
assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
rt.shutdown();
}
#[tokio::test]
async fn test_least_loaded_state_retain() {
let state = RoutingOccupancyState::default();
// Add some connections
state.select_exact_min_and_increment(&[1, 2, 3]).await;
state.select_exact_min_and_increment(&[1, 2, 3]).await;
state.select_exact_min_and_increment(&[1, 2, 3]).await;
// Each instance should have 1 connection
assert_eq!(state.load(1), 1);
assert_eq!(state.load(2), 1);
assert_eq!(state.load(3), 1);
// Retain only instances 1 and 3 (instance 2 was removed)
state.retain(&[1, 3]);
assert_eq!(state.load(1), 1);
assert_eq!(state.load(2), 0);
assert_eq!(state.load(3), 1);
}
#[tokio::test]
async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let worker_id = client.instance_ids_avail()[0];
let state = get_or_create_routing_occupancy_state(&endpoint).await;
state.increment(worker_id);
assert_eq!(state.load(worker_id), 1);
endpoint.unregister_endpoint_instance().await.unwrap();
for _ in 0..10 {
if state.load(worker_id) == 0 {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert_eq!(state.load(worker_id), 0);
rt.shutdown();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::component::{Component, Instance};
use crate::component::{
self, Component, ComponentBuilder, Endpoint, Instance, Namespace, RoutingOccupancyState,
};
use crate::pipeline::PipelineError;
use crate::pipeline::network::manager::NetworkManager;
use crate::service::{ServiceClient, ServiceSet};
use crate::storage::kv;
use crate::{discovery, system_status_server, transports};
use crate::{
component::{self, ComponentBuilder, Endpoint, Namespace},
discovery::Discovery,
metrics::PrometheusUpdateCallback,
metrics::{MetricsHierarchy, MetricsRegistry},
transports::{etcd, nats, tcp},
};
use crate::{discovery, system_status_server, transports};
use super::utils::GracefulShutdownTracker;
use crate::SystemHealth;
......@@ -35,6 +36,7 @@ use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
type InstanceMap = HashMap<Endpoint, Weak<Receiver<Vec<Instance>>>>;
type RoutingOccupancyMap = HashMap<Endpoint, Weak<RoutingOccupancyState>>;
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
......@@ -64,6 +66,7 @@ pub struct DistributedRuntime {
component_registry: component::Registry,
instance_sources: Arc<tokio::sync::Mutex<InstanceMap>>,
routing_occupancy_states: Arc<tokio::sync::Mutex<RoutingOccupancyMap>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
......@@ -185,6 +188,7 @@ impl DistributedRuntime {
discovery_metadata,
component_registry,
instance_sources: Arc::new(Mutex::new(HashMap::new())),
routing_occupancy_states: Arc::new(Mutex::new(HashMap::new())),
metrics_registry: crate::MetricsRegistry::new(),
system_health,
request_plane,
......@@ -390,6 +394,10 @@ impl DistributedRuntime {
self.instance_sources.clone()
}
pub(crate) fn routing_occupancy_states(&self) -> Arc<Mutex<RoutingOccupancyMap>> {
self.routing_occupancy_states.clone()
}
/// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for
/// Component, to allow it to publish to NATS. KV Router is the only user.
///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment