Unverified Commit f6e07f27 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] fix pd model completion request (#8303)

parent 5dd0f870
...@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest { ...@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
logit_bias: None, logit_bias: None,
user: None, user: None,
seed: None, seed: None,
other: serde_json::Map::new(),
} }
} }
......
...@@ -91,6 +91,10 @@ pub struct CompletionRequest { ...@@ -91,6 +91,10 @@ pub struct CompletionRequest {
/// If specified, our system will make a best effort to sample deterministically /// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>, pub seed: Option<i64>,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>,
} }
impl GenerationRequest for CompletionRequest { impl GenerationRequest for CompletionRequest {
......
...@@ -420,6 +420,77 @@ impl PDRouter { ...@@ -420,6 +420,77 @@ impl PDRouter {
.await .await
} }
// Route a completion request while preserving OpenAI format
pub async fn route_completion(
&self,
client: &reqwest::Client,
req: &HttpRequest,
mut typed_req: CompletionRequest,
route: &str,
) -> HttpResponse {
let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream;
let return_logprob = typed_req.logprobs.is_some();
// Extract text for cache-aware routing from the typed request
let request_text = match &typed_req.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair: {}", e);
RouterMetrics::record_pd_error("server_selection");
return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e));
}
};
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
RouterMetrics::record_pd_error("bootstrap_injection");
return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e));
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request: {}", e);
return HttpResponse::InternalServerError().body("Failed to serialize request");
}
};
// Execute dual dispatch
self.execute_dual_dispatch(
client,
req,
json_with_bootstrap,
route,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
// Execute the dual dispatch to prefill and decode servers // Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch( async fn execute_dual_dispatch(
...@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter { ...@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
req: &HttpRequest, req: &HttpRequest,
body: serde_json::Value, body: serde_json::Value,
) -> HttpResponse { ) -> HttpResponse {
match serde_json::from_value::<CompletionRequest>(body.clone()) { match serde_json::from_value::<CompletionRequest>(body) {
Ok(openai_req) => { Ok(openai_req) => {
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput) // Use the new method that preserves OpenAI format
let pd_req = openai_req.to_pd_request(); PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
}
Err(_) => {
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
match serde_json::from_value::<GenerateReqInput>(body) {
Ok(pd_req) => {
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
}
Err(e) => {
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
}
}
} }
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
} }
} }
......
// Essential PDLB types extracted for PD routing // Essential PDLB types extracted for PD routing
use crate::core::{Worker, WorkerType}; use crate::core::{Worker, WorkerType};
use crate::openai_api_types::{CompletionRequest, StringOrArray};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
...@@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput { ...@@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput {
self.bootstrap_room = Some(bootstrap_room); self.bootstrap_room = Some(bootstrap_room);
} }
} }
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl Bootstrap for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
if let StringOrArray::Array(prompts) = &self.prompt {
if prompts.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(prompts.len()));
}
// Single string prompt
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
self.other.insert("bootstrap_host".to_string(), host_value);
}
// Insert bootstrap_port - it serializes correctly whether Single or Batch
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
self.other.insert("bootstrap_port".to_string(), port_value);
}
// Insert bootstrap_room - it serializes correctly whether Single or Batch
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
self.other.insert("bootstrap_room".to_string(), room_value);
}
}
}
#[cfg(test)]
mod bootstrap_tests {
use super::*;
use crate::openai_api_types::StringOrArray;
#[test]
fn test_completion_batch_size_with_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};
// Should return batch size for array prompt
assert_eq!(req.get_batch_size().unwrap(), Some(2));
}
#[test]
fn test_completion_batch_size_with_single_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};
// Should return None for single prompt
assert_eq!(req.get_batch_size().unwrap(), None);
}
#[test]
fn test_completion_batch_size_with_n_parameter() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
n: Some(3),
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};
// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!(req.get_batch_size().unwrap(), None);
}
#[test]
fn test_completion_bootstrap_single_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};
// Set bootstrap info - should always use single values
req.set_bootstrap_info(
BootstrapHost::Single("test-server".to_string()),
BootstrapPort::Single(Some(5678)),
BootstrapRoom::Single(12345),
);
// Verify single values were created
assert!(req.other.get("bootstrap_host").unwrap().is_string());
assert!(req.other.get("bootstrap_port").unwrap().is_number());
assert!(req.other.get("bootstrap_room").unwrap().is_number());
assert_eq!(
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
"test-server"
);
assert_eq!(
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
5678
);
assert_eq!(
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
12345
);
}
#[test]
fn test_completion_bootstrap_array_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};
// Set bootstrap info with arrays
req.set_bootstrap_info(
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
BootstrapPort::Batch(vec![Some(5678); 2]),
BootstrapRoom::Batch(vec![12345, 67890]),
);
// Verify arrays were created correctly
assert!(req.other.get("bootstrap_host").unwrap().is_array());
assert!(req.other.get("bootstrap_port").unwrap().is_array());
assert!(req.other.get("bootstrap_room").unwrap().is_array());
let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
assert_eq!(hosts.len(), 2);
assert_eq!(hosts[0].as_str().unwrap(), "test-server");
let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
assert_eq!(ports.len(), 2);
assert_eq!(ports[0].as_u64().unwrap(), 5678);
let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
assert_eq!(rooms.len(), 2);
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
}
}
...@@ -648,6 +648,7 @@ mod tests { ...@@ -648,6 +648,7 @@ mod tests {
user: None, user: None,
seed: None, seed: None,
suffix: None, suffix: None,
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -687,6 +688,7 @@ mod tests { ...@@ -687,6 +688,7 @@ mod tests {
user: None, user: None,
seed: None, seed: None,
suffix: None, suffix: None,
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -725,6 +727,7 @@ mod tests { ...@@ -725,6 +727,7 @@ mod tests {
user: Some("user123".to_string()), user: Some("user123".to_string()),
seed: Some(42), seed: Some(42),
suffix: Some("...".to_string()), suffix: Some("...".to_string()),
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -768,6 +771,7 @@ mod tests { ...@@ -768,6 +771,7 @@ mod tests {
user: None, user: None,
seed: None, seed: None,
suffix: None, suffix: None,
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -799,6 +803,7 @@ mod tests { ...@@ -799,6 +803,7 @@ mod tests {
user: None, user: None,
seed: None, seed: None,
suffix: None, suffix: None,
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
......
...@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() { ...@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() {
logit_bias: None, logit_bias: None,
user: None, user: None,
seed: None, seed: None,
other: serde_json::Map::new(),
}; };
// Test serialization works // Test serialization works
...@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() { ...@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() {
logit_bias: None, logit_bias: None,
user: None, user: None,
seed: None, seed: None,
other: serde_json::Map::new(),
}; };
// Test PD adaptation (should not panic) // Test PD adaptation (should not panic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment