Unverified Commit 1e261dbd authored by blarson-b10's avatar blarson-b10 Committed by GitHub
Browse files

fix: Improve active sequences request expiration (#7340)


Signed-off-by: default avatarBrian Larson <brian.larson@baseten.co>
parent a441aaf8
...@@ -15,6 +15,7 @@ use parking_lot::RwLock; ...@@ -15,6 +15,7 @@ use parking_lot::RwLock;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::single::{ActiveSequences, RequestId}; use super::single::{ActiveSequences, RequestId};
...@@ -22,6 +23,11 @@ use crate::protocols::{ ...@@ -22,6 +23,11 @@ use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, WorkerWithDpRank, ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, WorkerWithDpRank,
}; };
// How often we force expire stale requests across all workers. See the comment
// in ActiveSequencesMultiWorker::force_expire_requests_across_all_workers for
// more details.
const FORCE_EXPIRE_REQUESTS_ACROSS_ALL_WORKERS_INTERVAL: Duration = Duration::from_secs(60);
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Traits // Traits
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -691,4 +697,62 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -691,4 +697,62 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
counts counts
} }
/// Force expire stale requests across all workers (one-shot).
///
/// This is necessary because worker expiration otherwise only runs as a side-effect
/// of `add_request`. If a worker has many expired active sequences and no new
/// requests are added, expiration never runs. This method forces it on all workers.
///
/// To run this periodically, use start_periodic_force_expiry_across_all_workers.
pub fn force_expire_requests_across_all_workers(&self) {
let now = Instant::now();
let table = self.workers.read();
let mut removed_request_count = 0;
for (worker, lock) in &table.slots {
let removed_requests = lock.write().force_expiry();
if !removed_requests.is_empty() {
for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id);
removed_request_count += 1;
}
self.publish_active_load_for_worker(*worker);
}
}
let duration = now.elapsed();
tracing::debug!(
duration = duration.as_secs_f64(),
removed_request_count,
"Force expired stale requests across all workers"
);
}
/// Spawn a background task that calls `force_expire_requests_across_all_workers`
/// at the given interval until `cancel_token` is cancelled.
///
/// **Concurrency note:** This type is always used as `Arc<ActiveSequencesMultiWorker>`. All
/// mutation is via interior mutability (`RwLock<WorkerTable>`, `DashMap`), so the periodic
/// task only needs `&self` and does not block other callers.
pub fn start_periodic_force_expiry_across_all_workers(
self: &Arc<Self>,
cancel_token: CancellationToken,
) {
let this = Arc::clone(self);
tokio::spawn(async move {
let mut expiry_interval =
tokio::time::interval(FORCE_EXPIRE_REQUESTS_ACROSS_ALL_WORKERS_INTERVAL);
expiry_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = expiry_interval.tick() => {
this.force_expire_requests_across_all_workers();
}
_ = cancel_token.cancelled() => {
break;
}
}
}
});
}
} }
...@@ -26,9 +26,13 @@ use std::time::Duration; ...@@ -26,9 +26,13 @@ use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
/// Duration after which stale requests are forcibly expired (5 minutes) /// Duration after which stale requests may be expired (5 minutes).
const EXPIRY_DURATION: Duration = Duration::from_secs(300); const EXPIRY_DURATION: Duration = Duration::from_secs(300);
/// How often we *check* for stale requests (30 seconds). This is not
/// the expiration time, that is EXPIRY_DURATION.
const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);
// TODO: use the common request_id if it exists in the repo // TODO: use the common request_id if it exists in the repo
pub type RequestId = String; pub type RequestId = String;
...@@ -55,11 +59,10 @@ pub struct ActiveSequences { ...@@ -55,11 +59,10 @@ pub struct ActiveSequences {
#[getter(copy)] #[getter(copy)]
active_tokens: usize, active_tokens: usize,
/// Timer for when to force expiry of stale requests // Request timestamps, for expiration.
expiry_timer: Instant, request_timestamps: HashMap<RequestId, Instant>,
/// Set of request IDs to check for expiry last_expiry_check_time: Instant,
expiry_requests: HashSet<RequestId>,
} }
impl ActiveSequences { impl ActiveSequences {
...@@ -76,8 +79,8 @@ impl ActiveSequences { ...@@ -76,8 +79,8 @@ impl ActiveSequences {
fractional_blocks: HashMap::new(), fractional_blocks: HashMap::new(),
block_size, block_size,
active_tokens: 0, active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION, request_timestamps: HashMap::new(),
expiry_requests: HashSet::new(), last_expiry_check_time: Instant::now(),
} }
} }
...@@ -172,6 +175,8 @@ impl ActiveSequences { ...@@ -172,6 +175,8 @@ impl ActiveSequences {
// dummy empty sequence // dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new()); self.active_seqs.insert(request_id.clone(), Vec::new());
} }
self.request_timestamps
.insert(request_id.clone(), Instant::now());
removed_requests removed_requests
} }
...@@ -231,12 +236,11 @@ impl ActiveSequences { ...@@ -231,12 +236,11 @@ impl ActiveSequences {
pub fn free(&mut self, request_id: &RequestId) -> usize { pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id); self.mark_prefill_completed(request_id);
self.expiry_requests.remove(request_id);
// Remove expected output tokens tracking // Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id); self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence // Remove from active_seqs and get the token sequence
self.request_timestamps.remove(request_id);
let token_seq = match self.active_seqs.remove(request_id) { let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq, Some(seq) => seq,
None => { None => {
...@@ -299,21 +303,26 @@ impl ActiveSequences { ...@@ -299,21 +303,26 @@ impl ActiveSequences {
pub fn force_expiry(&mut self) -> HashSet<RequestId> { pub fn force_expiry(&mut self) -> HashSet<RequestId> {
let now = Instant::now(); let now = Instant::now();
// Early return if timer hasn't expired yet // Early return if timer hasn't expired yet.
if now < self.expiry_timer { if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
return HashSet::new(); return HashSet::new();
} }
// Process expired requests - drain to avoid clone self.last_expiry_check_time = now;
let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect(); let expired_requests_time = now - EXPIRY_DURATION;
let mut expired_requests: HashSet<RequestId> = HashSet::new();
for (request_id, timestamp) in self.request_timestamps.iter() {
if *timestamp < expired_requests_time {
expired_requests.insert(request_id.clone());
}
}
for request_id in &expired_requests { for request_id in &expired_requests {
tracing::warn!("Force expiring stale request: {}", request_id); tracing::warn!("Expiring stale request: {}", request_id);
self.free(request_id); self.free(request_id);
} }
self.expiry_timer = now + EXPIRY_DURATION;
self.expiry_requests = self.active_seqs.keys().cloned().collect();
expired_requests expired_requests
} }
} }
...@@ -420,25 +429,38 @@ mod tests { ...@@ -420,25 +429,38 @@ mod tests {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
// Add two requests // Add two requests at time 0 (paused clock)
seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None); seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None);
seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None); seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
// First expiry cycle: advance past EXPIRY_DURATION. // Advance 20s: check interval (CHECK_EXPIRY_FREQUENCY = 30s) not reached,
// This populates expiry_requests with {r1, r2} but doesn't expire anything // force_expiry returns without running the check.
// since expiry_requests started empty. tokio::time::advance(Duration::from_secs(20)).await;
tokio::time::advance(Duration::from_secs(301)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert!(expired.is_empty()); assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
assert_eq!(seq_manager.active_blocks(), 4);
// Second expiry cycle: advance again so the timer expires. // Advance to 31s: first time we pass the check interval. Requests are 31s old,
// Adding r3 triggers force_expiry which drains {r1, r2}. // still under EXPIRY_DURATION (300s), so none are expired.
tokio::time::advance(Duration::from_secs(301)).await; tokio::time::advance(Duration::from_secs(11)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None); let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "requests not old enough to expire");
assert_eq!(seq_manager.active_blocks(), 4);
// Advance to 301s: requests are now older than EXPIRY_DURATION.
// force_expiry runs and expires r1, r2.
tokio::time::advance(Duration::from_secs(270)).await;
let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()])); assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
// Only r3's block remains // add_request calls force_expiry internally. Add r3; no old requests remain,
// so expired set is empty and only r3 is active.
tokio::time::advance(Duration::from_secs(31)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
assert!(expired.is_empty());
assert_eq!(seq_manager.active_blocks(), 1); assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(), 4); assert_eq!(seq_manager.active_tokens(), 4);
} }
......
...@@ -133,6 +133,9 @@ pub async fn create_multi_worker_sequences( ...@@ -133,6 +133,9 @@ pub async fn create_multi_worker_sequences(
arc.start_replica_sync(subscriber, cancel_token); arc.start_replica_sync(subscriber, cancel_token);
} }
let expiry_cancel = component.drt().runtime().child_token();
arc.start_periodic_force_expiry_across_all_workers(expiry_cancel);
Ok(arc) Ok(arc)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment