Unverified Commit 4673e47f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: to make sure free is always run on stream drop (#6246)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 2d517e77
......@@ -21,7 +21,10 @@ use crate::{
protocols::{TokensWithHashes, WorkerWithDpRank},
},
preprocessor::PreprocessedRequest,
protocols::common::{llm_backend::LLMEngineOutput, timing::RequestPhase},
protocols::common::{
llm_backend::LLMEngineOutput,
timing::{RequestPhase, RequestTracker},
},
};
pub struct KvPushRouter {
......@@ -36,6 +39,69 @@ struct WorkerSelection {
overlap_amount: u32,
}
/// Drop guard that ensures `free()` and final metrics are recorded even if the
/// response stream is dropped without being polled to completion.
///
/// In the happy path, `finish().await` runs cleanup inline in the async context.
/// If the stream is dropped early (e.g., client disconnect, consumer drop), the
/// `Drop` impl fires and spawns a task to call `free()`.
struct RequestGuard {
chooser: Arc<KvRouter>,
context_id: String,
handle_local_updates: bool,
tracker: Option<Arc<RequestTracker>>,
request_metrics: Arc<RouterRequestMetrics>,
cumulative_osl: usize,
metrics_recorded: bool,
freed: bool,
}
impl RequestGuard {
async fn finish(&mut self) {
self.record_metrics();
if self.handle_local_updates
&& let Err(e) = self.chooser.free(&self.context_id).await
{
tracing::warn!("Failed to free request {}: {e}", self.context_id);
}
self.freed = true;
}
fn record_metrics(&mut self) {
if self.metrics_recorded {
return;
}
self.metrics_recorded = true;
if let Some(ref tracker) = self.tracker {
tracker.record_finish();
tracker.record_osl(self.cumulative_osl);
self.request_metrics
.output_sequence_tokens
.observe(self.cumulative_osl as f64);
}
self.request_metrics.requests_total.inc();
}
}
impl Drop for RequestGuard {
fn drop(&mut self) {
self.record_metrics();
if !self.freed && self.handle_local_updates {
let chooser = self.chooser.clone();
let context_id = self.context_id.clone();
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::warn!("No tokio runtime for drop guard free of request {context_id}");
return;
};
handle.spawn(async move {
if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id} (drop guard): {e}");
}
});
}
}
}
impl KvPushRouter {
pub fn new(
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
......@@ -285,15 +351,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
// Wrap stream with lifecycle management (mark_prefill_completed, free)
// Only perform these operations if handle_local_updates is true.
// When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
// Wrap stream with lifecycle management (mark_prefill_completed, free).
// RequestGuard ensures free() and final metrics run even if the stream is
// dropped without being polled to completion (e.g., client disconnect).
let wrapped_stream = Box::pin(async_stream::stream! {
let mut guard = RequestGuard {
chooser: chooser.clone(),
context_id: context_id.clone(),
handle_local_updates,
tracker: tracker.clone(),
request_metrics: request_metrics.clone(),
cumulative_osl: 0,
metrics_recorded: false,
freed: false,
};
let mut prefill_marked = false;
let mut first_token_recorded = false;
// Output block tracking state
let mut cumulative_osl: usize = 0;
let mut current_total_blocks = isl_tokens.div_ceil(block_size);
loop {
......@@ -311,8 +384,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
};
if handle_local_updates && !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
......@@ -328,7 +399,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.map(|d| d.token_ids.len())
.unwrap_or(0);
// Record first token time on tracker when actual tokens arrive
if !first_token_recorded && new_tokens > 0 {
if let Some(ref tracker) = tracker {
tracker.record_first_token();
......@@ -341,16 +411,13 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
first_token_recorded = true;
}
cumulative_osl += new_tokens;
guard.cumulative_osl += new_tokens;
// Track output blocks if enabled
if track_output_blocks {
let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
let new_total_blocks = (isl_tokens + guard.cumulative_osl).div_ceil(block_size);
if new_total_blocks > current_total_blocks {
// New block boundary crossed - add output block with decay
// Clamp eot to min 1 to avoid division by zero, and result to min 0.0
let decay_fraction = expected_output_tokens.map(|eot| {
(1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
(1.0 - (guard.cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
});
if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
tracing::warn!(
......@@ -358,9 +425,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
);
}
// Update tracker and observe avg ITL at each block boundary
if let Some(ref tracker) = tracker {
tracker.record_osl(cumulative_osl);
tracker.record_osl(guard.cumulative_osl);
tracker.record_finish();
if let Some(avg_itl) = tracker.avg_itl_ms() {
request_metrics
......@@ -378,24 +444,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
// Record final aggregate metrics (histograms sampled once per request)
if let Some(ref tracker) = tracker {
tracker.record_finish();
tracker.record_osl(cumulative_osl);
request_metrics
.output_sequence_tokens
.observe(cumulative_osl as f64);
}
request_metrics.requests_total.inc();
// Only call free() if we handle local updates.
// When handle_local_updates=false, external caller handles cleanup via C FFI.
if handle_local_updates
&& let Err(e) = chooser.free(&context_id).await
{
tracing::warn!("Failed to free request {context_id}: {e}");
}
guard.finish().await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment