"...speculative-decoding/speculative-decoding-vllm.md" did not exist on "d59b9d72ed98eed50e8863de8259d0871efad336"
Unverified Commit c7986b35 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

refactor: move per-item stream tracking into RequestGuard (#6355)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarCursor <cursoragent@cursor.com>
parent 61889a14
......@@ -40,8 +40,8 @@ struct WorkerSelection {
overlap_amount: u32,
}
/// Drop guard that ensures `free()` and final metrics are recorded even if the
/// response stream is dropped without being polled to completion.
/// Drop guard that manages the full lifecycle of a routed request:
/// per-item tracking (prefill, first token, output blocks) and final cleanup (free + metrics).
///
/// In the happy path, `finish().await` runs cleanup inline in the async context.
/// If the stream is dropped early (e.g., client disconnect, consumer drop), the
......@@ -54,9 +54,83 @@ struct RequestGuard {
cumulative_osl: usize,
metrics_recorded: bool,
freed: bool,
prefill_marked: bool,
first_token_recorded: bool,
track_output_blocks: bool,
current_total_blocks: usize,
isl_tokens: usize,
block_size: usize,
expected_output_tokens: Option<u32>,
}
impl RequestGuard {
async fn on_item(&mut self, item: &Annotated<LLMEngineOutput>) {
if !self.prefill_marked {
let has_tokens = item
.data
.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = self.chooser.mark_prefill_completed(&self.context_id).await {
tracing::warn!(
"Failed to mark prefill completed for request {}: {e}",
self.context_id
);
}
self.prefill_marked = true;
}
}
let new_tokens = item.data.as_ref().map(|d| d.token_ids.len()).unwrap_or(0);
if !self.first_token_recorded && new_tokens > 0 {
if let Some(ref tracker) = self.tracker {
tracker.record_first_token();
if let Some(ttft) = tracker.ttft_ms() {
self.request_metrics
.time_to_first_token_seconds
.observe(ttft / 1000.0);
}
}
self.first_token_recorded = true;
}
self.cumulative_osl += new_tokens;
if self.track_output_blocks {
let new_total_blocks =
(self.isl_tokens + self.cumulative_osl).div_ceil(self.block_size);
if new_total_blocks > self.current_total_blocks {
let decay_fraction = self
.expected_output_tokens
.map(|eot| (1.0 - (self.cumulative_osl as f64 / eot.max(1) as f64)).max(0.0));
if let Err(e) = self
.chooser
.add_output_block(&self.context_id, decay_fraction)
.await
{
tracing::warn!(
"Failed to add output block for request {}: {e}",
self.context_id
);
}
if let Some(ref tracker) = self.tracker {
tracker.record_osl(self.cumulative_osl);
tracker.record_finish();
if let Some(avg_itl) = tracker.avg_itl_ms() {
self.request_metrics
.inter_token_latency_seconds
.observe(avg_itl / 1000.0);
}
}
self.current_total_blocks = new_total_blocks;
}
}
}
async fn finish(&mut self) {
self.record_metrics();
if let Err(e) = self.chooser.free(&self.context_id).await {
......@@ -384,9 +458,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
// Wrap stream with lifecycle management (mark_prefill_completed, free).
// RequestGuard ensures free() and final metrics run even if the stream is
// dropped without being polled to completion (e.g., client disconnect).
let wrapped_stream = Box::pin(async_stream::stream! {
let mut guard = RequestGuard {
chooser: chooser.clone(),
......@@ -396,10 +467,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
cumulative_osl: 0,
metrics_recorded: false,
freed: false,
prefill_marked: false,
first_token_recorded: false,
track_output_blocks,
current_total_blocks: isl_tokens.div_ceil(block_size),
isl_tokens,
block_size,
expected_output_tokens,
};
let mut prefill_marked = false;
let mut first_token_recorded = false;
let mut current_total_blocks = isl_tokens.div_ceil(block_size);
loop {
tokio::select! {
......@@ -414,65 +489,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let Some(item) = item else {
break;
};
if !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
}
prefill_marked = true;
}
}
let new_tokens = item.data.as_ref()
.map(|d| d.token_ids.len())
.unwrap_or(0);
if !first_token_recorded && new_tokens > 0 {
if let Some(ref tracker) = tracker {
tracker.record_first_token();
if let Some(ttft) = tracker.ttft_ms() {
request_metrics
.time_to_first_token_seconds
.observe(ttft / 1000.0);
}
}
first_token_recorded = true;
}
guard.cumulative_osl += new_tokens;
if track_output_blocks {
let new_total_blocks = (isl_tokens + guard.cumulative_osl).div_ceil(block_size);
if new_total_blocks > current_total_blocks {
let decay_fraction = expected_output_tokens.map(|eot| {
(1.0 - (guard.cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
});
if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
tracing::warn!(
"Failed to add output block for request {context_id}: {e}"
);
}
if let Some(ref tracker) = tracker {
tracker.record_osl(guard.cumulative_osl);
tracker.record_finish();
if let Some(avg_itl) = tracker.avg_itl_ms() {
request_metrics
.inter_token_latency_seconds
.observe(avg_itl / 1000.0);
}
}
current_total_blocks = new_total_blocks;
}
}
guard.on_item(&item).await;
yield item;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment