Unverified Commit db49dfd2 authored by Aaron Batilo's avatar Aaron Batilo Committed by GitHub
Browse files

feat: add power-of-two-choices router mode (#7614)


Signed-off-by: default avatarAaron Batilo <abatilo@coreweave.com>
parent 2fe37a51
......@@ -181,8 +181,11 @@ class FrontendArgGroup(ArgGroup):
flag_name="--router-mode",
env_var="DYN_ROUTER_MODE",
default="round-robin",
help="How to route the request.",
choices=["round-robin", "random", "kv", "direct"],
help="How to route the request. power-of-two picks 2 random workers and "
"routes to the one with fewer in-flight requests. In disaggregated prefill "
"mode, power-of-two skips bootstrap optimization and falls back to the "
"synchronous prefill path.",
choices=["round-robin", "random", "power-of-two", "kv", "direct"],
)
# KV router options (shared with dynamo.router)
......
......@@ -225,6 +225,9 @@ async def async_main():
elif config.router_mode == "direct":
router_mode = RouterMode.Direct
kv_router_config = None
elif config.router_mode == "power-of-two":
router_mode = RouterMode.PowerOfTwoChoices
kv_router_config = None
else:
router_mode = RouterMode.RoundRobin
kv_router_config = None
......
......@@ -46,6 +46,7 @@ use crate::llm::preprocessor::{MediaDecoder, MediaFetcher};
pub enum RouterMode {
RoundRobin,
Random,
PowerOfTwoChoices,
KV,
/// Direct routing - reads worker ID from each request's routing hints.
/// Used when an external orchestrator (e.g., EPP) handles worker selection.
......@@ -57,6 +58,7 @@ impl From<RouterMode> for RsRouterMode {
match mode {
RouterMode::RoundRobin => Self::RoundRobin,
RouterMode::Random => Self::Random,
RouterMode::PowerOfTwoChoices => Self::PowerOfTwoChoices,
RouterMode::KV => Self::KV,
RouterMode::Direct => Self::Direct,
}
......
......@@ -1079,6 +1079,7 @@ class RouterMode:
"""Router mode for load balancing requests across workers"""
RoundRobin: "RouterMode"
Random: "RouterMode"
PowerOfTwoChoices: "RouterMode"
KV: "RouterMode"
Direct: "RouterMode"
...
......
......@@ -292,7 +292,7 @@ where
RouterMode::Direct => {
ServiceBackend::from_engine(Arc::new(DirectRoutingRouter::new(router)))
}
RouterMode::Random | RouterMode::RoundRobin => {
RouterMode::Random | RouterMode::RoundRobin | RouterMode::PowerOfTwoChoices => {
ServiceBackend::from_engine(Arc::new(router))
}
RouterMode::KV => {
......
......@@ -27,6 +27,7 @@ use crate::{
traits::DistributedRuntimeProvider,
};
use async_trait::async_trait;
use dashmap::DashMap;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
......@@ -41,6 +42,21 @@ use std::{
use tokio_stream::StreamExt;
use tracing::Instrument;
/// RAII guard that decrements a per-instance in-flight counter on drop.
/// Used by PowerOfTwoChoices routing to track request occupancy.
struct P2CGuard {
in_flight_counts: Arc<DashMap<u64, AtomicU64>>,
instance_id: u64,
}
impl Drop for P2CGuard {
fn drop(&mut self) {
if let Some(counter) = self.in_flight_counts.get(&self.instance_id) {
counter.value().fetch_sub(1, Ordering::Relaxed);
}
}
}
/// Trait for monitoring worker load and determining busy state.
/// Implementations can define custom load metrics and busy thresholds.
#[async_trait]
......@@ -71,6 +87,9 @@ where
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
/// Per-instance in-flight request counts for PowerOfTwoChoices routing.
in_flight_counts: Arc<DashMap<u64, AtomicU64>>,
/// The next step in the chain. PushRouter (this object) picks an instances,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
......@@ -96,6 +115,7 @@ pub enum RouterMode {
#[default]
RoundRobin,
Random,
PowerOfTwoChoices,
KV,
Direct,
}
......@@ -110,6 +130,38 @@ impl RouterMode {
}
}
/// Pick the instance with lower in-flight count from two random candidates.
/// Returns the single instance if only one is available.
fn p2c_select_from(in_flight_counts: &DashMap<u64, AtomicU64>, instance_ids: &[u64]) -> u64 {
let count = instance_ids.len();
if count == 1 {
return instance_ids[0];
}
let mut rng = rand::rng();
let idx1 = rng.random_range(0..count);
let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
let id1 = instance_ids[idx1];
let id2 = instance_ids[idx2];
let load1 = in_flight_counts
.get(&id1)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
let load2 = in_flight_counts
.get(&id2)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
let selected = if load1 <= load2 { id1 } else { id2 };
tracing::debug!(
candidate_a = id1,
candidate_a_load = load1,
candidate_b = id2,
candidate_b_load = load2,
selected = selected,
"p2c selection"
);
selected
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
// Get network manager and create client (no mode checks!)
let manager = endpoint.drt().network_manager();
......@@ -150,6 +202,7 @@ where
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
in_flight_counts: Arc::new(DashMap::new()),
busy_threshold: None,
fault_detection_enabled: false,
_phantom: PhantomData,
......@@ -175,6 +228,7 @@ where
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
in_flight_counts: Arc::new(DashMap::new()),
busy_threshold,
fault_detection_enabled: true,
_phantom: PhantomData,
......@@ -224,6 +278,41 @@ where
.await
}
/// Issue a request using power-of-two-choices: pick 2 random healthy workers,
/// route to the one with fewer in-flight requests.
pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let instance_id = {
let instance_ids = self.client.instance_ids_avail();
if instance_ids.is_empty() {
return Err(anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
));
}
p2c_select_from(&self.in_flight_counts, &instance_ids)
};
// Guard created before the await so error paths also decrement.
self.in_flight_counts
.entry(instance_id)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
let guard = P2CGuard {
in_flight_counts: self.in_flight_counts.clone(),
instance_id,
};
let stream = self
.generate_with_fault_detection(instance_id, request)
.await?;
let engine_ctx = stream.context();
let stream = stream.map(move |res| {
let _guard = &guard;
res
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
/// Issue a request to a specific endpoint
pub async fn direct(
&self,
......@@ -270,7 +359,8 @@ where
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
RouterMode::Direct => None,
// P2C needs lifecycle tracking (P2CGuard); use generate() instead.
RouterMode::PowerOfTwoChoices | RouterMode::Direct => None,
_ => {
panic!(
"select_next_worker should not be called for {:?} routing mode",
......@@ -302,7 +392,8 @@ where
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
RouterMode::Direct => None,
// P2C needs lifecycle tracking (P2CGuard); use generate() instead.
RouterMode::PowerOfTwoChoices | RouterMode::Direct => None,
_ => {
panic!(
"peek_next_worker should not be called for {:?} routing mode",
......@@ -455,6 +546,7 @@ where
match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
......@@ -466,3 +558,109 @@ where
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn p2c_selects_lower_load_worker() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(10));
counts.insert(2, AtomicU64::new(1));
// With only two workers, p2c_select_from must pick both and choose id=2 (lower load).
let result = p2c_select_from(&counts, &[1, 2]);
assert_eq!(result, 2);
}
#[test]
fn p2c_selects_single_worker() {
let counts = DashMap::new();
assert_eq!(p2c_select_from(&counts, &[42]), 42);
}
#[test]
fn p2c_treats_missing_counts_as_zero() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(5));
// Worker 2 has no entry — should be treated as 0, so it wins.
let result = p2c_select_from(&counts, &[1, 2]);
assert_eq!(result, 2);
}
#[test]
fn p2c_returns_valid_worker_on_tie() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(3));
counts.insert(2, AtomicU64::new(3));
for _ in 0..100 {
let result = p2c_select_from(&counts, &[1, 2]);
assert!(result == 1 || result == 2);
}
}
#[test]
fn p2c_lifecycle_tracks_inflight_counts() {
let counts = Arc::new(DashMap::new());
let mut guards = Vec::new();
for _ in 0..5 {
let selected = p2c_select_from(&counts, &[1, 2]);
counts
.entry(selected)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
guards.push(P2CGuard {
in_flight_counts: counts.clone(),
instance_id: selected,
});
}
let total = counts
.get(&1)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
+ counts
.get(&2)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
assert_eq!(total, 5, "5 in-flight requests should be tracked");
drop(guards);
let total = counts
.get(&1)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
+ counts
.get(&2)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
assert_eq!(total, 0, "All guards dropped, counts should be 0");
}
#[test]
fn p2c_never_selects_dominated_worker() {
let counts = DashMap::new();
counts.insert(1, AtomicU64::new(0));
counts.insert(2, AtomicU64::new(0));
counts.insert(3, AtomicU64::new(100));
let mut selected = [0u32; 3];
for _ in 0..1000 {
let result = p2c_select_from(&counts, &[1, 2, 3]);
match result {
1 => selected[0] += 1,
2 => selected[1] += 1,
3 => selected[2] += 1,
_ => panic!("unexpected worker id"),
}
}
assert_eq!(
selected[2], 0,
"Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
selected[2]
);
}
}
......@@ -76,7 +76,7 @@ def _test_router_basic(
frontend_timeout: Timeout for frontend readiness check (default: 120s)
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "nats".
router_mode: Router mode ("kv", "round-robin", "random", "direct"). Defaults to "kv".
router_mode: Router mode ("kv", "round-robin", "random", "power-of-two", "direct"). Defaults to "kv".
enforce_disagg: Whether to pass --enforce-disagg to the frontend. Defaults to False.
Raises:
......
......@@ -647,6 +647,7 @@ class DisaggMockerProcess:
pytest.param("kv", True, id="kv-durable"),
pytest.param("round-robin", False, id="roundrobin"),
pytest.param("random", False, id="random"),
pytest.param("power-of-two", False, id="power-of-two"),
],
indirect=["durable_kv_events"],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment