Unverified Commit 0b9915c1 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] update generate spec to align with sgl io struct (#11591)

parent 27ef1459
...@@ -28,15 +28,38 @@ fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) { ...@@ -28,15 +28,38 @@ fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
fn default_generate_request() -> GenerateRequest { fn default_generate_request() -> GenerateRequest {
GenerateRequest { GenerateRequest {
text: None, text: None,
prompt: None,
input_ids: None, input_ids: None,
stream: false, input_embeds: None,
image_data: None,
video_data: None,
audio_data: None,
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: None,
// SGLang Extensions logprob_start_len: None,
lora_path: None, top_logprobs_num: None,
session_params: None, token_ids_logprob: None,
return_text_in_logprobs: false,
stream: false,
log_metrics: true,
return_hidden_states: false, return_hidden_states: false,
modalities: None,
session_params: None,
lora_path: None,
lora_id: None,
custom_logit_processor: None,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
bootstrap_pair_key: None,
data_parallel_rank: None,
background: false,
conversation_id: None,
priority: None,
extra_key: None,
no_logs: false,
custom_labels: None,
return_bytes: false,
return_entropy: false,
rid: None, rid: None,
} }
} }
...@@ -101,6 +124,7 @@ fn create_sample_generate_request() -> GenerateRequest { ...@@ -101,6 +124,7 @@ fn create_sample_generate_request() -> GenerateRequest {
GenerateRequest { GenerateRequest {
text: Some("Write a story about artificial intelligence".to_string()), text: Some("Write a story about artificial intelligence".to_string()),
sampling_params: Some(SamplingParams { sampling_params: Some(SamplingParams {
max_new_tokens: Some(100),
temperature: Some(0.8), temperature: Some(0.8),
top_p: Some(0.9), top_p: Some(0.9),
top_k: Some(50), top_k: Some(50),
......
...@@ -280,13 +280,13 @@ impl SglangSchedulerClient { ...@@ -280,13 +280,13 @@ impl SglangSchedulerClient {
input_ids: token_ids, input_ids: token_ids,
}), }),
sampling_params: Some(sampling_params), sampling_params: Some(sampling_params),
return_logprob: body.return_logprob, return_logprob: body.return_logprob.unwrap_or(false),
logprob_start_len: -1, logprob_start_len: body.logprob_start_len.unwrap_or(-1),
top_logprobs_num: 0, top_logprobs_num: body.top_logprobs_num.unwrap_or(0),
token_ids_logprob: vec![], token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(),
return_hidden_states: body.return_hidden_states, return_hidden_states: body.return_hidden_states,
stream: body.stream, stream: body.stream,
log_metrics: true, log_metrics: body.log_metrics,
..Default::default() ..Default::default()
}; };
......
...@@ -356,7 +356,7 @@ pub struct ChatCompletionRequest { ...@@ -356,7 +356,7 @@ pub struct ChatCompletionRequest {
/// Path to LoRA adapter(s) for model customization /// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>, pub lora_path: Option<String>,
/// Session parameters for continual prompting /// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -905,7 +905,7 @@ pub struct CompletionRequest { ...@@ -905,7 +905,7 @@ pub struct CompletionRequest {
/// Path to LoRA adapter(s) for model customization /// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>, pub lora_path: Option<String>,
/// Session parameters for continual prompting /// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -2309,10 +2309,6 @@ fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::Va ...@@ -2309,10 +2309,6 @@ fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::Va
#[derive(Clone, Debug, Serialize, Deserialize, Validate)] #[derive(Clone, Debug, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_generate_request"))] #[validate(schema(function = "validate_generate_request"))]
pub struct GenerateRequest { pub struct GenerateRequest {
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<StringOrArray>,
/// Text input - SGLang native format /// Text input - SGLang native format
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>, pub text: Option<String>,
...@@ -2321,31 +2317,144 @@ pub struct GenerateRequest { ...@@ -2321,31 +2317,144 @@ pub struct GenerateRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>, pub input_ids: Option<InputIds>,
/// Input embeddings for direct embedding input
/// Can be a 2D array (single request) or 3D array (batch of requests)
/// Placeholder for future use
#[serde(skip_serializing_if = "Option::is_none")]
pub input_embeds: Option<Value>,
/// Image input data
/// Can be an image instance, file name, URL, or base64 encoded string
/// Supports single images, lists of images, or nested lists for batch processing
/// Placeholder for future use
#[serde(skip_serializing_if = "Option::is_none")]
pub image_data: Option<Value>,
/// Video input data
/// Can be a file name, URL, or base64 encoded string
/// Supports single videos, lists of videos, or nested lists for batch processing
/// Placeholder for future use
#[serde(skip_serializing_if = "Option::is_none")]
pub video_data: Option<Value>,
/// Audio input data
/// Can be a file name, URL, or base64 encoded string
/// Supports single audio files, lists of audio, or nested lists for batch processing
/// Placeholder for future use
#[serde(skip_serializing_if = "Option::is_none")]
pub audio_data: Option<Value>,
/// Sampling parameters (sglang style) /// Sampling parameters (sglang style)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>, pub sampling_params: Option<SamplingParams>,
/// Whether to return logprobs
#[serde(skip_serializing_if = "Option::is_none")]
pub return_logprob: Option<bool>,
/// If return logprobs, the start location in the prompt for returning logprobs.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprob_start_len: Option<i32>,
/// If return logprobs, the number of top logprobs to return at each position.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs_num: Option<i32>,
/// If return logprobs, the token ids to return logprob for.
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids_logprob: Option<Vec<u32>>,
/// Whether to detokenize tokens in text in the returned logprobs.
#[serde(default)]
pub return_text_in_logprobs: bool,
/// Whether to stream the response /// Whether to stream the response
#[serde(default)] #[serde(default)]
pub stream: bool, pub stream: bool,
/// Whether to return logprobs /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
#[serde(default = "default_true")]
pub log_metrics: bool,
/// Return model hidden states
#[serde(default)] #[serde(default)]
pub return_logprob: bool, pub return_hidden_states: bool,
/// Path to LoRA adapter(s) for model customization /// The modalities of the image data [image, multi-images, video]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>, pub modalities: Option<Vec<String>>,
/// Session parameters for continual prompting /// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, Value>>, pub session_params: Option<HashMap<String, Value>>,
/// Return model hidden states /// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<String>,
/// LoRA adapter ID (if pre-loaded)
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_id: Option<String>,
/// Custom logit processor for advanced sampling control. Must be a serialized instance
/// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
/// Use the processor's `to_str()` method to generate the serialized string.
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_logit_processor: Option<String>,
/// For disaggregated inference
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_host: Option<String>,
/// For disaggregated inference
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<i32>,
/// For disaggregated inference
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_room: Option<i32>,
/// For disaggregated inference
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_pair_key: Option<String>,
/// Data parallel rank routing
#[serde(skip_serializing_if = "Option::is_none")]
pub data_parallel_rank: Option<i32>,
/// Background response
#[serde(default)] #[serde(default)]
pub return_hidden_states: bool, pub background: bool,
/// Request ID for tracking /// Conversation ID for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub conversation_id: Option<String>,
/// Priority for the request
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
/// Extra key for classifying the request (e.g. cache_salt)
#[serde(skip_serializing_if = "Option::is_none")]
pub extra_key: Option<String>,
/// Whether to disallow logging for this request (e.g. due to ZDR)
#[serde(default)]
pub no_logs: bool,
/// Custom metric labels
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_labels: Option<HashMap<String, String>>,
/// Whether to return bytes for image generation
#[serde(default)]
pub return_bytes: bool,
/// Whether to return entropy
#[serde(default)]
pub return_entropy: bool,
/// Request ID for tracking (inherited from BaseReq in Python)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>, pub rid: Option<String>,
} }
...@@ -2358,7 +2467,7 @@ impl Normalizable for GenerateRequest { ...@@ -2358,7 +2467,7 @@ impl Normalizable for GenerateRequest {
fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> { fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
// Exactly one of text or input_ids must be provided // Exactly one of text or input_ids must be provided
// Note: input_embeds not yet supported in Rust implementation // Note: input_embeds not yet supported in Rust implementation
let has_text = req.text.is_some() || req.prompt.is_some(); let has_text = req.text.is_some();
let has_input_ids = req.input_ids.is_some(); let has_input_ids = req.input_ids.is_some();
let count = [has_text, has_input_ids].iter().filter(|&&x| x).count(); let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
...@@ -2389,18 +2498,11 @@ impl GenerationRequest for GenerateRequest { ...@@ -2389,18 +2498,11 @@ impl GenerationRequest for GenerateRequest {
} }
fn extract_text_for_routing(&self) -> String { fn extract_text_for_routing(&self) -> String {
// Check fields in priority order: text, prompt, inputs // Check fields in priority order: text, input_ids
if let Some(ref text) = self.text { if let Some(ref text) = self.text {
return text.clone(); return text.clone();
} }
if let Some(ref prompt) = self.prompt {
return match prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
};
}
if let Some(ref input_ids) = self.input_ids { if let Some(ref input_ids) = self.input_ids {
return match input_ids { return match input_ids {
InputIds::Single(ids) => ids InputIds::Single(ids) => ids
......
...@@ -877,7 +877,7 @@ impl ResponseProcessingStage { ...@@ -877,7 +877,7 @@ impl ResponseProcessingStage {
} }
// Non-streaming: Delegate to ResponseProcessor // Non-streaming: Delegate to ResponseProcessor
let request_logprobs = ctx.generate_request().return_logprob; let request_logprobs = ctx.generate_request().return_logprob.unwrap_or(false);
let generate_request = ctx.generate_request_arc(); let generate_request = ctx.generate_request_arc();
let stop_decoder = ctx let stop_decoder = ctx
......
...@@ -616,7 +616,7 @@ impl StreamingProcessor { ...@@ -616,7 +616,7 @@ impl StreamingProcessor {
generate_request: Arc<GenerateRequest>, generate_request: Arc<GenerateRequest>,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
) -> Response { ) -> Response {
let return_logprob = generate_request.return_logprob; let return_logprob = generate_request.return_logprob.unwrap_or(false);
// Create SSE channel // Create SSE channel
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>(); let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
......
...@@ -150,11 +150,6 @@ impl PDRouter { ...@@ -150,11 +150,6 @@ impl PDRouter {
} }
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> { fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
if let Some(StringOrArray::Array(arr)) = &req.prompt {
if !arr.is_empty() {
return Some(arr.len());
}
}
if let Some(text) = &req.text { if let Some(text) = &req.text {
if text.contains("[") && text.contains("]") { if text.contains("[") && text.contains("]") {
return None; return None;
...@@ -1061,18 +1056,10 @@ impl RouterTrait for PDRouter { ...@@ -1061,18 +1056,10 @@ impl RouterTrait for PDRouter {
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
let is_stream = body.stream; let is_stream = body.stream;
let return_logprob = body.return_logprob; let return_logprob = body.return_logprob.unwrap_or(false);
let request_text = if self.policies_need_request_text() { let request_text = if self.policies_need_request_text() {
body.text body.text.as_deref().map(|s| s.to_string())
.as_deref()
.or_else(|| {
body.prompt.as_ref().and_then(|p| match p {
StringOrArray::String(s) => Some(s.as_str()),
StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
})
})
.map(|s| s.to_string())
} else { } else {
None None
}; };
......
...@@ -598,15 +598,39 @@ async fn test_unsupported_endpoints() { ...@@ -598,15 +598,39 @@ async fn test_unsupported_endpoints() {
.unwrap(); .unwrap();
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
prompt: None,
text: Some("Hello world".to_string()), text: Some("Hello world".to_string()),
input_ids: None, input_ids: None,
input_embeds: None,
image_data: None,
video_data: None,
audio_data: None,
sampling_params: None, sampling_params: None,
stream: false, stream: false,
return_logprob: false, return_logprob: Some(false),
lora_path: None, logprob_start_len: None,
session_params: None, top_logprobs_num: None,
token_ids_logprob: None,
return_text_in_logprobs: false,
log_metrics: true,
return_hidden_states: false, return_hidden_states: false,
modalities: None,
session_params: None,
lora_path: None,
lora_id: None,
custom_logit_processor: None,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
bootstrap_pair_key: None,
data_parallel_rank: None,
background: false,
conversation_id: None,
priority: None,
extra_key: None,
no_logs: false,
custom_labels: None,
return_bytes: false,
return_entropy: false,
rid: None, rid: None,
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment