Unverified Commit 6d4fd882 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] minor code clean up and and refactoring (#8711)

parent f9f0138f
......@@ -35,7 +35,7 @@ pub struct PDRouter {
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: reqwest::Client,
pub http_client: Client,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
}
......@@ -206,51 +206,17 @@ impl PDRouter {
}
// Initialize cache-aware components if needed for prefill policy
let prefill_tree = if prefill_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with prefill workers
if let Some(cache_policy) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&prefill_workers);
}
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with prefill workers
for worker in &prefill_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
} else {
None
};
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;
// Initialize cache-aware components if needed for decode policy
let decode_tree = if decode_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with decode workers
if let Some(cache_policy) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&decode_workers);
}
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with decode workers
for worker in &decode_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
} else {
None
};
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Create a shared HTTP client for all operations
let http_client = reqwest::Client::builder()
let http_client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
......@@ -304,6 +270,35 @@ impl PDRouter {
})
}
// Helper function to initialize radix tree for cache-aware policies
fn initialize_radix_tree(
policy: &Arc<dyn LoadBalancingPolicy>,
workers: &[Box<dyn Worker>],
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
if let Some(cache_policy) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
// Initialize the policy's internal tree with workers
cache_policy.init_workers(workers);
let tree = Arc::new(Mutex::new(Tree::new()));
{
let tree_guard = tree
.lock()
.map_err(|e| format!("Failed to lock tree: {}", e))?;
for worker in workers {
tree_guard.insert("", worker.url());
}
}
Ok(Some(tree))
} else {
Ok(None)
}
}
// Route a typed generate request
pub async fn route_generate(
&self,
......@@ -329,7 +324,7 @@ impl PDRouter {
});
// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
......@@ -417,7 +412,7 @@ impl PDRouter {
.and_then(|content| content.as_str());
// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
......@@ -498,7 +493,7 @@ impl PDRouter {
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
......@@ -833,7 +828,6 @@ impl PDRouter {
// Select a pair of prefill and decode servers
async fn select_pd_pair(
&self,
_client: &Client,
request_text: Option<&str>,
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
// Get read locks for both worker lists
......@@ -998,7 +992,7 @@ impl PDRouter {
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
// Select a random worker pair using the policy
let (prefill, decode) = match self.select_pd_pair(client, None).await {
let (prefill, decode) = match self.select_pd_pair(None).await {
Ok(pair) => pair,
Err(e) => {
return (
......@@ -1921,8 +1915,7 @@ mod tests {
router.prefill_workers.write().unwrap().push(healthy_worker);
router.decode_workers.write().unwrap().push(decode_worker);
let client = reqwest::Client::new();
let result = router.select_pd_pair(&client, None).await;
let result = router.select_pd_pair(None).await;
assert!(result.is_ok());
let (prefill, _decode) = result.unwrap();
......@@ -1936,8 +1929,7 @@ mod tests {
async fn test_empty_worker_lists() {
let router = create_test_pd_router();
let client = reqwest::Client::new();
let result = router.select_pd_pair(&client, None).await;
let result = router.select_pd_pair(None).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("No prefill workers available"));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment