Unverified Commit 01c9ee1a authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor generate to use new pipeline arch (#11323)

parent d6837aea
...@@ -2066,39 +2066,64 @@ impl GenerationRequest for GenerateRequest { ...@@ -2066,39 +2066,64 @@ impl GenerationRequest for GenerateRequest {
} }
} }
// TODO(generate): Define GenerateResponse and GenerateChoice structs // ============================================================================
// // SGLang Generate Response Types
// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964) // ============================================================================
//
// #[derive(Debug, Clone, Serialize, Deserialize)] /// SGLang generate response (single completion or array for n>1)
// pub struct GenerateResponse { ///
// pub id: String, /// Format for n=1:
// pub object: String, // "text.completion" /// ```json
// pub created: u64, /// {
// pub model: String, /// "text": "...",
// pub choices: Vec<GenerateChoice>, /// "output_ids": [...],
// #[serde(skip_serializing_if = "Option::is_none")] /// "meta_info": { ... }
// pub usage: Option<Usage>, /// }
// #[serde(skip_serializing_if = "Option::is_none")] /// ```
// pub system_fingerprint: Option<String>, ///
// } /// Format for n>1:
// /// ```json
// #[derive(Debug, Clone, Serialize, Deserialize)] /// [
// pub struct GenerateChoice { /// {"text": "...", "output_ids": [...], "meta_info": {...}},
// pub index: u32, /// {"text": "...", "output_ids": [...], "meta_info": {...}}
// pub text: String, /// ]
// #[serde(skip_serializing_if = "Option::is_none")] /// ```
// pub output_ids: Option<Vec<u32>>, #[derive(Debug, Clone, Serialize, Deserialize)]
// #[serde(skip_serializing_if = "Option::is_none")] pub struct GenerateResponse {
// pub finish_reason: Option<String>, pub text: String,
// #[serde(skip_serializing_if = "Option::is_none")] pub output_ids: Vec<u32>,
// pub logprobs: Option<TopLogprobs>, pub meta_info: GenerateMetaInfo,
// #[serde(skip_serializing_if = "Option::is_none")] }
// pub matched_stop: Option<Value>,
// } /// Metadata for a single generate completion
// #[derive(Debug, Clone, Serialize, Deserialize)]
// Note: Verify if similar structs already exist elsewhere before implementing. pub struct GenerateMetaInfo {
// May need streaming variant (GenerateStreamResponse) as well. pub id: String,
pub finish_reason: GenerateFinishReason,
pub prompt_tokens: u32,
pub weight_version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
pub completion_tokens: u32,
pub cached_tokens: u32,
pub e2e_latency: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
}
/// Finish reason for generate endpoint
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum GenerateFinishReason {
Length {
length: u32,
},
Stop,
#[serde(untagged)]
Other(Value),
}
// Constants for rerank API // Constants for rerank API
pub const DEFAULT_MODEL_NAME: &str = "default"; pub const DEFAULT_MODEL_NAME: &str = "default";
......
...@@ -12,7 +12,9 @@ use serde_json::Value; ...@@ -12,7 +12,9 @@ use serde_json::Value;
use crate::core::Worker; use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ChatCompletionRequest, ChatCompletionResponse, GenerateRequest}; use crate::protocols::spec::{
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse,
};
use crate::reasoning_parser::ReasoningParserFactory; use crate::reasoning_parser::ReasoningParserFactory;
use crate::tokenizer::stop::StopSequenceDecoder; use crate::tokenizer::stop::StopSequenceDecoder;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
...@@ -226,14 +228,6 @@ impl RequestContext { ...@@ -226,14 +228,6 @@ impl RequestContext {
} }
} }
/// Try to get chat request
pub fn try_chat_request(&self) -> Option<&ChatCompletionRequest> {
match &self.input.request_type {
RequestType::Chat(req) => Some(req.as_ref()),
_ => None,
}
}
/// Get generate request (panics if not generate) /// Get generate request (panics if not generate)
pub fn generate_request(&self) -> &GenerateRequest { pub fn generate_request(&self) -> &GenerateRequest {
match &self.input.request_type { match &self.input.request_type {
...@@ -242,14 +236,6 @@ impl RequestContext { ...@@ -242,14 +236,6 @@ impl RequestContext {
} }
} }
/// Try to get generate request
pub fn try_generate_request(&self) -> Option<&GenerateRequest> {
match &self.input.request_type {
RequestType::Generate(req) => Some(req.as_ref()),
_ => None,
}
}
/// Check if request is streaming /// Check if request is streaming
pub fn is_streaming(&self) -> bool { pub fn is_streaming(&self) -> bool {
match &self.input.request_type { match &self.input.request_type {
...@@ -257,16 +243,6 @@ impl RequestContext { ...@@ -257,16 +243,6 @@ impl RequestContext {
RequestType::Generate(req) => req.stream, RequestType::Generate(req) => req.stream,
} }
} }
/// Check if request is chat
pub fn is_chat(&self) -> bool {
matches!(&self.input.request_type, RequestType::Chat(_))
}
/// Check if request is generate
pub fn is_generate(&self) -> bool {
matches!(&self.input.request_type, RequestType::Generate(_))
}
} }
// ============================================================================ // ============================================================================
...@@ -394,5 +370,6 @@ pub enum ExecutionResult { ...@@ -394,5 +370,6 @@ pub enum ExecutionResult {
/// Final processed response /// Final processed response
pub enum FinalResponse { pub enum FinalResponse {
Chat(ChatCompletionResponse), Chat(ChatCompletionResponse),
Generate(Box<GenerateRequest>), /// Generate response is a Vec of GenerateResponse (n=1 returns single item, n>1 returns multiple)
Generate(Vec<GenerateResponse>),
} }
This diff is collapsed.
...@@ -11,15 +11,20 @@ use super::context::*; ...@@ -11,15 +11,20 @@ use super::context::*;
use super::processing; use super::processing;
use super::streaming; use super::streaming;
use super::utils; use super::utils;
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType}; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::proto; use crate::grpc_client::proto;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage, ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest,
GenerateResponse, InputIds, Usage,
}; };
use crate::tokenizer::stop::SequenceDecoderOutput;
use crate::tokenizer::traits::Tokenizer;
use proto::generate_complete::MatchedStop;
use proto::DisaggregatedParams;
use rand::Rng; use rand::Rng;
use std::sync::Arc; use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{Instant, SystemTime, UNIX_EPOCH};
use uuid::Uuid; use uuid::Uuid;
// ============================================================================ // ============================================================================
...@@ -208,7 +213,7 @@ impl PreparationStage { ...@@ -208,7 +213,7 @@ impl PreparationStage {
fn tokenize_single_text( fn tokenize_single_text(
&self, &self,
tokenizer: &Arc<dyn crate::tokenizer::traits::Tokenizer>, tokenizer: &Arc<dyn Tokenizer>,
text: &str, text: &str,
) -> Result<(String, Vec<u32>), String> { ) -> Result<(String, Vec<u32>), String> {
let encoding = tokenizer let encoding = tokenizer
...@@ -302,7 +307,7 @@ impl WorkerSelectionStage { ...@@ -302,7 +307,7 @@ impl WorkerSelectionStage {
&self, &self,
model_id: Option<&str>, model_id: Option<&str>,
text: Option<&str>, text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> { ) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model, filtered by connection mode // Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered( let workers = self.worker_registry.get_workers_filtered(
model_id, model_id,
...@@ -312,7 +317,7 @@ impl WorkerSelectionStage { ...@@ -312,7 +317,7 @@ impl WorkerSelectionStage {
); );
// Filter by availability (health + circuit breaker) // Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers let available: Vec<Arc<dyn Worker>> = workers
.iter() .iter()
.filter(|w| w.is_available()) .filter(|w| w.is_available())
.cloned() .cloned()
...@@ -337,7 +342,7 @@ impl WorkerSelectionStage { ...@@ -337,7 +342,7 @@ impl WorkerSelectionStage {
&self, &self,
model_id: Option<&str>, model_id: Option<&str>,
text: Option<&str>, text: Option<&str>,
) -> Option<(Arc<dyn crate::core::Worker>, Arc<dyn crate::core::Worker>)> { ) -> Option<(Arc<dyn Worker>, Arc<dyn Worker>)> {
// Get prefill workers - use None for WorkerType filter to get all types, // Get prefill workers - use None for WorkerType filter to get all types,
// then filter manually (since Prefill is a struct variant) // then filter manually (since Prefill is a struct variant)
let all_workers = self.worker_registry.get_workers_filtered( let all_workers = self.worker_registry.get_workers_filtered(
...@@ -537,10 +542,8 @@ impl RequestBuildingStage { ...@@ -537,10 +542,8 @@ impl RequestBuildingStage {
fn inject_bootstrap_metadata( fn inject_bootstrap_metadata(
&self, &self,
request: &mut proto::GenerateRequest, request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn crate::core::Worker>, prefill_worker: &Arc<dyn Worker>,
) { ) {
use proto::DisaggregatedParams;
let hostname = prefill_worker.bootstrap_host(); let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998); let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
...@@ -935,40 +938,183 @@ impl ResponseProcessingStage { ...@@ -935,40 +938,183 @@ impl ResponseProcessingStage {
async fn process_generate_response( async fn process_generate_response(
&self, &self,
_ctx: &mut RequestContext, ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> { ) -> Result<Option<Response>, Response> {
// TODO(generate): Implement generate response processing let start_time = Instant::now();
// let is_streaming = ctx.is_streaming();
// Required implementation:
// 1. Extract execution_result from ctx // Extract execution result
// 2. Check is_streaming flag let execution_result = ctx
// 3. For streaming: .state
// - Add StreamingProcessor::process_streaming_generate() method .response
// - Similar to process_streaming_response but WITHOUT tool/reasoning parsing .execution_result
// - Return Err(sse_response) for early exit .take()
// 4. For non-streaming: .ok_or_else(|| utils::internal_error_static("No execution result"))?;
// - Collect stream responses using utils::collect_stream_responses()
// - Process through stop decoder (sequential with reset for n>1, like chat) if is_streaming {
// - Build GenerateResponse struct (see TODO in protocols/spec.rs) // Get dispatch metadata for consistent response fields
// - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response)) let dispatch = ctx
// .state
// Reference implementation: router.rs:297-595 .dispatch
// Key differences from chat: .as_ref()
// - No tool parsing .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// - No reasoning parsing
// - Different response format (GenerateResponse instead of ChatCompletionResponse) let generate_request = ctx.generate_request().clone();
// - Still needs: stop decoder, logprobs, finish_reason, matched_stop
Err(( // Streaming: Use StreamingProcessor and return SSE response (done)
axum::http::StatusCode::NOT_IMPLEMENTED, return Ok(Some(
axum::Json(serde_json::json!({ self.streaming_processor.clone().process_streaming_generate(
"error": { execution_result,
"message": "Generate response processing not yet implemented in pipeline", generate_request,
"type": "not_implemented", dispatch.clone(),
"code": 501 ),
} ));
})), }
)
.into_response()) // Non-streaming: Collect all responses
let request_logprobs = ctx.generate_request().return_logprob;
let all_responses = match execution_result {
ExecutionResult::Single { stream } => {
utils::collect_stream_responses(stream, "Single").await?
}
ExecutionResult::Dual { prefill, decode } => {
// Collect prefill for input_logprobs
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
// Collect decode for actual output
let mut decode_responses =
utils::collect_stream_responses(*decode, "Decode").await?;
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
decode_responses
}
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// Get stop decoder for processing
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
// Get dispatch metadata
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Process each completion (similar to router.rs:336-400)
let mut result_array = Vec::new();
for mut complete in all_responses {
stop_decoder.reset();
// Process tokens through stop decoder
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
)))
}
};
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t);
}
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
// Parse finish_reason from string to proper type
let finish_reason =
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
// Handle matched_stop if present
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
});
// Extract logprobs if requested (convert proto types to Generate format)
let input_token_logprobs = if request_logprobs {
complete
.input_logprobs
.as_ref()
.map(utils::convert_generate_input_logprobs)
} else {
None
};
let output_token_logprobs = if request_logprobs {
complete
.output_logprobs
.as_ref()
.map(utils::convert_generate_output_logprobs)
} else {
None
};
// Build GenerateResponse struct
let meta_info = GenerateMetaInfo {
id: dispatch.request_id.clone(),
finish_reason,
prompt_tokens: complete.prompt_tokens as u32,
weight_version: dispatch
.weight_version
.clone()
.unwrap_or_else(|| "default".to_string()),
input_token_logprobs,
output_token_logprobs,
completion_tokens: complete.completion_tokens as u32,
cached_tokens: complete.cached_tokens as u32,
e2e_latency: start_time.elapsed().as_secs_f64(),
matched_stop,
};
result_array.push(GenerateResponse {
text: decoded_text,
output_ids,
meta_info,
});
}
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
Ok(None)
} }
} }
...@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline { ...@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline {
// Extract final response // Extract final response
match ctx.state.response.final_response { match ctx.state.response.final_response {
Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(), Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Chat(_)) => { Some(FinalResponse::Chat(_)) => {
utils::internal_error_static("Internal error: wrong response type") utils::internal_error_static("Internal error: wrong response type")
} }
......
...@@ -8,28 +8,21 @@ use axum::{ ...@@ -8,28 +8,21 @@ use axum::{
extract::Request, extract::Request,
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use tracing::debug; use tracing::debug;
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::core::WorkerRegistry;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
RerankRequest, ResponsesGetParams, ResponsesRequest, ResponsesGetParams, ResponsesRequest,
}; };
use crate::reasoning_parser::ReasoningParserFactory; use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::{grpc, RouterTrait}; use crate::routers::RouterTrait;
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::stop::SequenceDecoderOutput;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ToolParserFactory;
use grpc::utils;
use serde_json::json;
use std::time::Instant;
use uuid::Uuid;
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
...@@ -45,9 +38,7 @@ pub struct GrpcRouter { ...@@ -45,9 +38,7 @@ pub struct GrpcRouter {
retry_config: RetryConfig, retry_config: RetryConfig,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
// Pipeline for non-streaming requests
pipeline: super::pipeline::ChatCompletionPipeline, pipeline: super::pipeline::ChatCompletionPipeline,
// Shared components for pipeline
shared_components: Arc<super::context::SharedComponents>, shared_components: Arc<super::context::SharedComponents>,
} }
...@@ -149,420 +140,21 @@ impl GrpcRouter { ...@@ -149,420 +140,21 @@ impl GrpcRouter {
/// Main route_generate implementation /// Main route_generate implementation
async fn route_generate_impl( async fn route_generate_impl(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &GenerateRequest, body: &GenerateRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
debug!("Processing generate request for model: {:?}", model_id); debug!("Processing generate request for model: {:?}", model_id);
// Step 1: Resolve input (text, prompt, or input_ids) // Use pipeline for ALL requests (streaming and non-streaming)
let (original_text, token_ids) = match self.resolve_generate_input(body) { self.pipeline
Ok(res) => res, .execute_generate(
Err(msg) => { body.clone(),
return utils::bad_request_error(msg); headers.cloned(),
} model_id.map(|s| s.to_string()),
}; self.shared_components.clone(),
debug!("Resolved input with {} tokens", token_ids.len());
// Step 2: Select worker (fail fast if no workers available)
let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) {
Some(w) => w,
None => {
return utils::service_unavailable_error(format!(
"No available workers for model: {:?}",
model_id
));
}
};
debug!("Selected worker: {}", worker.url());
// Step 3: Get gRPC client from worker
let client = match utils::get_grpc_client_from_worker(&worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 4: Build the gRPC request
let request_id = body
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
let request = match client.build_plain_generate_request(
request_id.clone(),
body,
original_text.clone(),
token_ids,
) {
Ok(req) => req,
Err(e) => {
return utils::bad_request_error(e);
}
};
// Step 5: Get weight version for response metadata
let weight_version = worker
.metadata()
.labels
.get("weight_version")
.cloned()
.unwrap_or_else(|| "default".to_string());
// Step 6: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_generate(client, request, body, request_id, weight_version)
.await
} else {
self.handle_non_streaming_generate(client, request, body, request_id, weight_version)
.await
}
}
/// Select a worker for the request
fn select_worker_for_request(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
/// Resolve the generate input into optional original text and token IDs
fn resolve_generate_input(
&self,
request: &GenerateRequest,
) -> Result<(Option<String>, Vec<u32>), String> {
if let Some(text) = &request.text {
return self
.tokenize_single_text(text)
.map(|(original, ids)| (Some(original), ids));
}
// Handle input_ids - validate and convert
if let Some(input_ids) = &request.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| u32::try_from(id))
.collect::<Result<Vec<u32>, _>>()
.map(|converted| (None, converted))
.map_err(|_| "input_ids must be non-negative".to_string()),
InputIds::Batch(_) => {
Err("Batch input_ids are not supported over gRPC generate yet".to_string())
}
};
}
Err("Either `text` or `input_ids` must be provided".to_string())
}
fn tokenize_single_text(&self, text: &str) -> Result<(String, Vec<u32>), String> {
let encoding = self
.tokenizer
.encode(text)
.map_err(|e| format!("Tokenization failed: {}", e))?;
Ok((text.to_string(), encoding.token_ids().to_vec()))
}
/// Submit request and handle non-streaming response for the `/generate` endpoint
async fn handle_non_streaming_generate(
&self,
mut client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &GenerateRequest,
request_id: String,
weight_version: String,
) -> Response {
let start_time = Instant::now();
let stream = match client.generate(request).await {
Ok(stream) => stream,
Err(e) => {
return utils::internal_error_message(format!("Failed to start generation: {}", e))
}
};
// Collect all responses using utils helper
let responses = match utils::collect_stream_responses(stream, "Generate").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
if responses.is_empty() {
return utils::internal_error_static("No completion received from scheduler");
}
// Create stop decoder from sampling params
let params = original_request.sampling_params.as_ref();
let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
params.and_then(|p| p.stop.as_ref()),
params.and_then(|p| p.stop_token_ids.as_ref()),
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
);
// Process each completion
let mut result_array = Vec::new();
for mut complete in responses {
stop_decoder.reset();
// Process tokens through stop decoder
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
))
}
};
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t);
}
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason = std::mem::take(&mut complete.finish_reason);
// Build base meta_info using json! macro
let mut meta_info = json!({
"id": request_id.clone(),
"finish_reason": finish_reason,
"prompt_tokens": complete.prompt_tokens,
"weight_version": weight_version.clone(),
"completion_tokens": complete.completion_tokens,
"cached_tokens": complete.cached_tokens,
"e2e_latency": start_time.elapsed().as_secs_f64(),
});
let meta_obj = meta_info.as_object_mut().unwrap();
// Add matched_stop if present
if let Some(matched) = complete.matched_stop.take() {
use proto::generate_complete::MatchedStop;
let matched_value = match matched {
MatchedStop::MatchedTokenId(id) => json!(id),
MatchedStop::MatchedStopStr(s) => json!(s),
};
meta_obj.insert("matched_stop".to_string(), matched_value);
}
result_array.push(json!({
"text": decoded_text,
"output_ids": output_ids,
"meta_info": meta_info,
}));
}
Json(result_array).into_response()
}
/// Submit request and handle streaming response for the `/generate` endpoint
async fn handle_streaming_generate(
&self,
mut client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &GenerateRequest,
request_id: String,
weight_version: String,
) -> Response {
let tokenizer = self.tokenizer.clone();
let return_logprob = original_request.return_logprob;
// Create channel for SSE streaming
let (tx, rx) =
tokio::sync::mpsc::unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
// Start the stream
let stream = match client.generate(request).await {
Ok(stream) => stream,
Err(e) => {
return utils::internal_error_message(format!("Failed to start generation: {}", e))
}
};
// Spawn async task to process stream
tokio::spawn(async move {
let result = Self::process_generate_streaming(
tokenizer,
stream,
request_id,
weight_version,
return_logprob,
&tx,
) )
.await; .await
if let Err(e) = result {
let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e);
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
}
// Send [DONE] marker
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
});
// Create SSE response stream
let body_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(axum::body::Body::from_stream(body_stream))
.unwrap()
}
/// Process streaming chunks for generate endpoint
async fn process_generate_streaming(
tokenizer: Arc<dyn Tokenizer>,
mut stream: impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
+ Unpin,
request_id: String,
weight_version: String,
_include_logprobs: bool,
tx: &tokio::sync::mpsc::UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
) -> Result<(), String> {
use proto::generate_response::Response::{Chunk, Complete, Error};
use std::time::Instant;
use tokio_stream::StreamExt;
let start_time = Instant::now();
// Track state per index for n>1 case
use std::collections::HashMap;
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
while let Some(response) = stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Update completion tokens for this index
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
*completion_tokens += chunk.token_ids.len() as u32;
// Decode tokens to text (skip_special_tokens=true to handle newlines correctly)
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
// Accumulate text for this index
let accumulated_text = accumulated_texts.entry(index).or_default();
accumulated_text.push_str(&chunk_text);
// Generate unique ID per index
let index_id = format!("{}-{}", request_id, index);
// Build streaming response chunk (SGLang format)
let chunk_response = serde_json::json!({
"text": accumulated_text.clone(),
"output_ids": chunk.token_ids,
"meta_info": {
"id": index_id,
"finish_reason": null,
"prompt_tokens": chunk.prompt_tokens,
"weight_version": weight_version,
"completion_tokens": *completion_tokens,
"cached_tokens": chunk.cached_tokens
},
"index": index
});
let sse_chunk = format!(
"data: {}\n\n",
serde_json::to_string(&chunk_response).unwrap()
);
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
.map_err(|_| "Failed to send chunk".to_string())?;
}
Some(Complete(complete)) => {
let index = complete.index;
let accumulated_text =
accumulated_texts.get(&index).cloned().unwrap_or_default();
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
let index_id = format!("{}-{}", request_id, index);
let e2e_latency = start_time.elapsed().as_secs_f64();
// Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks)
let finish_response = serde_json::json!({
"text": accumulated_text,
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
"meta_info": {
"id": index_id,
"finish_reason": complete.finish_reason,
"prompt_tokens": complete.prompt_tokens,
"weight_version": weight_version,
"completion_tokens": completion_tokens,
"cached_tokens": complete.cached_tokens,
"e2e_latency": e2e_latency
},
"index": index
});
let sse_chunk = format!(
"data: {}\n\n",
serde_json::to_string(&finish_response).unwrap()
);
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
.map_err(|_| "Failed to send finish chunk".to_string())?;
// Continue to process all completions if n>1
}
Some(Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
Ok(())
} }
} }
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ use crate::core::Worker; ...@@ -5,7 +5,7 @@ use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb, GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
}; };
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
...@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs( ...@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs(
}) })
} }
/// Convert proto::OutputLogProbs to Generate format Vec<Vec<Option<f64>>>
///
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
/// Each inner vec contains [logprob (f64), token_id (i32), ...]
pub fn convert_generate_output_logprobs(
proto_logprobs: &proto::OutputLogProbs,
) -> Vec<Vec<Option<f64>>> {
proto_logprobs
.token_logprobs
.iter()
.zip(proto_logprobs.token_ids.iter())
.map(|(&logprob, &token_id)| vec![Some(logprob as f64), Some(token_id as f64)])
.collect()
}
/// Convert proto::InputLogProbs to Generate format Vec<Vec<Option<f64>>>
///
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
/// First token has null logprob: [[null, token_id], [logprob, token_id], ...]
pub fn convert_generate_input_logprobs(
proto_logprobs: &proto::InputLogProbs,
) -> Vec<Vec<Option<f64>>> {
proto_logprobs
.token_logprobs
.iter()
.zip(proto_logprobs.token_ids.iter())
.map(|(token_logprob, &token_id)| {
// InputTokenLogProb has optional value field
let logprob_value = token_logprob.value.map(|v| v as f64);
vec![logprob_value, Some(token_id as f64)]
})
.collect()
}
/// Parse finish_reason string into GenerateFinishReason enum
///
/// Uses serde to deserialize the finish_reason, which handles all tagged variants automatically.
/// The GenerateFinishReason enum is tagged with `#[serde(tag = "type", rename_all = "lowercase")]`,
/// so it expects JSON objects like:
/// - `{"type":"stop"}` -> Stop
/// - `{"type":"length","length":100}` -> Length { length: 100 }
/// - Any other JSON -> Other(...)
///
/// For backward compatibility, also handles simple string "stop" -> Stop
pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> GenerateFinishReason {
if reason_str == "stop" {
return GenerateFinishReason::Stop;
}
if reason_str == "length" {
return GenerateFinishReason::Length {
length: completion_tokens.max(0) as u32,
};
}
match serde_json::from_str::<GenerateFinishReason>(reason_str) {
Ok(finish_reason) => finish_reason,
Err(_) => match serde_json::from_str::<Value>(reason_str) {
Ok(json_value) => GenerateFinishReason::Other(json_value),
Err(_) => GenerateFinishReason::Other(Value::String(reason_str.to_string())),
},
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment