Unverified Commit c4958331 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] leverage RAII to actively cancel request during client disconnect (#11399)

parent 2eeb2751
...@@ -319,13 +319,8 @@ class GrpcRequestManager: ...@@ -319,13 +319,8 @@ class GrpcRequestManager:
is_stream = getattr(obj, "stream", False) is_stream = getattr(obj, "stream", False)
while True: while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try: try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4) response = await state.out_queue.get()
if is_stream: if is_stream:
yield response yield response
...@@ -338,10 +333,11 @@ class GrpcRequestManager: ...@@ -338,10 +333,11 @@ class GrpcRequestManager:
yield final_response yield final_response
break break
except asyncio.TimeoutError: except asyncio.CancelledError:
# Timeout is for periodic client cancellation check # Task was cancelled by gRPC framework when client disconnected
# Continue waiting for scheduler response logger.info(f"Request {request_id} cancelled by client")
continue await self.abort_request(request_id)
raise # Re-raise to let gRPC server handle cleanup
finally: finally:
# Always clean up request state when exiting # Always clean up request state when exiting
...@@ -409,32 +405,32 @@ class GrpcRequestManager: ...@@ -409,32 +405,32 @@ class GrpcRequestManager:
return future return future
async def abort_request(self, request_id: str) -> bool: async def abort_request(self, request_id: str) -> bool:
"""Abort a running request.""" """Abort a running request.
Sends abort request to scheduler and marks local state as finished
to stop processing any further outputs from the scheduler.
"""
# Skip aborting health check requests (they clean themselves up) # Skip aborting health check requests (they clean themselves up)
if request_id.startswith("HEALTH_CHECK"): if request_id.startswith("HEALTH_CHECK"):
return False return False
if request_id not in self.rid_to_state: # Mark state as finished immediately to stop processing scheduler outputs
return False state = self.rid_to_state.get(request_id)
if state:
state.finished = True
state.stream_finished = True
logger.debug(f"Marked request {request_id} as aborted locally")
# Send abort to scheduler # Send abort to scheduler - the scheduler will send AbortReq back
# which will be handled by _handle_abort_req
abort_req = AbortReq(rid=request_id) abort_req = AbortReq(rid=request_id)
try: try:
await self._send_to_scheduler(abort_req) await self._send_to_scheduler(abort_req)
logger.debug(f"Sent abort to scheduler for request {request_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to send abort request: {e}") logger.error(f"Failed to send abort request to scheduler: {e}")
return False return False
# Mark as finished
state = self.rid_to_state.get(request_id)
if state:
state.finished = True
state.stream_finished = True
state.event.set()
# Send abort notification to output queue
await state.out_queue.put({"error": "Request aborted", "abort": True})
return True return True
async def handle_loop(self): async def handle_loop(self):
...@@ -460,6 +456,8 @@ class GrpcRequestManager: ...@@ -460,6 +456,8 @@ class GrpcRequestManager:
await self._handle_embedding_output(recv_obj) await self._handle_embedding_output(recv_obj)
elif isinstance(recv_obj, HealthCheckOutput): elif isinstance(recv_obj, HealthCheckOutput):
await self._handle_health_check_output(recv_obj) await self._handle_health_check_output(recv_obj)
elif isinstance(recv_obj, AbortReq):
await self._handle_abort_req(recv_obj)
else: else:
logger.warning(f"Unknown output type: {type(recv_obj)}") logger.warning(f"Unknown output type: {type(recv_obj)}")
...@@ -541,6 +539,11 @@ class GrpcRequestManager: ...@@ -541,6 +539,11 @@ class GrpcRequestManager:
state = self.rid_to_state[rid] state = self.rid_to_state[rid]
# Skip if already aborted/finished locally (client cancelled)
if state.finished:
logger.debug(f"Skipping output for aborted request {rid}")
continue
# Update metrics # Update metrics
now = time.time() now = time.time()
if state.first_token_time == 0.0: if state.first_token_time == 0.0:
...@@ -713,6 +716,67 @@ class GrpcRequestManager: ...@@ -713,6 +716,67 @@ class GrpcRequestManager:
state.finished_time = time.time() state.finished_time = time.time()
state.event.set() state.event.set()
async def _handle_abort_req(self, recv_obj: AbortReq):
"""Handle abort request from scheduler.
The scheduler sends AbortReq back to notify us that a request was aborted,
either due to explicit abort_request() call or scheduler-initiated abort
(priority preemption, queue full, KV cache pressure, etc).
"""
# Skip health check requests
if recv_obj.rid.startswith("HEALTH_CHECK"):
return
# Check if request still exists
if recv_obj.rid not in self.rid_to_state:
logger.debug(
f"Abort request for {recv_obj.rid} not in local state (may have already finished or not started yet)"
)
return
state = self.rid_to_state[recv_obj.rid]
# Mark as finished
state.finished = True
state.stream_finished = True
# Create abort response
if recv_obj.finished_reason:
# Scheduler provided a specific finish reason (e.g., priority preemption, queue full)
abort_response = {
"request_id": recv_obj.rid,
"error": recv_obj.finished_reason.get("message", "Request aborted"),
"finished": True,
"meta_info": {
"id": recv_obj.rid,
"finish_reason": recv_obj.finished_reason,
},
}
else:
# Generic abort (e.g., explicit abort_request call)
abort_response = {
"request_id": recv_obj.rid,
"error": "Request aborted",
"finished": True,
"meta_info": {
"id": recv_obj.rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
},
"prompt_tokens": 0,
"completion_tokens": 0,
},
}
# Send abort notification to output queue
await state.out_queue.put(abort_response)
# Wake up any waiting coroutines
state.event.set()
logger.debug(f"Handled abort request for {recv_obj.rid}")
async def _send_to_scheduler(self, obj): async def _send_to_scheduler(self, obj):
"""Send an object to the scheduler via ZMQ.""" """Send an object to the scheduler via ZMQ."""
try: try:
......
...@@ -211,13 +211,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -211,13 +211,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) )
async for output in response_generator: async for output in response_generator:
# Check if client cancelled before processing/yielding
if context.cancelled():
logger.info(f"Client cancelled request {request.request_id}")
# Explicitly abort the request to notify scheduler
await self.request_manager.abort_request(request.request_id)
break
# Handle batch responses (for n>1 non-streaming) # Handle batch responses (for n>1 non-streaming)
if isinstance(output, list): if isinstance(output, list):
for batch_output in output: for batch_output in output:
......
use std::convert::TryFrom; use std::convert::TryFrom;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use tonic::{transport::Channel, Request, Streaming}; use tonic::{transport::Channel, Request, Streaming};
use tracing::debug; use tracing::{debug, warn};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, GenerateRequest, ResponseFormat, ChatCompletionRequest, GenerateRequest, ResponseFormat,
...@@ -16,6 +20,92 @@ pub mod proto { ...@@ -16,6 +20,92 @@ pub mod proto {
// The generated module structure depends on the package name in the .proto file // The generated module structure depends on the package name in the .proto file
// package sglang.grpc.scheduler; generates a nested module structure // package sglang.grpc.scheduler; generates a nested module structure
/// A smart wrapper around Streaming<GenerateResponse> that automatically
/// sends abort when dropped (e.g., due to client disconnection or early termination).
///
/// This leverages Rust's RAII pattern to ensure cleanup happens automatically,
/// regardless of how the stream is dropped (panic, early return, client disconnect, etc.).
pub struct AbortOnDropStream {
inner: Streaming<proto::GenerateResponse>,
request_id: String,
client: SglangSchedulerClient,
aborted: Arc<AtomicBool>,
}
impl AbortOnDropStream {
/// Create a new auto-aborting stream wrapper
pub fn new(
stream: Streaming<proto::GenerateResponse>,
request_id: String,
client: SglangSchedulerClient,
) -> Self {
debug!("Created AbortOnDropStream for request {}", request_id);
Self {
inner: stream,
request_id,
client,
aborted: Arc::new(AtomicBool::new(false)),
}
}
/// Manually mark the request as completed to prevent abort on drop.
/// Call this when the request completes successfully to avoid unnecessary abort RPC.
pub fn mark_completed(&self) {
// Use Release ordering to ensure that this write is visible to other threads
// that use Acquire on the same atomic variable
self.aborted.store(true, Ordering::Release);
debug!("Request {} marked as completed", self.request_id);
}
}
impl Drop for AbortOnDropStream {
fn drop(&mut self) {
// Atomically check and set the aborted flag using compare_exchange.
// If compare_exchange fails, it means the flag was already true (from mark_completed),
// so we don't need to send abort. AcqRel is used for success to synchronize with
// mark_completed's Release, and Acquire for failure to see writes from mark_completed.
if self
.aborted
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let client = self.client.clone();
let request_id = self.request_id.clone();
// Spawn a background task to send abort (since Drop is sync but abort_request is async)
tokio::spawn(async move {
debug!(
"Stream dropped without completion for request {}, sending abort",
request_id
);
// Clone request_id for the error message since abort_request takes ownership
let request_id_for_log = request_id.clone();
if let Err(e) = client
.abort_request(request_id, "Stream dropped".to_string())
.await
{
warn!(
"Failed to send abort on drop for request {}: {}",
request_id_for_log, e
);
}
});
}
}
// Implement Stream trait to make AbortOnDropStream work like the original Streaming
impl futures::Stream for AbortOnDropStream {
type Item = Result<proto::GenerateResponse, tonic::Status>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// Delegate to the inner stream
Pin::new(&mut self.inner).poll_next(cx)
}
}
/// gRPC client for SGLang scheduler /// gRPC client for SGLang scheduler
#[derive(Clone)] #[derive(Clone)]
pub struct SglangSchedulerClient { pub struct SglangSchedulerClient {
...@@ -35,7 +125,7 @@ impl SglangSchedulerClient { ...@@ -35,7 +125,7 @@ impl SglangSchedulerClient {
}; };
let channel = Channel::from_shared(http_endpoint)? let channel = Channel::from_shared(http_endpoint)?
.timeout(Duration::from_secs(3600)) .timeout(Duration::from_secs(600)) // 10 minute timeout for connection
.http2_keep_alive_interval(Duration::from_secs(30)) .http2_keep_alive_interval(Duration::from_secs(30))
.keep_alive_timeout(Duration::from_secs(10)) .keep_alive_timeout(Duration::from_secs(10))
.keep_alive_while_idle(true) .keep_alive_while_idle(true)
...@@ -52,15 +142,26 @@ impl SglangSchedulerClient { ...@@ -52,15 +142,26 @@ impl SglangSchedulerClient {
Ok(Self { client }) Ok(Self { client })
} }
/// Submit a generation request (returns streaming response) /// Submit a generation request (returns auto-aborting streaming response)
///
/// The returned stream automatically sends an abort request when dropped,
/// ensuring proper cleanup even if the HTTP client disconnects or an error occurs.
/// Call `mark_completed()` on the stream after successful completion to prevent
/// unnecessary abort RPCs.
pub async fn generate( pub async fn generate(
&self, &self,
req: proto::GenerateRequest, req: proto::GenerateRequest,
) -> Result<Streaming<proto::GenerateResponse>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<AbortOnDropStream, Box<dyn std::error::Error + Send + Sync>> {
let request_id = req.request_id.clone();
let mut client = self.client.clone(); let mut client = self.client.clone();
let request = Request::new(req); let request = Request::new(req);
let response = client.generate(request).await?; let response = client.generate(request).await?;
Ok(response.into_inner())
Ok(AbortOnDropStream::new(
response.into_inner(),
request_id,
self.clone(),
))
} }
/// Perform health check /// Perform health check
...@@ -68,12 +169,8 @@ impl SglangSchedulerClient { ...@@ -68,12 +169,8 @@ impl SglangSchedulerClient {
&self, &self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
debug!("Sending health check request"); debug!("Sending health check request");
let request = Request::new(proto::HealthCheckRequest { // Server ignores the request body and creates its own health check internally
tokenized: Some(proto::TokenizedInput { let request = Request::new(proto::HealthCheckRequest { tokenized: None });
original_text: "Hello".to_string(),
input_ids: vec![9906], // Mock token ID for "Hello"
}),
});
let mut client = self.client.clone(); let mut client = self.client.clone();
let response = client.health_check(request).await?; let response = client.health_check(request).await?;
...@@ -87,10 +184,23 @@ impl SglangSchedulerClient { ...@@ -87,10 +184,23 @@ impl SglangSchedulerClient {
request_id: String, request_id: String,
reason: String, reason: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let request = Request::new(proto::AbortRequest { request_id, reason }); debug!(
"Sending abort request for {} (reason: {})",
request_id, reason
);
let request = Request::new(proto::AbortRequest {
request_id: request_id.clone(),
reason,
});
let mut client = self.client.clone(); let mut client = self.client.clone();
client.abort(request).await?; let response = client.abort(request).await?;
debug!(
"Abort response for {}: success={}, message={}",
request_id,
response.get_ref().success,
response.get_ref().message
);
Ok(()) Ok(())
} }
......
...@@ -371,16 +371,17 @@ impl ClientSelection { ...@@ -371,16 +371,17 @@ impl ClientSelection {
// Execution and Response Types // Execution and Response Types
// ============================================================================ // ============================================================================
use tonic::codec::Streaming; use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
/// Result of request execution (streams from workers) /// Result of request execution (streams from workers)
/// Uses AbortOnDropStream to automatically abort on cancellation
pub enum ExecutionResult { pub enum ExecutionResult {
Single { Single {
stream: Streaming<proto::GenerateResponse>, stream: AbortOnDropStream,
}, },
Dual { Dual {
prefill: Streaming<proto::GenerateResponse>, prefill: AbortOnDropStream,
decode: Box<Streaming<proto::GenerateResponse>>, decode: Box<AbortOnDropStream>,
}, },
} }
......
...@@ -816,16 +816,27 @@ impl ResponseProcessingStage { ...@@ -816,16 +816,27 @@ impl ResponseProcessingStage {
// Collect all responses from the execution result // Collect all responses from the execution result
let all_responses = match execution_result { let all_responses = match execution_result {
ExecutionResult::Single { stream } => { ExecutionResult::Single { mut stream } => {
utils::collect_stream_responses(stream, "Single").await? let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
} }
ExecutionResult::Dual { prefill, decode } => { ExecutionResult::Dual {
// Collect prefill for input_logprobs mut prefill,
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?; decode,
} => {
// Collect decode for actual output // Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses = let mut decode_responses =
utils::collect_stream_responses(*decode, "Decode").await?; utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested // Merge prefill input_logprobs if requested
if request_logprobs { if request_logprobs {
...@@ -952,16 +963,27 @@ impl ResponseProcessingStage { ...@@ -952,16 +963,27 @@ impl ResponseProcessingStage {
// Non-streaming: Collect all responses // Non-streaming: Collect all responses
let request_logprobs = ctx.generate_request().return_logprob; let request_logprobs = ctx.generate_request().return_logprob;
let all_responses = match execution_result { let all_responses = match execution_result {
ExecutionResult::Single { stream } => { ExecutionResult::Single { mut stream } => {
utils::collect_stream_responses(stream, "Single").await? let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
stream.mark_completed();
responses
} }
ExecutionResult::Dual { prefill, decode } => { ExecutionResult::Dual {
// Collect prefill for input_logprobs mut prefill,
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?; decode,
} => {
// Collect decode for actual output // Collect prefill for input_logprobs (don't mark completed yet)
let prefill_responses =
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
// Collect decode for actual output (don't mark completed yet)
let mut decode_stream = *decode;
let mut decode_responses = let mut decode_responses =
utils::collect_stream_responses(*decode, "Decode").await?; utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
// Mark both streams as completed now that both succeeded
prefill.mark_completed();
decode_stream.mark_completed();
// Merge prefill input_logprobs if requested // Merge prefill input_logprobs if requested
if request_logprobs { if request_logprobs {
......
...@@ -14,7 +14,6 @@ use std::sync::Arc; ...@@ -14,7 +14,6 @@ use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tonic::codec::Streaming;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::context; use super::context;
...@@ -153,7 +152,7 @@ impl StreamingProcessor { ...@@ -153,7 +152,7 @@ impl StreamingProcessor {
/// Process streaming chunks from a single stream (Regular mode) /// Process streaming chunks from a single stream (Regular mode)
pub async fn process_streaming_chunks( pub async fn process_streaming_chunks(
&self, &self,
mut grpc_stream: Streaming<proto::GenerateResponse>, mut grpc_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool), stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: Arc<ChatCompletionRequest>, original_request: Arc<ChatCompletionRequest>,
...@@ -571,14 +570,17 @@ impl StreamingProcessor { ...@@ -571,14 +570,17 @@ impl StreamingProcessor {
} }
} }
// Mark stream as completed successfully to prevent abort on drop
grpc_stream.mark_completed();
Ok(()) Ok(())
} }
/// Process dual streaming chunks (prefill + decode) - PD mode /// Process dual streaming chunks (prefill + decode) - PD mode
pub async fn process_dual_streaming_chunks( pub async fn process_dual_streaming_chunks(
&self, &self,
mut prefill_stream: Streaming<proto::GenerateResponse>, mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
decode_stream: Streaming<proto::GenerateResponse>, decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool), stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: Arc<ChatCompletionRequest>, original_request: Arc<ChatCompletionRequest>,
...@@ -603,8 +605,18 @@ impl StreamingProcessor { ...@@ -603,8 +605,18 @@ impl StreamingProcessor {
} }
// Phase 2-5: Process decode stream (same as single mode) // Phase 2-5: Process decode stream (same as single mode)
self.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx) // Note: decode_stream will be marked completed inside process_streaming_chunks
.await let result = self
.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx)
.await;
// Mark prefill stream as completed AFTER decode completes successfully
// This ensures that if client disconnects during decode, BOTH streams send abort
if result.is_ok() {
prefill_stream.mark_completed();
}
result
} }
/// Process streaming generate response and return SSE response /// Process streaming generate response and return SSE response
...@@ -687,7 +699,7 @@ impl StreamingProcessor { ...@@ -687,7 +699,7 @@ impl StreamingProcessor {
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing) /// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
async fn process_generate_streaming( async fn process_generate_streaming(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
mut stream: Streaming<proto::GenerateResponse>, mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
request_id: String, request_id: String,
weight_version: String, weight_version: String,
_include_logprobs: bool, _include_logprobs: bool,
...@@ -782,14 +794,17 @@ impl StreamingProcessor { ...@@ -782,14 +794,17 @@ impl StreamingProcessor {
} }
} }
// Mark stream as completed successfully to prevent abort on drop
stream.mark_completed();
Ok(()) Ok(())
} }
/// Process dual streaming for generate endpoint (PD mode with logprobs support) /// Process dual streaming for generate endpoint (PD mode with logprobs support)
async fn process_generate_streaming_dual( async fn process_generate_streaming_dual(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
mut prefill_stream: Streaming<proto::GenerateResponse>, mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
decode_stream: Streaming<proto::GenerateResponse>, decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
request_id: String, request_id: String,
weight_version: String, weight_version: String,
return_logprob: bool, return_logprob: bool,
...@@ -821,7 +836,8 @@ impl StreamingProcessor { ...@@ -821,7 +836,8 @@ impl StreamingProcessor {
}; };
// Process decode stream with input_logprobs prepended // Process decode stream with input_logprobs prepended
Self::process_generate_streaming_with_input_logprobs( // Note: decode_stream will be marked completed inside the function
let result = Self::process_generate_streaming_with_input_logprobs(
tokenizer, tokenizer,
decode_stream, decode_stream,
request_id, request_id,
...@@ -830,13 +846,21 @@ impl StreamingProcessor { ...@@ -830,13 +846,21 @@ impl StreamingProcessor {
input_token_logprobs, input_token_logprobs,
tx, tx,
) )
.await .await;
// Mark prefill stream as completed AFTER decode completes successfully
// This ensures that if client disconnects during decode, BOTH streams send abort
if result.is_ok() {
prefill_stream.mark_completed();
}
result
} }
/// Process generate streaming with optional input_logprobs /// Process generate streaming with optional input_logprobs
async fn process_generate_streaming_with_input_logprobs( async fn process_generate_streaming_with_input_logprobs(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
mut stream: Streaming<proto::GenerateResponse>, mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
request_id: String, request_id: String,
weight_version: String, weight_version: String,
_include_logprobs: bool, _include_logprobs: bool,
...@@ -957,6 +981,9 @@ impl StreamingProcessor { ...@@ -957,6 +981,9 @@ impl StreamingProcessor {
} }
} }
// Mark stream as completed successfully to prevent abort on drop
stream.mark_completed();
Ok(()) Ok(())
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
use super::ProcessedMessages; use super::ProcessedMessages;
use crate::core::Worker; use crate::core::Worker;
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
...@@ -20,7 +21,6 @@ use futures::StreamExt; ...@@ -20,7 +21,6 @@ use futures::StreamExt;
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tonic::codec::Streaming;
use tracing::{error, warn}; use tracing::{error, warn};
use uuid::Uuid; use uuid::Uuid;
...@@ -590,7 +590,7 @@ pub fn parse_json_schema_response( ...@@ -590,7 +590,7 @@ pub fn parse_json_schema_response(
/// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream /// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream
/// * `Err(Response)` - Error response if the stream fails or returns an error /// * `Err(Response)` - Error response if the stream fails or returns an error
pub async fn collect_stream_responses( pub async fn collect_stream_responses(
mut stream: Streaming<proto::GenerateResponse>, stream: &mut AbortOnDropStream,
worker_name: &str, worker_name: &str,
) -> Result<Vec<proto::GenerateComplete>, Response> { ) -> Result<Vec<proto::GenerateComplete>, Response> {
use proto::generate_response::Response::*; use proto::generate_response::Response::*;
...@@ -606,6 +606,7 @@ pub async fn collect_stream_responses( ...@@ -606,6 +606,7 @@ pub async fn collect_stream_responses(
} }
Some(Error(err)) => { Some(Error(err)) => {
error!("{} error: {}", worker_name, err.message); error!("{} error: {}", worker_name, err.message);
// Don't mark as completed - let Drop send abort for error cases
return Err(internal_error_message(format!( return Err(internal_error_message(format!(
"{} generation failed: {}", "{} generation failed: {}",
worker_name, err.message worker_name, err.message
...@@ -621,6 +622,7 @@ pub async fn collect_stream_responses( ...@@ -621,6 +622,7 @@ pub async fn collect_stream_responses(
} }
Err(e) => { Err(e) => {
error!("{} stream error: {:?}", worker_name, e); error!("{} stream error: {:?}", worker_name, e);
// Don't mark as completed - let Drop send abort for error cases
return Err(internal_error_message(format!( return Err(internal_error_message(format!(
"{} stream failed: {}", "{} stream failed: {}",
worker_name, e worker_name, e
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment