Unverified Commit 2e901e89 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] dedicated prefill HTTP client and request-path optimizations (#8923)

parent d3be9710
......@@ -38,6 +38,8 @@ pub struct PDRouter {
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub client: Client,
// Dedicated client for prefill fire-and-forget (non-logprob) requests
pub prefill_client: Client,
pub retry_config: RetryConfig,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
......@@ -255,6 +257,15 @@ impl PDRouter {
let decode_health_checker =
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
// Build a dedicated prefill client for fire-and-forget semantics
let prefill_client = reqwest::Client::builder()
.pool_max_idle_per_host(0)
.http1_only()
.connect_timeout(Duration::from_millis(300))
.timeout(Duration::from_secs(2))
.build()
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
Ok(PDRouter {
prefill_workers,
decode_workers,
......@@ -267,6 +278,7 @@ impl PDRouter {
worker_loads,
load_monitor_handle,
client,
prefill_client,
retry_config,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
......@@ -365,41 +377,69 @@ impl PDRouter {
None
}
// Helper to create request with bootstrap fields
fn create_request_with_bootstrap<T: serde::Serialize>(
request: &T,
// Helper to inject bootstrap fields into an existing JSON request value
fn inject_bootstrap_into_value(
mut original: Value,
prefill_worker: &dyn Worker,
batch_size: Option<usize>,
) -> Result<serde_json::Value, serde_json::Error> {
// Get bootstrap port from prefill worker
) -> Result<Value, String> {
let bootstrap_port = match prefill_worker.worker_type() {
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = super::pd_types::get_hostname(prefill_worker.url());
// Create optimized request with bootstrap fields
if let Some(batch_size) = batch_size {
// Batch request
let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap {
original: request,
bootstrap_host: vec![hostname; batch_size],
bootstrap_port: vec![bootstrap_port; batch_size],
bootstrap_room: (0..batch_size)
.map(|_| super::pd_types::generate_room_id())
.collect(),
};
serde_json::to_value(&request_with_bootstrap)
let obj = original
.as_object_mut()
.ok_or_else(|| "Request must be a JSON object".to_string())?;
if let Some(n) = batch_size {
let mut hosts = Vec::with_capacity(n);
let mut ports = Vec::with_capacity(n);
let mut rooms = Vec::with_capacity(n);
for _ in 0..n {
hosts.push(hostname.clone());
ports.push(bootstrap_port);
rooms.push(super::pd_types::generate_room_id());
}
obj.insert(
"bootstrap_host".to_string(),
Value::Array(hosts.into_iter().map(serde_json::Value::from).collect()),
);
obj.insert(
"bootstrap_port".to_string(),
Value::Array(
ports
.into_iter()
.map(|p| match p {
Some(v) => serde_json::Value::from(v),
None => Value::Null,
})
.collect(),
),
);
obj.insert(
"bootstrap_room".to_string(),
Value::Array(rooms.into_iter().map(serde_json::Value::from).collect()),
);
} else {
// Single request
let request_with_bootstrap = super::pd_types::RequestWithBootstrap {
original: request,
bootstrap_host: hostname,
bootstrap_port,
bootstrap_room: super::pd_types::generate_room_id(),
};
serde_json::to_value(&request_with_bootstrap)
obj.insert(
"bootstrap_host".to_string(),
serde_json::Value::from(hostname),
);
obj.insert(
"bootstrap_port".to_string(),
match bootstrap_port {
Some(v) => serde_json::Value::from(v),
None => Value::Null,
},
);
obj.insert(
"bootstrap_room".to_string(),
serde_json::Value::from(super::pd_types::generate_room_id()),
);
}
Ok(original)
}
// Execute the dual dispatch to prefill and decode servers
......@@ -417,12 +457,15 @@ impl PDRouter {
// Update load tracking for both workers
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests with headers
let prefill_request =
self.build_request_with_headers(prefill.url(), route, &json_request, headers);
let decode_request =
self.build_request_with_headers(decode.url(), route, &json_request, headers);
// Build decode request with shared client
let decode_request = self.build_post_with_headers(
&self.client,
decode.url(),
route,
&json_request,
headers,
false,
);
// Send both requests concurrently
debug!(
......@@ -432,6 +475,15 @@ impl PDRouter {
);
if return_logprob {
// Build prefill request with shared client when we need response body
let prefill_request = self.build_post_with_headers(
&self.client,
prefill.url(),
route,
&json_request,
headers,
false,
);
// When we need logprobs, wait for both responses
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
......@@ -525,19 +577,27 @@ impl PDRouter {
} else {
// When we don't need logprobs, only wait for decode response
// Send both requests concurrently but don't wait for prefill
// Add headers to minimize response size when we don't need the body
let prefill_future = prefill_request.header("Connection", "close").send();
// Use dedicated prefill client with Connection: close
let prefill_future = self
.build_post_with_headers(
&self.prefill_client,
prefill.url(),
route,
&json_request,
headers,
true,
)
.send();
let decode_future = decode_request.send();
tokio::spawn(async move {
if let Ok(response) = prefill_future.await {
// Consume with a short timeout to free connection quickly
let consume_future = async {
let _ = response.bytes().await;
};
// Give it 100ms to consume, then abandon
let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await;
// Consume at most one small chunk with a very short timeout to advance flow control
let _ = tokio::time::timeout(Duration::from_millis(20), async {
let mut s = response.bytes_stream();
let _ = s.next().await;
})
.await;
}
});
......@@ -879,29 +939,34 @@ impl PDRouter {
Ok((prefill_status, prefill_body))
}
// Helper to build a request with headers copied from the original request
fn build_request_with_headers(
fn build_post_with_headers(
&self,
client: &reqwest::Client,
url: &str,
route: &str,
json_request: &Value,
json_request: &serde_json::Value,
headers: Option<&HeaderMap>,
connection_close: bool,
) -> reqwest::RequestBuilder {
let mut request = self.client.post(api_path(url, route)).json(json_request);
// Copy headers from original request (excluding content-type and content-length which are set by .json())
let mut request = client.post(api_path(url, route)).json(json_request);
if connection_close {
request = request.header("Connection", "close");
}
if let Some(headers) = headers {
for (name, value) in headers.iter() {
let name_str = name.as_str();
if name_str != "content-type" && name_str != "content-length" {
// Skip headers with non-ASCII values
if value.to_str().is_ok() {
request = request.header(name, value);
let name_lc = name.as_str().to_ascii_lowercase();
// Whitelist important end-to-end headers, skip hop-by-hop
let forward = matches!(
name_lc.as_str(),
"authorization" | "x-request-id" | "x-correlation-id"
) || name_lc.starts_with("x-request-id-");
if forward {
if let Ok(val) = value.to_str() {
request = request.header(name, val);
}
}
}
}
request
}
......@@ -1109,11 +1174,12 @@ impl RouterTrait for PDRouter {
// Test prefill server's health_generate
let prefill_url = format!("{}/health_generate", prefill.url());
let prefill_result = self.client.get(&prefill_url).send().await;
// Test decode server's health_generate
let decode_url = format!("{}/health_generate", decode.url());
let decode_result = self.client.get(&decode_url).send().await;
let (prefill_result, decode_result) = tokio::join!(
self.client.get(&prefill_url).send(),
self.client
.get(&format!("{}/health_generate", decode.url()))
.send()
);
// Check results
let mut errors = Vec::new();
......@@ -1399,10 +1465,13 @@ impl RouterTrait for PDRouter {
decode.url()
);
// Create optimized request with bootstrap fields
let batch_size = Self::get_generate_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
Ok(json) => json,
let original = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
......@@ -1464,10 +1533,13 @@ impl RouterTrait for PDRouter {
decode.url()
);
// Create optimized request with bootstrap fields
let batch_size = Self::get_chat_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
Ok(json) => json,
let original = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
......@@ -1519,10 +1591,13 @@ impl RouterTrait for PDRouter {
decode.url()
);
// Create optimized request with bootstrap fields
let batch_size = Self::get_completion_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
Ok(json) => json,
let original = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
......@@ -1771,6 +1846,7 @@ mod tests {
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None,
client: Client::new(),
prefill_client: Client::new(),
retry_config: RetryConfig::default(),
_prefill_health_checker: None,
_decode_health_checker: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment