Unverified Commit 307d4039 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: router slot manager needs force expire requests (#2840)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 432e8290
...@@ -34,12 +34,17 @@ use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; ...@@ -34,12 +34,17 @@ use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt; use futures::StreamExt;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData}; use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
use dynamo_runtime::CancellationToken; use dynamo_runtime::CancellationToken;
/// Duration after which stale requests are forcibly expired (5 minutes)
const EXPIRY_DURATION: Duration = Duration::from_secs(300);
// TODO: use the common request_id if it exists in the repo // TODO: use the common request_id if it exists in the repo
pub type RequestId = String; pub type RequestId = String;
...@@ -60,6 +65,12 @@ pub struct ActiveSequences { ...@@ -60,6 +65,12 @@ pub struct ActiveSequences {
#[getter(copy)] #[getter(copy)]
active_tokens: usize, active_tokens: usize,
/// Timer for when to force expiry of stale requests
expiry_timer: Instant,
/// Set of request IDs to check for expiry
expiry_requests: HashSet<RequestId>,
} }
impl ActiveSequences { impl ActiveSequences {
...@@ -75,6 +86,8 @@ impl ActiveSequences { ...@@ -75,6 +86,8 @@ impl ActiveSequences {
block_size, block_size,
active_blocks: 0, active_blocks: 0,
active_tokens: 0, active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
} }
} }
...@@ -105,13 +118,17 @@ impl ActiveSequences { ...@@ -105,13 +118,17 @@ impl ActiveSequences {
} }
/// Add a new request with its initial tokens /// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request( pub fn add_request(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: Vec<SequenceHash>, token_sequence: Vec<SequenceHash>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
) -> usize { ) -> HashSet<RequestId> {
// Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry();
let prefill_tokens = self.new_tokens(isl, overlap); let prefill_tokens = self.new_tokens(isl, overlap);
self.prefill_tokens self.prefill_tokens
.insert(request_id.clone(), prefill_tokens); .insert(request_id.clone(), prefill_tokens);
...@@ -123,7 +140,7 @@ impl ActiveSequences { ...@@ -123,7 +140,7 @@ impl ActiveSequences {
self.active_seqs.insert(request_id.clone(), token_sequence); self.active_seqs.insert(request_id.clone(), token_sequence);
self.active_blocks removed_requests
} }
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking /// Mark prefill as completed for a request, removing it from prefill_tokens tracking
...@@ -170,6 +187,8 @@ impl ActiveSequences { ...@@ -170,6 +187,8 @@ impl ActiveSequences {
pub fn free(&mut self, request_id: &RequestId) -> usize { pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id); self.mark_prefill_completed(request_id);
self.expiry_requests.remove(request_id);
let Some(token_seq) = self.active_seqs.get(request_id) else { let Some(token_seq) = self.active_seqs.get(request_id) else {
tracing::warn!("Trying to free free non-existent request {request_id}"); tracing::warn!("Trying to free free non-existent request {request_id}");
return 0; return 0;
...@@ -183,6 +202,29 @@ impl ActiveSequences { ...@@ -183,6 +202,29 @@ impl ActiveSequences {
self.active_blocks self.active_blocks
} }
/// Force expiry of stale requests if the timer has elapsed
/// Returns the set of expired request IDs that were removed
pub fn force_expiry(&mut self) -> HashSet<RequestId> {
let now = Instant::now();
// Early return if timer hasn't expired yet
if now < self.expiry_timer {
return HashSet::new();
}
// Process expired requests - drain to avoid clone
let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect();
for request_id in &expired_requests {
tracing::warn!("Force expiring stale request: {}", request_id);
self.free(request_id);
}
self.expiry_timer = now + EXPIRY_DURATION;
self.expiry_requests = self.active_seqs.keys().cloned().collect();
expired_requests
}
} }
enum UpdateSequences { enum UpdateSequences {
...@@ -191,6 +233,7 @@ enum UpdateSequences { ...@@ -191,6 +233,7 @@ enum UpdateSequences {
token_sequence: Vec<SequenceHash>, token_sequence: Vec<SequenceHash>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
}, },
Free { Free {
request_id: RequestId, request_id: RequestId,
...@@ -314,8 +357,10 @@ impl ActiveSequencesMultiWorker { ...@@ -314,8 +357,10 @@ impl ActiveSequencesMultiWorker {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
resp_tx,
} => { } => {
active_sequences.add_request(request_id, token_sequence, isl, overlap); let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
let _ = resp_tx.send(removed);
} }
UpdateSequences::Free { request_id } => { UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id); active_sequences.free(&request_id);
...@@ -415,11 +460,14 @@ impl ActiveSequencesMultiWorker { ...@@ -415,11 +460,14 @@ impl ActiveSequencesMultiWorker {
request_to_worker.insert(event.request_id.clone(), event.worker_id); request_to_worker.insert(event.request_id.clone(), event.worker_id);
if let Some(sender) = senders.get(&event.worker_id) { if let Some(sender) = senders.get(&event.worker_id) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel();
let _ = sender.send(UpdateSequences::AddRequest { let _ = sender.send(UpdateSequences::AddRequest {
request_id: event.request_id.clone(), request_id: event.request_id.clone(),
token_sequence: token_sequence.clone(), token_sequence: token_sequence.clone(),
isl: *isl, isl: *isl,
overlap: *overlap, overlap: *overlap,
resp_tx,
}); });
} else { } else {
tracing::warn!( tracing::warn!(
...@@ -501,6 +549,9 @@ impl ActiveSequencesMultiWorker { ...@@ -501,6 +549,9 @@ impl ActiveSequencesMultiWorker {
return Err(anyhow::anyhow!("Worker ID {worker_id} not found")); return Err(anyhow::anyhow!("Worker ID {worker_id} not found"));
} }
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Publish event only if replica_sync is enabled // Publish event only if replica_sync is enabled
if self.replica_sync { if self.replica_sync {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
...@@ -529,9 +580,20 @@ impl ActiveSequencesMultiWorker { ...@@ -529,9 +580,20 @@ impl ActiveSequencesMultiWorker {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
resp_tx,
}) })
.map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?; .map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?;
// Wait for response and handle removed requests
let removed_requests = resp_rx
.await
.map_err(|_| anyhow::anyhow!("Failed to receive response from worker"))?;
// Remove expired requests from request_to_worker mapping
for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id);
}
Ok(()) Ok(())
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment