Unverified Commit cc22114d authored by kornelcsernai-harmonic's avatar kornelcsernai-harmonic Committed by GitHub
Browse files

feat: add least-loaded router (#6314)


Signed-off-by: default avatarKornel Csernai <239206175+kornelcsernai-harmonic@users.noreply.github.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 21ba0c4b
......@@ -186,10 +186,18 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_ROUTER_MODE",
default="round-robin",
help="How to route the request. power-of-two picks 2 random workers and "
"routes to the one with fewer in-flight requests. In disaggregated prefill "
"mode, power-of-two skips bootstrap optimization and falls back to the "
"synchronous prefill path.",
choices=["round-robin", "random", "power-of-two", "kv", "direct"],
"routes to the one with fewer in-flight requests. least-loaded routes to "
"the worker with the fewest active requests. In disaggregated prefill mode, "
"both power-of-two and least-loaded skip bootstrap optimization and fall "
"back to the synchronous prefill path.",
choices=[
"round-robin",
"random",
"power-of-two",
"kv",
"direct",
"least-loaded",
],
)
add_argument(
g,
......
......@@ -7,7 +7,8 @@
# - OpenAI HTTP server.
# - Auto-discovery: Watches etcd for engine/worker registration (via `register_model`).
# - Pre-processor: Prompt templating and tokenization.
# - Router, defaulting to round-robin. Use --router-mode to switch (round-robin, random, kv, direct).
# - Router, defaulting to round-robin. Use --router-mode to switch
# (round-robin, random, kv, direct, least-loaded).
#
# Pass `--interactive` or `-i` for text chat instead of HTTP server.
#
......@@ -230,6 +231,9 @@ async def async_main():
elif config.router_mode == "power-of-two":
router_mode = RouterMode.PowerOfTwoChoices
kv_router_config = None
elif config.router_mode == "least-loaded":
router_mode = RouterMode.LeastLoaded
kv_router_config = None
else:
router_mode = RouterMode.RoundRobin
kv_router_config = None
......
......@@ -85,7 +85,7 @@ spec:
|-----------|---------|-------------|
| `--http-port` | 8000 | HTTP server port |
| `--kserve-grpc-server` | false | Enable KServe gRPC server |
| `--router-mode` | `round-robin` | Routing strategy: `round-robin`, `random`, `kv`, `direct` |
| `--router-mode` | `round-robin` | Routing strategy: `round-robin`, `random`, `kv`, `direct`, `least-loaded` (`power-of-two` and `least-loaded` use synchronous prefill fallback in disaggregated prefill mode) |
See the [Frontend Guide](frontend-guide.md) for full configuration options.
......
......@@ -46,7 +46,7 @@ For all CLI arguments, environment variables, K8s deployment examples, and tunin
**Limitations:**
- Static endpoints not supported—KV router requires dynamic model discovery via etcd to track worker instances and their KV cache states
For basic model registration without KV routing, use `--router-mode round-robin` or `--router-mode random` with both static and dynamic endpoints.
For basic model registration without KV routing, use `--router-mode round-robin`, `--router-mode random`, or `--router-mode least-loaded` with both static and dynamic endpoints.
## Next Steps
......
......@@ -20,6 +20,7 @@ The Dynamo router can be deployed in several configurations. The table below sho
| **Frontend + Random** | `python -m dynamo.frontend --router-mode random` | Random worker selection | None | Aggregated | Stateless load balancing |
| **Frontend + KV (Aggregated)** | `python -m dynamo.frontend --router-mode kv` | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Aggregated | Production single-pool serving with cache reuse |
| **Frontend + KV (Disaggregated)** | `python -m dynamo.frontend --router-mode kv` with prefill + decode workers | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Disaggregated (prefill + decode pools) | Separate prefill/decode for large-scale serving |
| **Frontend + Least-Loaded** | `python -m dynamo.frontend --router-mode least-loaded` | Fewest active connections | None | Aggregated or disaggregated fallback | Simple load-aware balancing without KV awareness |
| **Frontend + Direct** | `python -m dynamo.frontend --router-mode direct` | Worker ID from request hints | None | Aggregated | External orchestrator (e.g., EPP/GAIE) selects workers |
| **Standalone Router** | `python -m dynamo.router` | KV cache overlap + load | NATS Core / JetStream / ZMQ | Any | Routing without the HTTP frontend (multi-tier, custom pipelines) |
......@@ -30,6 +31,7 @@ The Dynamo router can be deployed in several configurations. The table below sho
| **Round-Robin** | `round-robin` (default) | Cycles through available workers in order |
| **Random** | `random` | Selects a random worker for each request |
| **KV** | `kv` | Evaluates KV cache overlap and decode load per worker; picks lowest cost |
| **Least-Loaded** | `least-loaded` | Routes to the worker with fewest active connections; in disaggregated prefill paths it skips bootstrap optimization and falls back to synchronous prefill |
| **Direct** | `direct` | Reads the target `worker_id` from the request's routing hints; no selection logic |
### KV Event Transport Modes (within `--router-mode kv`)
......@@ -214,6 +216,8 @@ We can then use the default routing methods exposed by the client class to send
- **Random routing**: Default strategy, available via `client.generate()` or `client.random()`
- **Round-robin routing**: Cycles through available workers via `client.round_robin()`
- **Direct routing**: Explicitly targets a specific worker via `client.direct(input, component_id)`
- **Least-loaded routing**: Routes to the worker with fewest active connections via `--router-mode least-loaded`
In disaggregated prefill paths it skips bootstrap optimization and uses the synchronous prefill path, matching power-of-two routing.
KV Cache routing uses direct routing with a special worker selection algorithm.
......
......@@ -51,6 +51,7 @@ pub enum RouterMode {
/// Direct routing - reads worker ID from each request's routing hints.
/// Used when an external orchestrator (e.g., EPP) handles worker selection.
Direct,
LeastLoaded,
}
impl From<RouterMode> for RsRouterMode {
......@@ -61,6 +62,7 @@ impl From<RouterMode> for RsRouterMode {
RouterMode::PowerOfTwoChoices => Self::PowerOfTwoChoices,
RouterMode::KV => Self::KV,
RouterMode::Direct => Self::Direct,
RouterMode::LeastLoaded => Self::LeastLoaded,
}
}
}
......
......@@ -1129,6 +1129,7 @@ class RouterMode:
PowerOfTwoChoices: "RouterMode"
KV: "RouterMode"
Direct: "RouterMode"
LeastLoaded: "RouterMode"
...
class RouterConfig:
......@@ -1149,7 +1150,7 @@ class RouterConfig:
Create a RouterConfig.
Args:
mode: The router mode (RoundRobin, Random, KV, or Direct)
mode: The router mode (RoundRobin, Random, KV, Direct, or LeastLoaded)
config: Optional KV router configuration (used when mode is KV)
active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection
active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection
......
......@@ -336,9 +336,10 @@ where
RouterMode::Direct => {
ServiceBackend::from_engine(Arc::new(DirectRoutingRouter::new(router)))
}
RouterMode::Random | RouterMode::RoundRobin | RouterMode::PowerOfTwoChoices => {
ServiceBackend::from_engine(Arc::new(router))
}
RouterMode::Random
| RouterMode::RoundRobin
| RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded => ServiceBackend::from_engine(Arc::new(router)),
RouterMode::KV => {
let Some(chooser) = chooser else {
anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
......
......@@ -64,6 +64,8 @@ mod registry;
pub mod service;
pub use client::Client;
pub(crate) use client::RoutingOccupancyState;
pub(crate) use client::get_or_create_routing_occupancy_state;
pub use endpoint::build_transport_type;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
......
......@@ -2,10 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
use std::sync::atomic::{AtomicU64, Ordering};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use anyhow::Result;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use futures::StreamExt;
use tokio::net::unix::pipe::Receiver;
......@@ -21,6 +26,70 @@ use crate::{
transports::etcd::Client as EtcdClient,
};
/// Shared occupancy state for routing modes that track per-worker in-flight requests.
#[derive(Debug, Default)]
pub(crate) struct RoutingOccupancyState {
counts: DashMap<u64, AtomicU64>,
exact_selection_lock: tokio::sync::Mutex<()>,
}
impl RoutingOccupancyState {
pub(crate) fn increment(&self, instance_id: u64) {
self.counts
.entry(instance_id)
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
let _guard = self.exact_selection_lock.lock().await;
let id = *instance_ids.iter().min_by_key(|&&id| self.load(id))?;
self.increment(id);
Some(id)
}
pub(crate) fn decrement(&self, instance_id: u64) {
if let Some(count) = self.counts.get(&instance_id) {
let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(1))
});
}
}
pub(crate) fn load(&self, instance_id: u64) -> u64 {
self.counts
.get(&instance_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub(crate) fn retain(&self, instance_ids: &[u64]) {
let live: HashSet<u64> = instance_ids.iter().copied().collect();
self.counts.retain(|id, _| live.contains(id));
}
}
/// Get or create the shared routing occupancy state for an endpoint.
pub(crate) async fn get_or_create_routing_occupancy_state(
endpoint: &Endpoint,
) -> Arc<RoutingOccupancyState> {
let drt = endpoint.drt();
let registry = drt.routing_occupancy_states();
let mut registry = registry.lock().await;
if let Some(weak) = registry.get(endpoint) {
if let Some(state) = weak.upgrade() {
return state;
} else {
registry.remove(endpoint);
}
}
let state = Arc::new(RoutingOccupancyState::default());
registry.insert(endpoint.clone(), Arc::downgrade(&state));
state
}
/// Default interval for periodic reconciliation of instance_avail with instance_source
const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
......@@ -181,6 +250,15 @@ impl Client {
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids.clone()));
// Clean up stale occupancy counters for instances that no longer exist.
let registry = client.endpoint.drt().routing_occupancy_states();
if let Ok(registry) = registry.try_lock()
&& let Some(weak) = registry.get(&client.endpoint)
&& let Some(state) = weak.upgrade()
{
state.retain(&instance_ids);
}
// Send update to watch channel subscribers
let _ = client.instance_avail_tx.send(instance_ids);
......@@ -396,4 +474,123 @@ mod tests {
rt.shutdown();
}
/// Test that concurrent select_and_increment distributes load correctly.
#[tokio::test]
async fn test_concurrent_select_and_increment() {
let state = Arc::new(RoutingOccupancyState::default());
let instance_ids: Vec<u64> = vec![100, 200, 300];
let num_requests = 90;
let mut handles = Vec::new();
for _ in 0..num_requests {
let state = state.clone();
let ids = instance_ids.clone();
handles.push(tokio::spawn(async move {
state.select_exact_min_and_increment(&ids).await
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(state.load(100), 30);
assert_eq!(state.load(200), 30);
assert_eq!(state.load(300), 30);
}
#[tokio::test]
async fn test_connection_counts() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
let picked1 = state1
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
assert_eq!(state1.load(picked1), 1);
let picked2 = state1
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
assert_ne!(picked1, picked2);
// state2 should see the same counts (same underlying Arc)
assert_eq!(state2.load(10), state1.load(10));
assert_eq!(state2.load(20), state1.load(20));
assert_eq!(state2.load(30), state1.load(30));
state2.decrement(picked1);
assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
rt.shutdown();
}
#[tokio::test]
async fn test_least_loaded_state_retain() {
let state = RoutingOccupancyState::default();
// Add some connections
state.select_exact_min_and_increment(&[1, 2, 3]).await;
state.select_exact_min_and_increment(&[1, 2, 3]).await;
state.select_exact_min_and_increment(&[1, 2, 3]).await;
// Each instance should have 1 connection
assert_eq!(state.load(1), 1);
assert_eq!(state.load(2), 1);
assert_eq!(state.load(3), 1);
// Retain only instances 1 and 3 (instance 2 was removed)
state.retain(&[1, 3]);
assert_eq!(state.load(1), 1);
assert_eq!(state.load(2), 0);
assert_eq!(state.load(3), 1);
}
#[tokio::test]
async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let worker_id = client.instance_ids_avail()[0];
let state = get_or_create_routing_occupancy_state(&endpoint).await;
state.increment(worker_id);
assert_eq!(state.load(worker_id), 1);
endpoint.unregister_endpoint_instance().await.unwrap();
for _ in 0..10 {
if state.load(worker_id) == 0 {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert_eq!(state.load(worker_id), 0);
rt.shutdown();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::component::{Component, Instance};
use crate::component::{
self, Component, ComponentBuilder, Endpoint, Instance, Namespace, RoutingOccupancyState,
};
use crate::pipeline::PipelineError;
use crate::pipeline::network::manager::NetworkManager;
use crate::service::{ServiceClient, ServiceSet};
use crate::storage::kv;
use crate::{discovery, system_status_server, transports};
use crate::{
component::{self, ComponentBuilder, Endpoint, Namespace},
discovery::Discovery,
metrics::PrometheusUpdateCallback,
metrics::{MetricsHierarchy, MetricsRegistry},
transports::{etcd, nats, tcp},
};
use crate::{discovery, system_status_server, transports};
use super::utils::GracefulShutdownTracker;
use crate::SystemHealth;
......@@ -35,6 +36,7 @@ use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
type InstanceMap = HashMap<Endpoint, Weak<Receiver<Vec<Instance>>>>;
type RoutingOccupancyMap = HashMap<Endpoint, Weak<RoutingOccupancyState>>;
/// Distributed [Runtime] which provides access to shared resources across the cluster, this includes
/// communication protocols and transports.
......@@ -64,6 +66,7 @@ pub struct DistributedRuntime {
component_registry: component::Registry,
instance_sources: Arc<tokio::sync::Mutex<InstanceMap>>,
routing_occupancy_states: Arc<tokio::sync::Mutex<RoutingOccupancyMap>>,
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
......@@ -185,6 +188,7 @@ impl DistributedRuntime {
discovery_metadata,
component_registry,
instance_sources: Arc::new(Mutex::new(HashMap::new())),
routing_occupancy_states: Arc::new(Mutex::new(HashMap::new())),
metrics_registry: crate::MetricsRegistry::new(),
system_health,
request_plane,
......@@ -390,6 +394,10 @@ impl DistributedRuntime {
self.instance_sources.clone()
}
pub(crate) fn routing_occupancy_states(&self) -> Arc<Mutex<RoutingOccupancyMap>> {
self.routing_occupancy_states.clone()
}
/// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for
/// Component, to allow it to publish to NATS. KV Router is the only user.
///
......
......@@ -15,9 +15,9 @@ fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
match_error_chain(err, INHIBITED, &[])
}
use crate::{
component::{Client, Endpoint},
component::{Client, Endpoint, RoutingOccupancyState, get_or_create_routing_occupancy_state},
dynamo_nvtx_range,
engine::{AsyncEngine, Data},
engine::{AsyncEngine, AsyncEngineContext, Data},
metrics::frontend_perf::STAGE_DURATION_SECONDS,
pipeline::{
AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
......@@ -27,32 +27,59 @@ use crate::{
traits::DistributedRuntimeProvider,
};
use async_trait::async_trait;
use dashmap::DashMap;
use futures::Stream;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
task::Poll,
time::Instant,
};
use tokio_stream::StreamExt;
use tracing::Instrument;
/// RAII guard that decrements a per-instance in-flight counter on drop.
/// Used by PowerOfTwoChoices routing to track request occupancy.
struct P2CGuard {
in_flight_counts: Arc<DashMap<u64, AtomicU64>>,
struct OccupancyPermit {
state: Arc<RoutingOccupancyState>,
instance_id: u64,
armed: bool,
}
impl Drop for P2CGuard {
impl OccupancyPermit {
fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
Self {
state,
instance_id,
armed: true,
}
}
fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
self.armed = false;
let engine_ctx = stream.context();
ResponseStream::new(
Box::pin(OccupancyTrackedStream {
inner: stream,
state: self.state.clone(),
instance_id: self.instance_id,
}),
engine_ctx,
)
}
fn instance_id(&self) -> u64 {
self.instance_id
}
}
impl Drop for OccupancyPermit {
fn drop(&mut self) {
if let Some(counter) = self.in_flight_counts.get(&self.instance_id) {
counter.value().fetch_sub(1, Ordering::Relaxed);
if self.armed {
self.state.decrement(self.instance_id);
}
}
}
......@@ -87,9 +114,6 @@ where
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
/// Per-instance in-flight request counts for PowerOfTwoChoices routing.
in_flight_counts: Arc<DashMap<u64, AtomicU64>>,
/// The next step in the chain. PushRouter (this object) picks an instances,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
......@@ -104,6 +128,9 @@ where
/// where transient failures are expected.
fault_detection_enabled: bool,
/// Shared request occupancy state for tracked routing modes.
occupancy_state: Option<Arc<RoutingOccupancyState>>,
/// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time.
......@@ -118,6 +145,7 @@ pub enum RouterMode {
PowerOfTwoChoices,
KV,
Direct,
LeastLoaded,
}
impl RouterMode {
......@@ -132,7 +160,7 @@ impl RouterMode {
/// Pick the instance with lower in-flight count from two random candidates.
/// Returns the single instance if only one is available.
fn p2c_select_from(in_flight_counts: &DashMap<u64, AtomicU64>, instance_ids: &[u64]) -> u64 {
fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
let count = instance_ids.len();
if count == 1 {
return instance_ids[0];
......@@ -142,14 +170,8 @@ fn p2c_select_from(in_flight_counts: &DashMap<u64, AtomicU64>, instance_ids: &[u
let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
let id1 = instance_ids[idx1];
let id2 = instance_ids[idx2];
let load1 = in_flight_counts
.get(&id1)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
let load2 = in_flight_counts
.get(&id2)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
let load1 = occupancy_state.load(id1);
let load2 = occupancy_state.load(id2);
let selected = if load1 <= load2 { id1 } else { id2 };
tracing::debug!(
candidate_a = id1,
......@@ -197,14 +219,23 @@ where
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
let occupancy_state = if matches!(
router_mode,
RouterMode::PowerOfTwoChoices | RouterMode::LeastLoaded
) {
Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
} else {
None
};
Ok(PushRouter {
client: client.clone(),
client,
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
in_flight_counts: Arc::new(DashMap::new()),
busy_threshold: None,
fault_detection_enabled: false,
occupancy_state,
_phantom: PhantomData,
})
}
......@@ -223,14 +254,23 @@ where
monitor.start_monitoring().await?;
}
let occupancy_state = if matches!(
router_mode,
RouterMode::PowerOfTwoChoices | RouterMode::LeastLoaded
) {
Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
} else {
None
};
let router = PushRouter {
client: client.clone(),
client,
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
in_flight_counts: Arc::new(DashMap::new()),
busy_threshold,
fault_detection_enabled: true,
occupancy_state,
_phantom: PhantomData,
};
......@@ -281,36 +321,32 @@ where
/// Issue a request using power-of-two-choices: pick 2 random healthy workers,
/// route to the one with fewer in-flight requests.
pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?;
let instance_id = {
let instance_ids = self.client.instance_ids_avail();
let instance_ids = self
.client
.instance_ids_avail()
.iter()
.copied()
.collect::<Vec<_>>();
if instance_ids.is_empty() {
return Err(anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
));
}
p2c_select_from(&self.in_flight_counts, &instance_ids)
};
// Guard created before the await so error paths also decrement.
self.in_flight_counts
.entry(instance_id)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
let guard = P2CGuard {
in_flight_counts: self.in_flight_counts.clone(),
instance_id,
p2c_select_from(state.as_ref(), &instance_ids)
};
state.increment(instance_id);
let permit = OccupancyPermit::new(state, instance_id);
let stream = self
match self
.generate_with_fault_detection(instance_id, request)
.await?;
let engine_ctx = stream.context();
let stream = stream.map(move |res| {
let _guard = &guard;
res
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
.await
{
Ok(stream) => Ok(permit.into_tracked_stream(stream)),
Err(err) => Err(err),
}
}
/// Issue a request to a specific endpoint
......@@ -339,10 +375,42 @@ where
.await
}
/// Issue a request to the instance with the fewest active connections.
pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?;
let instance_ids = self
.client
.instance_ids_avail()
.iter()
.copied()
.collect::<Vec<_>>();
let instance_id = state
.select_exact_min_and_increment(&instance_ids)
.await
.ok_or_else(|| {
anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
)
})?;
let permit = OccupancyPermit::new(state.clone(), instance_id);
tracing::trace!(
"least loaded router selected {instance_id} (connections: {})",
state.load(instance_id)
);
match self
.generate_with_fault_detection(instance_id, request)
.await
{
Ok(stream) => Ok(permit.into_tracked_stream(stream)),
Err(err) => Err(err),
}
}
/// Select the next worker according to the routing mode.
/// Increments round-robin counter if applicable.
/// Returns None for Direct mode - requires explicit worker IDs via routing hints
/// Panics for KV mode which has its own selection via find_best_match.
/// Returns None for modes that require request lifecycle tracking or explicit routing hints.
pub fn select_next_worker(&self) -> Option<u64> {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
......@@ -359,9 +427,8 @@ where
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
// P2C needs lifecycle tracking (P2CGuard); use generate() instead.
RouterMode::PowerOfTwoChoices | RouterMode::Direct => None,
_ => {
RouterMode::PowerOfTwoChoices | RouterMode::Direct | RouterMode::LeastLoaded => None,
RouterMode::KV => {
panic!(
"select_next_worker should not be called for {:?} routing mode",
self.router_mode
......@@ -372,7 +439,7 @@ where
/// Peek the next worker according to the routing mode without incrementing the counter.
/// Useful for checking if a worker is suitable before committing to it.
/// Returns None for Direct mode - requires explicit worker IDs via routing hints.
/// Returns None for modes that require request lifecycle tracking or explicit routing hints.
pub fn peek_next_worker(&self) -> Option<u64> {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
......@@ -392,9 +459,8 @@ where
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
// P2C needs lifecycle tracking (P2CGuard); use generate() instead.
RouterMode::PowerOfTwoChoices | RouterMode::Direct => None,
_ => {
RouterMode::PowerOfTwoChoices | RouterMode::Direct | RouterMode::LeastLoaded => None,
RouterMode::KV => {
panic!(
"peek_next_worker should not be called for {:?} routing mode",
self.router_mode
......@@ -403,6 +469,15 @@ where
}
}
fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
self.occupancy_state.clone().ok_or_else(|| {
anyhow::anyhow!(
"routing occupancy state not initialized for endpoint {}",
self.client.endpoint.id()
)
})
}
/*
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject();
......@@ -555,101 +630,177 @@ where
"Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
);
}
RouterMode::LeastLoaded => self.least_loaded(request).await,
}
}
}
struct OccupancyTrackedStream<U: Data> {
inner: ManyOut<U>,
state: Arc<RoutingOccupancyState>,
instance_id: u64,
}
impl<U: Data> Drop for OccupancyTrackedStream<U> {
fn drop(&mut self) {
self.state.decrement(self.instance_id);
}
}
impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OccupancyTrackedStream")
.field("instance_id", &self.instance_id)
.finish()
}
}
impl<U: Data> Stream for OccupancyTrackedStream<U> {
type Item = U;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.context()
}
}
impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
DistributedRuntime, Runtime,
distributed::DistributedConfig,
error::DynamoError,
pipeline::{ResponseStream, context::Controller},
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
struct TestResponse {
error: Option<DynamoError>,
}
impl MaybeError for TestResponse {
fn from_err(err: impl std::error::Error + 'static) -> Self {
Self {
error: Some(DynamoError::from(
Box::new(err) as Box<dyn std::error::Error + 'static>
)),
}
}
fn err(&self) -> Option<DynamoError> {
self.error.clone()
}
}
#[test]
fn p2c_selects_lower_load_worker() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(10));
counts.insert(2, AtomicU64::new(1));
let state = RoutingOccupancyState::default();
for _ in 0..10 {
state.increment(1);
}
state.increment(2);
// With only two workers, p2c_select_from must pick both and choose id=2 (lower load).
let result = p2c_select_from(&counts, &[1, 2]);
let result = p2c_select_from(&state, &[1, 2]);
assert_eq!(result, 2);
}
#[test]
fn p2c_selects_single_worker() {
let counts = DashMap::new();
assert_eq!(p2c_select_from(&counts, &[42]), 42);
let state = RoutingOccupancyState::default();
assert_eq!(p2c_select_from(&state, &[42]), 42);
}
#[test]
fn p2c_treats_missing_counts_as_zero() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(5));
let state = RoutingOccupancyState::default();
for _ in 0..5 {
state.increment(1);
}
// Worker 2 has no entry — should be treated as 0, so it wins.
let result = p2c_select_from(&counts, &[1, 2]);
let result = p2c_select_from(&state, &[1, 2]);
assert_eq!(result, 2);
}
#[test]
fn p2c_returns_valid_worker_on_tie() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(3));
counts.insert(2, AtomicU64::new(3));
let state = RoutingOccupancyState::default();
for _ in 0..3 {
state.increment(1);
state.increment(2);
}
for _ in 0..100 {
let result = p2c_select_from(&counts, &[1, 2]);
let result = p2c_select_from(&state, &[1, 2]);
assert!(result == 1 || result == 2);
}
}
#[test]
fn p2c_lifecycle_tracks_inflight_counts() {
let counts = Arc::new(DashMap::new());
let mut guards = Vec::new();
fn occupancy_permit_decrements_before_stream_creation() {
let state = Arc::new(RoutingOccupancyState::default());
state.increment(42);
let permit = OccupancyPermit::new(state.clone(), 42);
assert_eq!(state.load(42), 1);
drop(permit);
assert_eq!(state.load(42), 0);
}
#[test]
fn occupancy_tracked_stream_decrements_on_drop() {
let state = Arc::new(RoutingOccupancyState::default());
state.increment(7);
let permit = OccupancyPermit::new(state.clone(), 7);
let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
let stream = permit.into_tracked_stream(ResponseStream::new(
Box::pin(tokio_stream::iter(vec![1u64])),
ctx,
));
assert_eq!(state.load(7), 1);
drop(stream);
assert_eq!(state.load(7), 0);
}
#[test]
fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
let state = Arc::new(RoutingOccupancyState::default());
let mut permits = Vec::new();
for _ in 0..5 {
let selected = p2c_select_from(&counts, &[1, 2]);
counts
.entry(selected)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
guards.push(P2CGuard {
in_flight_counts: counts.clone(),
instance_id: selected,
});
let selected = p2c_select_from(&state, &[1, 2]);
state.increment(selected);
permits.push(OccupancyPermit::new(state.clone(), selected));
}
let total = counts
.get(&1)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
+ counts
.get(&2)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
let total = state.load(1) + state.load(2);
assert_eq!(total, 5, "5 in-flight requests should be tracked");
drop(guards);
let total = counts
.get(&1)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
+ counts
.get(&2)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
drop(permits);
let total = state.load(1) + state.load(2);
assert_eq!(total, 0, "All guards dropped, counts should be 0");
}
#[test]
fn p2c_never_selects_dominated_worker() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(0));
counts.insert(2, AtomicU64::new(0));
counts.insert(3, AtomicU64::new(100));
let state = RoutingOccupancyState::default();
for _ in 0..100 {
state.increment(3);
}
let mut selected = [0u32; 3];
for _ in 0..1000 {
let result = p2c_select_from(&counts, &[1, 2, 3]);
let result = p2c_select_from(&state, &[1, 2, 3]);
match result {
1 => selected[0] += 1,
2 => selected[1] += 1,
......@@ -663,4 +814,49 @@ mod tests {
selected[2]
);
}
#[tokio::test]
async fn least_loaded_selects_exact_min_and_tracks_counts() {
let state = Arc::new(RoutingOccupancyState::default());
state.increment(1);
state.increment(1);
state.increment(2);
let selected = state
.select_exact_min_and_increment(&[1, 2, 3])
.await
.unwrap();
assert_eq!(selected, 3);
let permit = OccupancyPermit::new(state.clone(), selected);
assert_eq!(state.load(selected), 1);
drop(permit);
assert_eq!(state.load(selected), 0);
}
#[tokio::test]
async fn least_loaded_select_and_peek_return_none_with_available_worker() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_least_loaded_router".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
.await
.unwrap();
assert_eq!(router.select_next_worker(), None);
assert_eq!(router.peek_next_worker(), None);
rt.shutdown();
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment