Unverified Commit 1e261dbd authored by blarson-b10's avatar blarson-b10 Committed by GitHub
Browse files

fix: Improve active sequences request expiration (#7340)


Signed-off-by: default avatarBrian Larson <brian.larson@baseten.co>
parent a441aaf8
......@@ -15,6 +15,7 @@ use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use super::single::{ActiveSequences, RequestId};
......@@ -22,6 +23,11 @@ use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, WorkerWithDpRank,
};
// How often we force expire stale requests across all workers. See the comment
// in ActiveSequencesMultiWorker::force_expire_requests_across_all_workers for
// more details.
const FORCE_EXPIRE_REQUESTS_ACROSS_ALL_WORKERS_INTERVAL: Duration = Duration::from_secs(60);
// ---------------------------------------------------------------------------
// Traits
// ---------------------------------------------------------------------------
......@@ -691,4 +697,62 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
counts
}
/// Force expire stale requests across all workers (one-shot).
///
/// This is necessary because worker expiration otherwise only runs as a side-effect
/// of `add_request`. If a worker has many expired active sequences and no new
/// requests are added, expiration never runs. This method forces it on all workers.
///
/// To run this periodically, use start_periodic_force_expiry_across_all_workers.
pub fn force_expire_requests_across_all_workers(&self) {
let now = Instant::now();
let table = self.workers.read();
let mut removed_request_count = 0;
for (worker, lock) in &table.slots {
let removed_requests = lock.write().force_expiry();
if !removed_requests.is_empty() {
for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id);
removed_request_count += 1;
}
self.publish_active_load_for_worker(*worker);
}
}
let duration = now.elapsed();
tracing::debug!(
duration = duration.as_secs_f64(),
removed_request_count,
"Force expired stale requests across all workers"
);
}
/// Spawn a background task that calls `force_expire_requests_across_all_workers`
/// at the given interval until `cancel_token` is cancelled.
///
/// **Concurrency note:** This type is always used as `Arc<ActiveSequencesMultiWorker>`. All
/// mutation is via interior mutability (`RwLock<WorkerTable>`, `DashMap`), so the periodic
/// task only needs `&self` and does not block other callers.
pub fn start_periodic_force_expiry_across_all_workers(
self: &Arc<Self>,
cancel_token: CancellationToken,
) {
let this = Arc::clone(self);
tokio::spawn(async move {
let mut expiry_interval =
tokio::time::interval(FORCE_EXPIRE_REQUESTS_ACROSS_ALL_WORKERS_INTERVAL);
expiry_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = expiry_interval.tick() => {
this.force_expire_requests_across_all_workers();
}
_ = cancel_token.cancelled() => {
break;
}
}
}
});
}
}
......@@ -26,9 +26,13 @@ use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;
/// Duration after which stale requests are forcibly expired (5 minutes)
/// Duration after which stale requests may be expired (5 minutes).
const EXPIRY_DURATION: Duration = Duration::from_secs(300);
/// How often we *check* for stale requests (30 seconds). This is not
/// the expiration time, that is EXPIRY_DURATION.
const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
......@@ -55,11 +59,10 @@ pub struct ActiveSequences {
#[getter(copy)]
active_tokens: usize,
/// Timer for when to force expiry of stale requests
expiry_timer: Instant,
// Request timestamps, for expiration.
request_timestamps: HashMap<RequestId, Instant>,
/// Set of request IDs to check for expiry
expiry_requests: HashSet<RequestId>,
last_expiry_check_time: Instant,
}
impl ActiveSequences {
......@@ -76,8 +79,8 @@ impl ActiveSequences {
fractional_blocks: HashMap::new(),
block_size,
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
request_timestamps: HashMap::new(),
last_expiry_check_time: Instant::now(),
}
}
......@@ -172,6 +175,8 @@ impl ActiveSequences {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
}
self.request_timestamps
.insert(request_id.clone(), Instant::now());
removed_requests
}
......@@ -231,12 +236,11 @@ impl ActiveSequences {
pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id);
self.expiry_requests.remove(request_id);
// Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence
self.request_timestamps.remove(request_id);
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
None => {
......@@ -299,21 +303,26 @@ impl ActiveSequences {
pub fn force_expiry(&mut self) -> HashSet<RequestId> {
let now = Instant::now();
// Early return if timer hasn't expired yet
if now < self.expiry_timer {
// Early return if timer hasn't expired yet.
if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
return HashSet::new();
}
// Process expired requests - drain to avoid clone
let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect();
self.last_expiry_check_time = now;
let expired_requests_time = now - EXPIRY_DURATION;
let mut expired_requests: HashSet<RequestId> = HashSet::new();
for (request_id, timestamp) in self.request_timestamps.iter() {
if *timestamp < expired_requests_time {
expired_requests.insert(request_id.clone());
}
}
for request_id in &expired_requests {
tracing::warn!("Force expiring stale request: {}", request_id);
tracing::warn!("Expiring stale request: {}", request_id);
self.free(request_id);
}
self.expiry_timer = now + EXPIRY_DURATION;
self.expiry_requests = self.active_seqs.keys().cloned().collect();
expired_requests
}
}
......@@ -420,25 +429,38 @@ mod tests {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
// Add two requests
// Add two requests at time 0 (paused clock)
seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None);
seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None);
assert_eq!(seq_manager.active_blocks(), 4);
// First expiry cycle: advance past EXPIRY_DURATION.
// This populates expiry_requests with {r1, r2} but doesn't expire anything
// since expiry_requests started empty.
tokio::time::advance(Duration::from_secs(301)).await;
// Advance 20s: check interval (CHECK_EXPIRY_FREQUENCY = 30s) not reached,
// force_expiry returns without running the check.
tokio::time::advance(Duration::from_secs(20)).await;
let expired = seq_manager.force_expiry();
assert!(expired.is_empty());
assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
assert_eq!(seq_manager.active_blocks(), 4);
// Second expiry cycle: advance again so the timer expires.
// Adding r3 triggers force_expiry which drains {r1, r2}.
tokio::time::advance(Duration::from_secs(301)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
// Advance to 31s: first time we pass the check interval. Requests are 31s old,
// still under EXPIRY_DURATION (300s), so none are expired.
tokio::time::advance(Duration::from_secs(11)).await;
let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "requests not old enough to expire");
assert_eq!(seq_manager.active_blocks(), 4);
// Advance to 301s: requests are now older than EXPIRY_DURATION.
// force_expiry runs and expires r1, r2.
tokio::time::advance(Duration::from_secs(270)).await;
let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
// Only r3's block remains
// add_request calls force_expiry internally. Add r3; no old requests remain,
// so expired set is empty and only r3 is active.
tokio::time::advance(Duration::from_secs(31)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
assert!(expired.is_empty());
assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(), 4);
}
......
......@@ -133,6 +133,9 @@ pub async fn create_multi_worker_sequences(
arc.start_replica_sync(subscriber, cancel_token);
}
let expiry_cancel = component.drt().runtime().child_token();
arc.start_periodic_force_expiry_across_all_workers(expiry_cancel);
Ok(arc)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment