Unverified Commit 73e0a2f2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): fold workload concurrency cap into WorkloadDriver (#8446)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4833e29b
...@@ -79,6 +79,7 @@ pub struct WorkloadDriver { ...@@ -79,6 +79,7 @@ pub struct WorkloadDriver {
sessions: Vec<SessionRuntime>, sessions: Vec<SessionRuntime>,
in_flight: FxHashMap<Uuid, InFlightTurn>, in_flight: FxHashMap<Uuid, InFlightTurn>,
ready_sessions: BinaryHeap<ReadySession>, ready_sessions: BinaryHeap<ReadySession>,
max_in_flight: Option<usize>,
} }
impl WorkloadDriver { impl WorkloadDriver {
...@@ -140,16 +141,52 @@ impl WorkloadDriver { ...@@ -140,16 +141,52 @@ impl WorkloadDriver {
sessions, sessions,
in_flight: FxHashMap::default(), in_flight: FxHashMap::default(),
ready_sessions, ready_sessions,
max_in_flight: None,
}) })
} }
/// Set a global in-flight cap. `pop_ready` will clamp by the remaining cap,
/// and `next_ready_time_ms` returns `None` while at cap.
pub fn set_max_in_flight(&mut self, cap: usize) {
debug_assert!(
self.in_flight.is_empty(),
"set_max_in_flight called on a driver with pending work"
);
self.max_in_flight = Some(cap);
}
/// Failure-path companion: release a cap slot and terminate the owning session.
/// No-op if `on_complete` already ran. Used when a request task is cancelled
/// or panics before reaching `on_complete`.
///
/// Terminating the session (marking it exhausted) prevents `run_workload` from
/// deadlocking: `pop_ready` skips sessions with `in_flight.is_some()`, so a
/// leaked session would leave `is_drained` stuck at `false` forever.
pub fn release_cap_slot(&mut self, request_uuid: Uuid) {
let Some(in_flight) = self.in_flight.remove(&request_uuid) else {
return;
};
let Some(session) = self.sessions.get_mut(in_flight.session_index) else {
return;
};
if session.in_flight == Some(request_uuid) {
session.in_flight = None;
session.next_turn_index = session.turns.len();
session.next_ready_at_ms = None;
}
}
pub fn pop_ready(&mut self, now_ms: f64, limit: usize) -> Vec<ReadyTurn> { pub fn pop_ready(&mut self, now_ms: f64, limit: usize) -> Vec<ReadyTurn> {
if limit == 0 { let effective_limit = match self.max_in_flight {
Some(cap) => limit.min(cap.saturating_sub(self.in_flight.len())),
None => limit,
};
if effective_limit == 0 {
return Vec::new(); return Vec::new();
} }
let mut emitted = Vec::new(); let mut emitted = Vec::new();
while emitted.len() < limit { while emitted.len() < effective_limit {
let Some(ready_session) = self.ready_sessions.pop() else { let Some(ready_session) = self.ready_sessions.pop() else {
break; break;
}; };
...@@ -240,6 +277,11 @@ impl WorkloadDriver { ...@@ -240,6 +277,11 @@ impl WorkloadDriver {
} }
pub fn next_ready_time_ms(&mut self) -> Option<f64> { pub fn next_ready_time_ms(&mut self) -> Option<f64> {
if let Some(cap) = self.max_in_flight
&& self.in_flight.len() >= cap
{
return None;
}
loop { loop {
let ready_session = *self.ready_sessions.peek()?; let ready_session = *self.ready_sessions.peek()?;
let session = &self.sessions[ready_session.session_index]; let session = &self.sessions[ready_session.session_index];
...@@ -269,3 +311,162 @@ impl WorkloadDriver { ...@@ -269,3 +311,162 @@ impl WorkloadDriver {
.sum() .sum()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
fn two_session_trace() -> Trace {
Trace {
block_size: 1,
sessions: vec![
SessionTrace {
session_id: "a".into(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 2,
max_output_tokens: 1,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 2,
max_output_tokens: 1,
hash_ids: vec![3, 4],
delay_after_previous_ms: 5.0,
},
],
},
SessionTrace {
session_id: "b".into(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![TurnTrace {
input_length: 2,
max_output_tokens: 1,
hash_ids: vec![5, 6],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
#[test]
fn cap_clamps_pop_ready_when_limit_is_unbounded() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
driver.set_max_in_flight(1);
let first = driver.pop_ready(0.0, usize::MAX);
assert_eq!(first.len(), 1);
let second = driver.pop_ready(0.0, usize::MAX);
assert!(
second.is_empty(),
"cap should block dispatch while slot is held"
);
}
#[test]
fn pop_ready_admits_next_turn_after_on_complete() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
driver.set_max_in_flight(1);
let admitted = driver.pop_ready(0.0, usize::MAX);
assert_eq!(admitted.len(), 1);
let uuid = admitted[0].request_uuid;
driver.on_complete(uuid, 10.0).unwrap();
let next = driver.pop_ready(10.0, usize::MAX);
assert_eq!(next.len(), 1);
assert_ne!(next[0].request_uuid, uuid);
}
#[test]
fn next_ready_time_ms_returns_none_at_cap() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
driver.set_max_in_flight(1);
let admitted = driver.pop_ready(0.0, usize::MAX);
assert_eq!(admitted.len(), 1);
assert!(
driver.next_ready_time_ms().is_none(),
"expected None while at cap even with ready sessions queued"
);
driver.on_complete(admitted[0].request_uuid, 10.0).unwrap();
assert!(
driver.next_ready_time_ms().is_some(),
"expected readiness after a slot is freed"
);
}
#[test]
fn no_cap_preserves_caller_limit_behavior() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
let admitted = driver.pop_ready(0.0, 5);
assert_eq!(admitted.len(), 2, "both sessions should admit with no cap");
assert!(driver.next_ready_time_ms().is_none());
}
#[test]
fn release_cap_slot_is_noop_after_on_complete() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
driver.set_max_in_flight(1);
let admitted = driver.pop_ready(0.0, usize::MAX);
let uuid = admitted[0].request_uuid;
driver.on_complete(uuid, 5.0).unwrap();
driver.release_cap_slot(uuid);
let next = driver.pop_ready(5.0, usize::MAX);
assert_eq!(next.len(), 1);
assert_ne!(next[0].request_uuid, uuid);
}
#[test]
fn release_cap_slot_recovers_cap_when_on_complete_was_skipped() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
driver.set_max_in_flight(1);
let admitted = driver.pop_ready(0.0, usize::MAX);
assert_eq!(admitted.len(), 1);
driver.release_cap_slot(admitted[0].request_uuid);
let next = driver.pop_ready(0.0, usize::MAX);
assert_eq!(
next.len(),
1,
"cap slot should be available after release_cap_slot"
);
}
#[test]
fn release_cap_slot_terminates_session_so_is_drained_completes() {
let mut driver = WorkloadDriver::new_concurrency(two_session_trace(), 1).unwrap();
driver.set_max_in_flight(1);
let admitted = driver.pop_ready(0.0, usize::MAX);
assert_eq!(admitted.len(), 1);
let stuck_uuid = admitted[0].request_uuid;
driver.release_cap_slot(stuck_uuid);
let neighbor = driver.pop_ready(0.0, usize::MAX);
assert_eq!(
neighbor.len(),
1,
"other session must still be admissible after its neighbor was terminated"
);
driver.on_complete(neighbor[0].request_uuid, 1.0).unwrap();
assert!(
driver.is_drained(),
"is_drained must become true so run_workload can exit"
);
}
}
...@@ -22,7 +22,9 @@ use super::state::{ ...@@ -22,7 +22,9 @@ use super::state::{
LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms, LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms,
record_arrival, record_arrival,
}; };
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress}; use super::task::{
InFlightGuard, RequestTaskContext, run_request_task, wait_for_workload_progress,
};
pub(super) struct LiveRuntime { pub(super) struct LiveRuntime {
pending: std::collections::VecDeque<DirectRequest>, pending: std::collections::VecDeque<DirectRequest>,
...@@ -140,7 +142,11 @@ impl LiveRuntime { ...@@ -140,7 +142,11 @@ impl LiveRuntime {
.await .await
.map_err(|_| anyhow!("online replay concurrency semaphore closed"))?; .map_err(|_| anyhow!("online replay concurrency semaphore closed"))?;
record_arrival(&arrival_tx, &request, now_ms(start))?; record_arrival(&arrival_tx, &request, now_ms(start))?;
tasks.spawn(run_request_task(task_ctx.clone(), request, Some(permit))); let task_ctx = task_ctx.clone();
tasks.spawn(async move {
let _permit = permit;
run_request_task(task_ctx, request, None).await
});
} }
} }
} }
...@@ -163,7 +169,7 @@ impl LiveRuntime { ...@@ -163,7 +169,7 @@ impl LiveRuntime {
/// Drive a multi-turn workload driver until it is drained and all spawned request tasks finish. /// Drive a multi-turn workload driver until it is drained and all spawned request tasks finish.
pub(super) async fn run_workload( pub(super) async fn run_workload(
mut self, mut self,
driver: WorkloadDriver, mut driver: WorkloadDriver,
total_turns: usize, total_turns: usize,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> { ) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let requests = Arc::new(DashMap::with_capacity(total_turns.max(1))); let requests = Arc::new(DashMap::with_capacity(total_turns.max(1)));
...@@ -189,6 +195,13 @@ impl LiveRuntime { ...@@ -189,6 +195,13 @@ impl LiveRuntime {
) )
.await .await
}); });
let cap_enabled = match self.mode {
LiveReplayMode::Trace => false,
LiveReplayMode::Concurrency { max_in_flight } => {
driver.set_max_in_flight(max_in_flight);
true
}
};
let workload = Arc::new(WorkloadDispatchState { let workload = Arc::new(WorkloadDispatchState {
driver: std::sync::Mutex::new(driver), driver: std::sync::Mutex::new(driver),
wakeup: Notify::new(), wakeup: Notify::new(),
...@@ -202,38 +215,15 @@ impl LiveRuntime { ...@@ -202,38 +215,15 @@ impl LiveRuntime {
stats: Arc::clone(&stats), stats: Arc::clone(&stats),
workload: Some(Arc::clone(&workload)), workload: Some(Arc::clone(&workload)),
}; };
let semaphore = match self.mode {
LiveReplayMode::Trace => None,
LiveReplayMode::Concurrency { max_in_flight } => {
Some(Arc::new(Semaphore::new(max_in_flight)))
}
};
loop { loop {
let now = now_ms(start); let now = now_ms(start);
let dispatch_limit = match &semaphore { let ready_turns = workload.driver.lock().unwrap().pop_ready(now, usize::MAX);
Some(semaphore) => semaphore.available_permits(),
None => usize::MAX,
};
if dispatch_limit > 0 {
let ready_turns = workload
.driver
.lock()
.unwrap()
.pop_ready(now, dispatch_limit);
if !ready_turns.is_empty() { if !ready_turns.is_empty() {
for ready_turn in ready_turns { for ready_turn in ready_turns {
let permit = match &semaphore { let guard = cap_enabled.then(|| {
Some(semaphore) => { InFlightGuard::new(Arc::clone(&workload), ready_turn.request_uuid)
Some(semaphore.clone().try_acquire_owned().map_err(|_| { });
anyhow!(
"online replay concurrency semaphore unexpectedly closed"
)
})?)
}
None => None,
};
let arrival_at_ms = match self.mode { let arrival_at_ms = match self.mode {
LiveReplayMode::Trace => ready_turn.scheduled_ready_at_ms, LiveReplayMode::Trace => ready_turn.scheduled_ready_at_ms,
LiveReplayMode::Concurrency { .. } => now_ms(start), LiveReplayMode::Concurrency { .. } => now_ms(start),
...@@ -242,12 +232,11 @@ impl LiveRuntime { ...@@ -242,12 +232,11 @@ impl LiveRuntime {
tasks.spawn(run_request_task( tasks.spawn(run_request_task(
task_ctx.clone(), task_ctx.clone(),
ready_turn.request, ready_turn.request,
permit, guard,
)); ));
} }
continue; continue;
} }
}
let wake = workload.wakeup.notified(); let wake = workload.wakeup.notified();
tokio::pin!(wake); tokio::pin!(wake);
...@@ -259,14 +248,7 @@ impl LiveRuntime { ...@@ -259,14 +248,7 @@ impl LiveRuntime {
break; break;
} }
wait_for_workload_progress( wait_for_workload_progress(next_ready_ms, start, wake.as_mut()).await;
self.mode,
semaphore.as_deref(),
next_ready_ms,
start,
wake.as_mut(),
)
.await;
} }
while let Some(result) = tasks.join_next().await { while let Some(result) = tasks.join_next().await {
......
...@@ -6,15 +6,16 @@ use std::pin::Pin; ...@@ -6,15 +6,16 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Result, anyhow, bail}; use anyhow::{Result, anyhow, bail};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc}; use tokio::sync::mpsc;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid;
use crate::common::protocols::DirectRequest; use crate::common::protocols::DirectRequest;
use super::ReplayRouter; use super::ReplayRouter;
use super::state::{ use super::state::{
LiveReplayMode, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms,
now_ms, request_uuid, request_uuid,
}; };
#[derive(Clone)] #[derive(Clone)]
...@@ -26,50 +27,66 @@ pub(super) struct RequestTaskContext { ...@@ -26,50 +27,66 @@ pub(super) struct RequestTaskContext {
pub(super) workload: Option<Arc<WorkloadDispatchState>>, pub(super) workload: Option<Arc<WorkloadDispatchState>>,
} }
/// Releases a `WorkloadDriver` cap slot on drop if `mark_completed` was not called.
/// Preserves the drop-safety of the old `OwnedSemaphorePermit` so a cancelled or
/// panicking request task can't leak capacity.
pub(super) struct InFlightGuard {
dispatch: Arc<WorkloadDispatchState>,
uuid: Uuid,
completed: bool,
}
impl InFlightGuard {
pub(super) fn new(dispatch: Arc<WorkloadDispatchState>, uuid: Uuid) -> Self {
Self {
dispatch,
uuid,
completed: false,
}
}
pub(super) fn mark_completed(&mut self) {
self.completed = true;
}
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
if self.completed {
return;
}
if let Ok(mut driver) = self.dispatch.driver.lock() {
driver.release_cap_slot(self.uuid);
}
self.dispatch.wakeup.notify_waiters();
}
}
pub(super) async fn wait_for_workload_progress<F>( pub(super) async fn wait_for_workload_progress<F>(
mode: LiveReplayMode,
semaphore: Option<&Semaphore>,
next_ready_ms: Option<f64>, next_ready_ms: Option<f64>,
start: Instant, start: Instant,
mut wake: Pin<&mut F>, mut wake: Pin<&mut F>,
) where ) where
F: Future<Output = ()>, F: Future<Output = ()>,
{ {
match (mode, semaphore, next_ready_ms) { match next_ready_ms {
(LiveReplayMode::Trace, _, Some(next_ready_ms)) => { Some(next_ready_ms) => {
let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0); let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
tokio::select! { tokio::select! {
_ = tokio::time::sleep_until(deadline) => {} _ = tokio::time::sleep_until(deadline) => {}
_ = wake.as_mut() => {} _ = wake.as_mut() => {}
} }
} }
(LiveReplayMode::Trace, _, None) => { None => {
wake.as_mut().await; wake.as_mut().await;
} }
(LiveReplayMode::Concurrency { .. }, Some(semaphore), Some(next_ready_ms)) => {
if semaphore.available_permits() == 0 {
wake.as_mut().await;
} else {
let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
tokio::select! {
_ = tokio::time::sleep_until(deadline) => {}
_ = wake.as_mut() => {}
}
}
}
(LiveReplayMode::Concurrency { .. }, Some(_semaphore), None) => {
wake.as_mut().await;
}
(LiveReplayMode::Concurrency { .. }, None, _) => {
unreachable!("concurrency mode must have a semaphore");
}
} }
} }
pub(super) async fn run_request_task( pub(super) async fn run_request_task(
ctx: RequestTaskContext, ctx: RequestTaskContext,
request: DirectRequest, request: DirectRequest,
permit: Option<OwnedSemaphorePermit>, mut guard: Option<InFlightGuard>,
) -> Result<()> { ) -> Result<()> {
let uuid = request_uuid(&request)?; let uuid = request_uuid(&request)?;
...@@ -102,7 +119,9 @@ pub(super) async fn run_request_task( ...@@ -102,7 +119,9 @@ pub(super) async fn run_request_task(
.unwrap() .unwrap()
.on_complete(uuid, completion_ms)?; .on_complete(uuid, completion_ms)?;
workload.wakeup.notify_waiters(); workload.wakeup.notify_waiters();
if let Some(guard) = guard.as_mut() {
guard.mark_completed();
}
} }
drop(permit);
Ok(()) Ok(())
} }
...@@ -9,7 +9,7 @@ use std::time::Duration; ...@@ -9,7 +9,7 @@ use std::time::Duration;
use dashmap::DashMap; use dashmap::DashMap;
use dynamo_kv_router::PrefillLoadEstimator; use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_kv_router::config::{KvRouterConfig, RouterPrefillLoadModel}; use dynamo_kv_router::config::{KvRouterConfig, RouterPrefillLoadModel};
use tokio::sync::{Notify, Semaphore, mpsc}; use tokio::sync::{Notify, mpsc};
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
...@@ -24,7 +24,7 @@ use super::entrypoints::{ ...@@ -24,7 +24,7 @@ use super::entrypoints::{
simulate_trace_requests, simulate_trace_requests_with_stats, simulate_trace_requests, simulate_trace_requests_with_stats,
simulate_trace_workload_with_stats, simulate_trace_workload_with_stats,
}; };
use super::state::{LiveReplayMode, SharedLiveRuntimeStats, WorkloadDispatchState, record_arrival}; use super::state::{SharedLiveRuntimeStats, WorkloadDispatchState, record_arrival};
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress}; use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress};
fn replay_args() -> MockEngineArgs { fn replay_args() -> MockEngineArgs {
...@@ -349,7 +349,6 @@ async fn test_workload_wakeup_is_not_lost_when_completion_happens_before_await() ...@@ -349,7 +349,6 @@ async fn test_workload_wakeup_is_not_lost_when_completion_happens_before_await()
#[tokio::test] #[tokio::test]
async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion_gated() { async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion_gated() {
let semaphore = Arc::new(Semaphore::new(1));
let notify = Arc::new(Notify::new()); let notify = Arc::new(Notify::new());
let wake = notify.notified(); let wake = notify.notified();
tokio::pin!(wake); tokio::pin!(wake);
...@@ -357,13 +356,7 @@ async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion ...@@ -357,13 +356,7 @@ async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion
assert!( assert!(
tokio::time::timeout( tokio::time::timeout(
tokio::time::Duration::from_millis(20), tokio::time::Duration::from_millis(20),
wait_for_workload_progress( wait_for_workload_progress(None, Instant::now(), wake.as_mut()),
LiveReplayMode::Concurrency { max_in_flight: 1 },
Some(semaphore.as_ref()),
None,
Instant::now(),
wake.as_mut(),
),
) )
.await .await
.is_err(), .is_err(),
...@@ -372,13 +365,7 @@ async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion ...@@ -372,13 +365,7 @@ async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion
let wake = notify.notified(); let wake = notify.notified();
tokio::pin!(wake); tokio::pin!(wake);
let wait = wait_for_workload_progress( let wait = wait_for_workload_progress(None, Instant::now(), wake.as_mut());
LiveReplayMode::Concurrency { max_in_flight: 1 },
Some(semaphore.as_ref()),
None,
Instant::now(),
wake.as_mut(),
);
let notify_task = { let notify_task = {
let notify = Arc::clone(&notify); let notify = Arc::clone(&notify);
tokio::spawn(async move { tokio::spawn(async move {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment