"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "47825b194e226d17cfb6ffa4a9f6c8409ac563ee"
Unverified Commit 2e901e89 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] dedicated prefill HTTP client and request-path optimizations (#8923)

parent d3be9710
...@@ -38,6 +38,8 @@ pub struct PDRouter { ...@@ -38,6 +38,8 @@ pub struct PDRouter {
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>, pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub client: Client, pub client: Client,
// Dedicated client for prefill fire-and-forget (non-logprob) requests
pub prefill_client: Client,
pub retry_config: RetryConfig, pub retry_config: RetryConfig,
_prefill_health_checker: Option<HealthChecker>, _prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>, _decode_health_checker: Option<HealthChecker>,
...@@ -255,6 +257,15 @@ impl PDRouter { ...@@ -255,6 +257,15 @@ impl PDRouter {
let decode_health_checker = let decode_health_checker =
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs); crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
// Build a dedicated prefill client for fire-and-forget semantics
let prefill_client = reqwest::Client::builder()
.pool_max_idle_per_host(0)
.http1_only()
.connect_timeout(Duration::from_millis(300))
.timeout(Duration::from_secs(2))
.build()
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
Ok(PDRouter { Ok(PDRouter {
prefill_workers, prefill_workers,
decode_workers, decode_workers,
...@@ -267,6 +278,7 @@ impl PDRouter { ...@@ -267,6 +278,7 @@ impl PDRouter {
worker_loads, worker_loads,
load_monitor_handle, load_monitor_handle,
client, client,
prefill_client,
retry_config, retry_config,
_prefill_health_checker: Some(prefill_health_checker), _prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker), _decode_health_checker: Some(decode_health_checker),
...@@ -365,41 +377,69 @@ impl PDRouter { ...@@ -365,41 +377,69 @@ impl PDRouter {
None None
} }
// Helper to create request with bootstrap fields // Helper to inject bootstrap fields into an existing JSON request value
fn create_request_with_bootstrap<T: serde::Serialize>( fn inject_bootstrap_into_value(
request: &T, mut original: Value,
prefill_worker: &dyn Worker, prefill_worker: &dyn Worker,
batch_size: Option<usize>, batch_size: Option<usize>,
) -> Result<serde_json::Value, serde_json::Error> { ) -> Result<Value, String> {
// Get bootstrap port from prefill worker
let bootstrap_port = match prefill_worker.worker_type() { let bootstrap_port = match prefill_worker.worker_type() {
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port, crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None, _ => None,
}; };
let hostname = super::pd_types::get_hostname(prefill_worker.url()); let hostname = super::pd_types::get_hostname(prefill_worker.url());
// Create optimized request with bootstrap fields let obj = original
if let Some(batch_size) = batch_size { .as_object_mut()
// Batch request .ok_or_else(|| "Request must be a JSON object".to_string())?;
let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap {
original: request, if let Some(n) = batch_size {
bootstrap_host: vec![hostname; batch_size], let mut hosts = Vec::with_capacity(n);
bootstrap_port: vec![bootstrap_port; batch_size], let mut ports = Vec::with_capacity(n);
bootstrap_room: (0..batch_size) let mut rooms = Vec::with_capacity(n);
.map(|_| super::pd_types::generate_room_id()) for _ in 0..n {
.collect(), hosts.push(hostname.clone());
}; ports.push(bootstrap_port);
serde_json::to_value(&request_with_bootstrap) rooms.push(super::pd_types::generate_room_id());
}
obj.insert(
"bootstrap_host".to_string(),
Value::Array(hosts.into_iter().map(serde_json::Value::from).collect()),
);
obj.insert(
"bootstrap_port".to_string(),
Value::Array(
ports
.into_iter()
.map(|p| match p {
Some(v) => serde_json::Value::from(v),
None => Value::Null,
})
.collect(),
),
);
obj.insert(
"bootstrap_room".to_string(),
Value::Array(rooms.into_iter().map(serde_json::Value::from).collect()),
);
} else { } else {
// Single request obj.insert(
let request_with_bootstrap = super::pd_types::RequestWithBootstrap { "bootstrap_host".to_string(),
original: request, serde_json::Value::from(hostname),
bootstrap_host: hostname, );
bootstrap_port, obj.insert(
bootstrap_room: super::pd_types::generate_room_id(), "bootstrap_port".to_string(),
}; match bootstrap_port {
serde_json::to_value(&request_with_bootstrap) Some(v) => serde_json::Value::from(v),
None => Value::Null,
},
);
obj.insert(
"bootstrap_room".to_string(),
serde_json::Value::from(super::pd_types::generate_room_id()),
);
} }
Ok(original)
} }
// Execute the dual dispatch to prefill and decode servers // Execute the dual dispatch to prefill and decode servers
...@@ -417,12 +457,15 @@ impl PDRouter { ...@@ -417,12 +457,15 @@ impl PDRouter {
// Update load tracking for both workers // Update load tracking for both workers
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests with headers // Build decode request with shared client
let prefill_request = let decode_request = self.build_post_with_headers(
self.build_request_with_headers(prefill.url(), route, &json_request, headers); &self.client,
decode.url(),
let decode_request = route,
self.build_request_with_headers(decode.url(), route, &json_request, headers); &json_request,
headers,
false,
);
// Send both requests concurrently // Send both requests concurrently
debug!( debug!(
...@@ -432,6 +475,15 @@ impl PDRouter { ...@@ -432,6 +475,15 @@ impl PDRouter {
); );
if return_logprob { if return_logprob {
// Build prefill request with shared client when we need response body
let prefill_request = self.build_post_with_headers(
&self.client,
prefill.url(),
route,
&json_request,
headers,
false,
);
// When we need logprobs, wait for both responses // When we need logprobs, wait for both responses
let (prefill_result, decode_result) = let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send()); tokio::join!(prefill_request.send(), decode_request.send());
...@@ -525,19 +577,27 @@ impl PDRouter { ...@@ -525,19 +577,27 @@ impl PDRouter {
} else { } else {
// When we don't need logprobs, only wait for decode response // When we don't need logprobs, only wait for decode response
// Send both requests concurrently but don't wait for prefill // Send both requests concurrently but don't wait for prefill
// Add headers to minimize response size when we don't need the body // Use dedicated prefill client with Connection: close
let prefill_future = prefill_request.header("Connection", "close").send(); let prefill_future = self
.build_post_with_headers(
&self.prefill_client,
prefill.url(),
route,
&json_request,
headers,
true,
)
.send();
let decode_future = decode_request.send(); let decode_future = decode_request.send();
tokio::spawn(async move { tokio::spawn(async move {
if let Ok(response) = prefill_future.await { if let Ok(response) = prefill_future.await {
// Consume with a short timeout to free connection quickly // Consume at most one small chunk with a very short timeout to advance flow control
let consume_future = async { let _ = tokio::time::timeout(Duration::from_millis(20), async {
let _ = response.bytes().await; let mut s = response.bytes_stream();
}; let _ = s.next().await;
})
// Give it 100ms to consume, then abandon .await;
let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await;
} }
}); });
...@@ -879,29 +939,34 @@ impl PDRouter { ...@@ -879,29 +939,34 @@ impl PDRouter {
Ok((prefill_status, prefill_body)) Ok((prefill_status, prefill_body))
} }
// Helper to build a request with headers copied from the original request fn build_post_with_headers(
fn build_request_with_headers(
&self, &self,
client: &reqwest::Client,
url: &str, url: &str,
route: &str, route: &str,
json_request: &Value, json_request: &serde_json::Value,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
connection_close: bool,
) -> reqwest::RequestBuilder { ) -> reqwest::RequestBuilder {
let mut request = self.client.post(api_path(url, route)).json(json_request); let mut request = client.post(api_path(url, route)).json(json_request);
if connection_close {
// Copy headers from original request (excluding content-type and content-length which are set by .json()) request = request.header("Connection", "close");
}
if let Some(headers) = headers { if let Some(headers) = headers {
for (name, value) in headers.iter() { for (name, value) in headers.iter() {
let name_str = name.as_str(); let name_lc = name.as_str().to_ascii_lowercase();
if name_str != "content-type" && name_str != "content-length" { // Whitelist important end-to-end headers, skip hop-by-hop
// Skip headers with non-ASCII values let forward = matches!(
if value.to_str().is_ok() { name_lc.as_str(),
request = request.header(name, value); "authorization" | "x-request-id" | "x-correlation-id"
) || name_lc.starts_with("x-request-id-");
if forward {
if let Ok(val) = value.to_str() {
request = request.header(name, val);
} }
} }
} }
} }
request request
} }
...@@ -1109,11 +1174,12 @@ impl RouterTrait for PDRouter { ...@@ -1109,11 +1174,12 @@ impl RouterTrait for PDRouter {
// Test prefill server's health_generate // Test prefill server's health_generate
let prefill_url = format!("{}/health_generate", prefill.url()); let prefill_url = format!("{}/health_generate", prefill.url());
let prefill_result = self.client.get(&prefill_url).send().await; let (prefill_result, decode_result) = tokio::join!(
self.client.get(&prefill_url).send(),
// Test decode server's health_generate self.client
let decode_url = format!("{}/health_generate", decode.url()); .get(&format!("{}/health_generate", decode.url()))
let decode_result = self.client.get(&decode_url).send().await; .send()
);
// Check results // Check results
let mut errors = Vec::new(); let mut errors = Vec::new();
...@@ -1399,10 +1465,13 @@ impl RouterTrait for PDRouter { ...@@ -1399,10 +1465,13 @@ impl RouterTrait for PDRouter {
decode.url() decode.url()
); );
// Create optimized request with bootstrap fields
let batch_size = Self::get_generate_batch_size(body); let batch_size = Self::get_generate_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { let original = match serde_json::to_value(body) {
Ok(json) => json, Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e), Err(e) => return Self::handle_serialization_error(e),
}; };
...@@ -1464,10 +1533,13 @@ impl RouterTrait for PDRouter { ...@@ -1464,10 +1533,13 @@ impl RouterTrait for PDRouter {
decode.url() decode.url()
); );
// Create optimized request with bootstrap fields
let batch_size = Self::get_chat_batch_size(body); let batch_size = Self::get_chat_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { let original = match serde_json::to_value(body) {
Ok(json) => json, Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e), Err(e) => return Self::handle_serialization_error(e),
}; };
...@@ -1519,10 +1591,13 @@ impl RouterTrait for PDRouter { ...@@ -1519,10 +1591,13 @@ impl RouterTrait for PDRouter {
decode.url() decode.url()
); );
// Create optimized request with bootstrap fields
let batch_size = Self::get_completion_batch_size(body); let batch_size = Self::get_completion_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { let original = match serde_json::to_value(body) {
Ok(json) => json, Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e), Err(e) => return Self::handle_serialization_error(e),
}; };
...@@ -1771,6 +1846,7 @@ mod tests { ...@@ -1771,6 +1846,7 @@ mod tests {
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None, load_monitor_handle: None,
client: Client::new(), client: Client::new(),
prefill_client: Client::new(),
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
_prefill_health_checker: None, _prefill_health_checker: None,
_decode_health_checker: None, _decode_health_checker: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment