Unverified Commit 5dccf697 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] create worker removal step and clean up worker manager (#11921)

parent eec9e471
...@@ -85,6 +85,8 @@ def _popen_launch_router( ...@@ -85,6 +85,8 @@ def _popen_launch_router(
str(prom_port), str(prom_port),
"--router-prometheus-host", "--router-prometheus-host",
"127.0.0.1", "127.0.0.1",
"--router-log-level",
"warn",
] ]
proc = subprocess.Popen(cmd) proc = subprocess.Popen(cmd)
......
import time
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
import requests import requests
def _wait_for_workers(
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
) -> None:
"""Poll /workers endpoint until expected number of workers are registered."""
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
if r.status_code == 200:
workers = r.json().get("workers", [])
if len(workers) >= expected_count:
return
except requests.RequestException:
pass
time.sleep(0.5)
raise TimeoutError(
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
)
@pytest.mark.e2e @pytest.mark.e2e
def test_embeddings_basic( def test_embeddings_basic(
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
...@@ -12,8 +34,11 @@ def test_embeddings_basic( ...@@ -12,8 +34,11 @@ def test_embeddings_basic(
worker_url = e2e_primary_embedding_worker.url worker_url = e2e_primary_embedding_worker.url
# Attach embedding worker to router-only instance # Attach embedding worker to router-only instance
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180) r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
r.raise_for_status() assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
# Simple embedding request with two inputs # Simple embedding request with two inputs
payload = { payload = {
......
...@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str): ...@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
"--policy", "--policy",
"round_robin", "round_robin",
"--pd-disaggregation", "--pd-disaggregation",
"--log-level",
"warn",
] ]
for url, bport in prefill: for url, bport in prefill:
cmd += ["--prefill", url, str(bport)] cmd += ["--prefill", url, str(bport)]
......
...@@ -8,13 +8,39 @@ import requests ...@@ -8,13 +8,39 @@ import requests
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
def _wait_for_workers(
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
) -> None:
"""Poll /workers endpoint until expected number of workers are registered."""
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
if r.status_code == 200:
workers = r.json().get("workers", [])
if len(workers) >= expected_count:
return
except requests.RequestException:
pass
time.sleep(0.5)
raise TimeoutError(
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
)
@pytest.mark.e2e @pytest.mark.e2e
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model): def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance # Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
base = e2e_router_only_rr.url base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2: for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180) r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
r.raise_for_status() assert (
r.status_code == 202
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered
_wait_for_workers(base, expected_count=2, timeout=60.0)
args = SimpleNamespace( args = SimpleNamespace(
base_url=base, base_url=base,
...@@ -35,8 +61,13 @@ def test_genai_bench( ...@@ -35,8 +61,13 @@ def test_genai_bench(
"""Attach a worker to the regular router and run a short genai-bench.""" """Attach a worker to the regular router and run a short genai-bench."""
base = e2e_router_only_rr.url base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2: for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180) r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
r.raise_for_status() assert (
r.status_code == 202
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered
_wait_for_workers(base, expected_count=2, timeout=60.0)
genai_bench_runner( genai_bench_runner(
router_url=base, router_url=base,
...@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_ ...@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
base = e2e_router_only_rr.url base = e2e_router_only_rr.url
worker_url = e2e_primary_worker.url worker_url = e2e_primary_worker.url
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180) r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
r.raise_for_status() assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
with requests.Session() as s: with requests.Session() as s:
for i in range(8): for i in range(8):
...@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_ ...@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
r.raise_for_status() r.raise_for_status()
# Remove the worker # Remove the worker
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60) from urllib.parse import quote
r.raise_for_status()
encoded_url = quote(worker_url, safe="")
r = requests.delete(f"{base}/workers/{encoded_url}", timeout=60)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
@pytest.mark.e2e @pytest.mark.e2e
...@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m ...@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
base = e2e_router_only_rr.url base = e2e_router_only_rr.url
worker = e2e_primary_worker worker = e2e_primary_worker
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180) r = requests.post(f"{base}/workers", json={"url": worker.url}, timeout=180)
r.raise_for_status() assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
def killer(): def killer():
time.sleep(10) time.sleep(10)
...@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key( ...@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers # Attach worker; router should expand to dp_size logical workers
r = requests.post( r = requests.post(
f"{router_url}/add_worker", f"{router_url}/workers",
params={"url": worker_url, "api_key": api_key}, json={"url": worker_url, "api_key": api_key},
headers={"Authorization": f"Bearer {api_key}"}, headers={"Authorization": f"Bearer {api_key}"},
timeout=180, timeout=180,
) )
r.raise_for_status() assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered and expanded
_wait_for_workers(
router_url,
expected_count=2,
timeout=60.0,
headers={"Authorization": f"Bearer {api_key}"},
)
# Verify the expanded workers have correct URLs
r = requests.get( r = requests.get(
f"{router_url}/list_workers", f"{router_url}/workers",
headers={"Authorization": f"Bearer {api_key}"}, headers={"Authorization": f"Bearer {api_key}"},
timeout=30, timeout=30,
) )
r.raise_for_status() r.raise_for_status()
urls = r.json().get("urls", []) workers = r.json().get("workers", [])
urls = [w["url"] for w in workers]
assert len(urls) == 2 assert len(urls) == 2
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"} assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
......
...@@ -267,6 +267,8 @@ def popen_launch_workers_and_router( ...@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
policy, policy,
"--model-path", "--model-path",
model, model,
"--log-level",
"warn",
] ]
# Add worker URLs # Add worker URLs
......
...@@ -133,19 +133,90 @@ class RouterManager: ...@@ -133,19 +133,90 @@ class RouterManager:
time.sleep(0.2) time.sleep(0.2)
raise TimeoutError(f"Router at {base_url} did not become healthy") raise TimeoutError(f"Router at {base_url} did not become healthy")
def add_worker(self, base_url: str, worker_url: str) -> None: def add_worker(self, base_url: str, worker_url: str, timeout: float = 30.0) -> None:
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url}) r = requests.post(f"{base_url}/workers", json={"url": worker_url})
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}" assert (
r.status_code == 202
), f"add_worker failed: {r.status_code} {r.text}" # ACCEPTED status
def remove_worker(self, base_url: str, worker_url: str) -> None: # Poll until worker is actually added and healthy
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url}) from urllib.parse import quote
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
encoded_url = quote(worker_url, safe="")
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
if r.status_code == 200:
data = r.json()
# Check if registration job failed
job_status = data.get("job_status")
if job_status and job_status.get("state") == "failed":
raise RuntimeError(
f"Worker registration failed: {job_status.get('message', 'Unknown error')}"
)
# Check if worker is healthy and registered (not just in job queue)
if data.get("is_healthy", False):
return
# Worker not ready yet, continue polling
except requests.RequestException:
pass
time.sleep(0.1)
raise TimeoutError(
f"Worker {worker_url} was not added and healthy after {timeout}s"
)
def remove_worker(
self, base_url: str, worker_url: str, timeout: float = 30.0
) -> None:
# URL encode the worker_url for path parameter
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
r = requests.delete(f"{base_url}/workers/{encoded_url}")
assert (
r.status_code == 202
), f"remove_worker failed: {r.status_code} {r.text}" # ACCEPTED status
# Poll until worker is actually removed (GET returns 404) or timeout
start = time.time()
last_status = None
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
if r.status_code == 404:
# Worker successfully removed
return
elif r.status_code == 200:
# Check if removal job failed
data = r.json()
job_status = data.get("job_status")
if job_status:
last_status = job_status
if job_status.get("state") == "failed":
raise RuntimeError(
f"Worker removal failed: {job_status.get('message', 'Unknown error')}"
)
# Worker still being processed, continue polling
except requests.RequestException:
pass
time.sleep(0.1)
# Provide detailed timeout error with last known status
error_msg = f"Worker {worker_url} was not removed after {timeout}s"
if last_status:
error_msg += f". Last job status: {last_status}"
raise TimeoutError(error_msg)
def list_workers(self, base_url: str) -> list[str]: def list_workers(self, base_url: str) -> list[str]:
r = requests.get(f"{base_url}/list_workers") r = requests.get(f"{base_url}/workers")
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}" assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
data = r.json() data = r.json()
return data.get("urls", []) # Extract URLs from WorkerInfo objects
workers = data.get("workers", [])
return [w["url"] for w in workers]
def stop_all(self): def stop_all(self):
for p in self._children: for p in self._children:
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import subprocess import subprocess
import time import time
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Tuple from typing import Dict, Iterable, List, Optional, Tuple
import pytest import pytest
import requests import requests
...@@ -84,7 +84,7 @@ def mock_workers(): ...@@ -84,7 +84,7 @@ def mock_workers():
procs: List[subprocess.Popen] = [] procs: List[subprocess.Popen] = []
def _start(n: int, args: List[str] | None = None): def _start(n: int, args: Optional[List[str]] = None):
args = args or [] args = args or []
new_procs: List[subprocess.Popen] = [] new_procs: List[subprocess.Popen] = []
urls: List[str] = [] urls: List[str] = []
......
...@@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn}; ...@@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn};
use crate::{ use crate::{
config::{RouterConfig, RoutingMode}, config::{RouterConfig, RoutingMode},
core::{ core::workflow::{
workflow::{ steps::WorkerRemovalRequest, WorkflowContext, WorkflowEngine, WorkflowId,
WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus, WorkflowInstanceId, WorkflowStatus,
},
WorkerManager,
}, },
metrics::RouterMetrics, metrics::RouterMetrics,
protocols::worker_spec::{JobStatus, WorkerConfigRequest}, protocols::worker_spec::{JobStatus, WorkerConfigRequest},
...@@ -320,11 +318,29 @@ impl JobQueue { ...@@ -320,11 +318,29 @@ impl JobQueue {
.await .await
} }
Job::RemoveWorker { url } => { Job::RemoveWorker { url } => {
let result = WorkerManager::remove_worker(url, context); let engine = context
.workflow_engine
.get()
.ok_or_else(|| "Workflow engine not initialized".to_string())?;
let instance_id = Self::start_worker_removal_workflow(engine, url, context).await?;
debug!(
"Started worker removal workflow for {} (instance: {})",
url, instance_id
);
let timeout_duration = Duration::from_secs(30);
let result =
Self::wait_for_workflow_completion(engine, instance_id, url, timeout_duration)
.await;
// Clean up job status when removing worker // Clean up job status when removing worker
if let Some(queue) = context.worker_job_queue.get() { if let Some(queue) = context.worker_job_queue.get() {
queue.remove_status(url); queue.remove_status(url);
} }
result result
} }
Job::InitializeWorkersFromConfig { router_config } => { Job::InitializeWorkersFromConfig { router_config } => {
...@@ -424,6 +440,27 @@ impl JobQueue { ...@@ -424,6 +440,27 @@ impl JobQueue {
.map_err(|e| format!("Failed to start worker registration workflow: {:?}", e)) .map_err(|e| format!("Failed to start worker registration workflow: {:?}", e))
} }
/// Start worker removal workflow
async fn start_worker_removal_workflow(
engine: &Arc<WorkflowEngine>,
url: &str,
context: &Arc<AppContext>,
) -> Result<WorkflowInstanceId, String> {
let removal_request = WorkerRemovalRequest {
url: url.to_string(),
dp_aware: context.router_config.dp_aware,
};
let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new());
workflow_context.set("removal_request", removal_request);
workflow_context.set_arc("app_context", Arc::clone(context));
engine
.start_workflow(WorkflowId::new("worker_removal"), workflow_context)
.await
.map_err(|e| format!("Failed to start worker removal workflow: {:?}", e))
}
/// Wait for workflow completion with adaptive polling /// Wait for workflow completion with adaptive polling
async fn wait_for_workflow_completion( async fn wait_for_workflow_completion(
engine: &Arc<WorkflowEngine>, engine: &Arc<WorkflowEngine>,
......
...@@ -29,5 +29,5 @@ pub use worker::{ ...@@ -29,5 +29,5 @@ pub use worker::{
Worker, WorkerFactory, WorkerLoadGuard, WorkerType, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
}; };
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager}; pub use worker_manager::{LoadMonitor, WorkerManager};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
This diff is collapsed.
...@@ -14,5 +14,5 @@ pub use engine::WorkflowEngine; ...@@ -14,5 +14,5 @@ pub use engine::WorkflowEngine;
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent}; pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
pub use executor::{FunctionStep, StepExecutor}; pub use executor::{FunctionStep, StepExecutor};
pub use state::WorkflowStateStore; pub use state::WorkflowStateStore;
pub use steps::create_worker_registration_workflow; pub use steps::{create_worker_registration_workflow, create_worker_removal_workflow};
pub use types::*; pub use types::*;
...@@ -2,11 +2,17 @@ ...@@ -2,11 +2,17 @@
//! //!
//! This module contains concrete step implementations for various workflows: //! This module contains concrete step implementations for various workflows:
//! - Worker registration and activation //! - Worker registration and activation
//! - Worker removal
//! - Future: Tokenizer fetching, LoRA updates, etc. //! - Future: Tokenizer fetching, LoRA updates, etc.
pub mod worker_registration; pub mod worker_registration;
pub mod worker_removal;
pub use worker_registration::{ pub use worker_registration::{
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep, create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep, DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
}; };
pub use worker_removal::{
create_worker_removal_workflow, FindWorkersToRemoveStep, RemoveFromPolicyRegistryStep,
RemoveFromWorkerRegistryStep, UpdateRemainingPoliciesStep, WorkerRemovalRequest,
};
...@@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; ...@@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
use async_trait::async_trait; use async_trait::async_trait;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::{ use crate::{
core::{ core::{
workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode,
DPAwareWorkerBuilder, DpInfo, HealthConfig, Worker, WorkerManager, WorkerType, DPAwareWorkerBuilder, HealthConfig, Worker, WorkerType,
}, },
grpc_client::SglangSchedulerClient, grpc_client::SglangSchedulerClient,
protocols::worker_spec::WorkerConfigRequest, protocols::worker_spec::WorkerConfigRequest,
...@@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| { ...@@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
.expect("Failed to create HTTP client") .expect("Failed to create HTTP client")
}); });
/// Server information returned from worker endpoints
#[derive(Debug, Clone, Deserialize, Serialize)]
struct ServerInfo {
#[serde(alias = "model")]
model_id: Option<String>,
model_path: Option<String>,
dp_size: Option<usize>,
version: Option<String>,
max_batch_size: Option<usize>,
max_total_tokens: Option<usize>,
max_prefill_tokens: Option<usize>,
max_running_requests: Option<usize>,
max_num_reqs: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct DpInfo {
pub dp_size: usize,
pub model_id: String,
}
/// Parse server info from JSON response using serde
fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
serde_json::from_value(json).map_err(|e| format!("Failed to parse server info: {}", e))
}
/// Get server info from /get_server_info endpoint
async fn get_server_info(url: &str, api_key: Option<&str>) -> Result<ServerInfo, String> {
let base_url = url.trim_end_matches('/');
let server_info_url = format!("{}/get_server_info", base_url);
let mut req = HTTP_CLIENT.get(&server_info_url);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}
let response = req
.send()
.await
.map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?;
if !response.status().is_success() {
return Err(format!(
"Server returned status {} from {}",
response.status(),
server_info_url
));
}
let json = response
.json::<Value>()
.await
.map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?;
parse_server_info(json)
}
/// Get DP info for a worker URL
async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result<DpInfo, String> {
let info = get_server_info(url, api_key).await?;
let dp_size = info
.dp_size
.ok_or_else(|| format!("No dp_size in response from {}", url))?;
let model_id = info
.model_id
.or_else(|| {
info.model_path
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
})
.unwrap_or_else(|| "unknown".to_string());
Ok(DpInfo { dp_size, model_id })
}
/// Helper: Strip protocol prefix from URL /// Helper: Strip protocol prefix from URL
fn strip_protocol(url: &str) -> String { fn strip_protocol(url: &str) -> String {
url.trim_start_matches("http://") url.trim_start_matches("http://")
...@@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin ...@@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin
Ok(()) Ok(())
} }
/// Helper: Fetch HTTP metadata
async fn fetch_http_metadata(
url: &str,
api_key: Option<&str>,
) -> Result<HashMap<String, String>, String> {
let clean_url = strip_protocol(url);
let info_url = if clean_url.starts_with("http://") || clean_url.starts_with("https://") {
format!("{}/get_server_info", clean_url)
} else {
format!("http://{}/get_server_info", clean_url)
};
let mut request = HTTP_CLIENT.get(&info_url);
if let Some(key) = api_key {
request = request.header("Authorization", format!("Bearer {}", key));
}
let response = request
.send()
.await
.map_err(|e| format!("Failed to fetch HTTP metadata: {}", e))?;
let server_info: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse HTTP metadata: {}", e))?;
let mut labels = HashMap::new();
if let Some(model_path) = server_info.get("model_path").and_then(|v| v.as_str()) {
if !model_path.is_empty() {
labels.insert("model_path".to_string(), model_path.to_string());
}
}
if let Some(tokenizer_path) = server_info.get("tokenizer_path").and_then(|v| v.as_str()) {
if !tokenizer_path.is_empty() {
labels.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
}
}
Ok(labels)
}
/// Helper: Fetch gRPC metadata /// Helper: Fetch gRPC metadata
async fn fetch_grpc_metadata(url: &str) -> Result<HashMap<String, String>, String> { async fn fetch_grpc_metadata(url: &str) -> Result<HashMap<String, String>, String> {
let grpc_url = if url.starts_with("grpc://") { let grpc_url = if url.starts_with("grpc://") {
...@@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep { ...@@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep {
let discovered_labels = match connection_mode.as_ref() { let discovered_labels = match connection_mode.as_ref() {
ConnectionMode::Http => { ConnectionMode::Http => {
fetch_http_metadata(&config.url, config.api_key.as_deref()).await match get_server_info(&config.url, config.api_key.as_deref()).await {
Ok(server_info) => {
let mut labels = HashMap::new();
if let Some(model_path) = server_info.model_path {
if !model_path.is_empty() {
labels.insert("model_path".to_string(), model_path);
}
}
Ok(labels)
}
Err(e) => Err(e),
}
} }
ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await, ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await,
} }
...@@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep { ...@@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep {
debug!("Discovering DP info for {} (DP-aware)", config.url); debug!("Discovering DP info for {} (DP-aware)", config.url);
// Get DP info from worker // Get DP info from worker
let dp_info = WorkerManager::get_dp_info(&config.url, config.api_key.as_deref()) let dp_info = get_dp_info(&config.url, config.api_key.as_deref())
.await .await
.map_err(|e| WorkflowError::StepFailed { .map_err(|e| WorkflowError::StepFailed {
step_id: StepId::new("discover_dp_info"), step_id: StepId::new("discover_dp_info"),
...@@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep { ...@@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep {
); );
// Store DP info in context // Store DP info in context
context.set("dp_info", Arc::new(dp_info)); context.set("dp_info", dp_info);
Ok(StepResult::Success) Ok(StepResult::Success)
} }
...@@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep { ...@@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep {
} }
// Store workers (plural) and labels in context // Store workers (plural) and labels in context
context.set("workers", Arc::new(workers)); context.set("workers", workers);
context.set("labels", final_labels); context.set("labels", final_labels);
Ok(StepResult::Success) Ok(StepResult::Success)
...@@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep { ...@@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep {
); );
} }
context.set("worker_ids", Arc::new(worker_ids)); context.set("worker_ids", worker_ids);
Ok(StepResult::Success) Ok(StepResult::Success)
} else { } else {
// Non-DP-aware path: Register single worker // Non-DP-aware path: Register single worker
......
//! Worker Removal Workflow Steps
//!
//! This module implements the workflow steps for removing workers from the router.
//! Handles both single worker removal and DP-aware worker removal with prefix matching.
//!
//! Steps:
//! 1. FindWorkersToRemove - Identify workers to remove based on URL (handles DP-aware prefix matching)
//! 2. RemoveFromPolicyRegistry - Remove workers from policy registry and cache-aware policies
//! 3. RemoveFromWorkerRegistry - Remove workers from worker registry
//! 4. UpdateRemainingPolicies - Update cache-aware policies for remaining workers
use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait;
use tracing::{debug, info};
use crate::{
core::{workflow::*, Worker},
server::AppContext,
};
/// Request structure for worker removal
#[derive(Debug, Clone)]
pub struct WorkerRemovalRequest {
pub url: String,
pub dp_aware: bool,
}
/// Step 1: Find workers to remove based on URL
pub struct FindWorkersToRemoveStep;
#[async_trait]
impl StepExecutor for FindWorkersToRemoveStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let request: Arc<WorkerRemovalRequest> = context
.get("removal_request")
.ok_or_else(|| WorkflowError::ContextValueNotFound("removal_request".to_string()))?;
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
debug!(
"Finding workers to remove for {} (dp_aware: {})",
request.url, request.dp_aware
);
let workers_to_remove: Vec<Arc<dyn Worker>> = if request.dp_aware {
// DP-aware: Find all workers with matching prefix
let worker_url_prefix = format!("{}@", request.url);
let all_workers = app_context.worker_registry.get_all();
all_workers
.iter()
.filter(|worker| worker.url().starts_with(&worker_url_prefix))
.cloned()
.collect()
} else {
// Non-DP-aware: Find single worker by exact URL
match app_context.worker_registry.get_by_url(&request.url) {
Some(worker) => vec![worker],
None => Vec::new(),
}
};
if workers_to_remove.is_empty() {
let error_msg = if request.dp_aware {
format!("No workers found with prefix {}@", request.url)
} else {
format!("Worker {} not found", request.url)
};
return Err(WorkflowError::StepFailed {
step_id: StepId::new("find_workers_to_remove"),
message: error_msg,
});
}
debug!(
"Found {} worker(s) to remove for {}",
workers_to_remove.len(),
request.url
);
// Store workers and their model IDs for subsequent steps
let worker_urls: Vec<String> = workers_to_remove
.iter()
.map(|w| w.url().to_string())
.collect();
let affected_models: HashSet<String> = workers_to_remove
.iter()
.map(|w| w.model_id().to_string())
.collect();
context.set("workers_to_remove", workers_to_remove);
context.set("worker_urls", worker_urls);
context.set("affected_models", affected_models);
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Worker not found is not retryable
}
}
/// Step 2: Remove workers from policy registry
pub struct RemoveFromPolicyRegistryStep;
#[async_trait]
impl StepExecutor for RemoveFromPolicyRegistryStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let workers_to_remove: Arc<Vec<Arc<dyn Worker>>> = context
.get("workers_to_remove")
.ok_or_else(|| WorkflowError::ContextValueNotFound("workers_to_remove".to_string()))?;
debug!(
"Removing {} worker(s) from policy registry",
workers_to_remove.len()
);
for worker in workers_to_remove.iter() {
let model_id = worker.model_id().to_string();
let worker_url = worker.url();
// Remove from cache-aware policy
app_context
.policy_registry
.remove_worker_from_cache_aware(&model_id, worker_url);
// Notify policy registry
app_context.policy_registry.on_worker_removed(&model_id);
debug!(
"Removed worker {} from policy registry (model: {})",
worker_url, model_id
);
}
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Policy removal is not retryable
}
}
/// Step 3: Remove workers from worker registry
pub struct RemoveFromWorkerRegistryStep;
#[async_trait]
impl StepExecutor for RemoveFromWorkerRegistryStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let worker_urls: Arc<Vec<String>> = context
.get("worker_urls")
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?;
debug!(
"Removing {} worker(s) from worker registry",
worker_urls.len()
);
let mut removed_count = 0;
for worker_url in worker_urls.iter() {
if app_context
.worker_registry
.remove_by_url(worker_url)
.is_some()
{
removed_count += 1;
debug!("Removed worker {} from registry", worker_url);
}
}
if removed_count != worker_urls.len() {
return Err(WorkflowError::StepFailed {
step_id: StepId::new("remove_from_worker_registry"),
message: format!(
"Expected to remove {} workers but only removed {}",
worker_urls.len(),
removed_count
),
});
}
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Worker removal is not retryable
}
}
/// Step 4: Update cache-aware policies for remaining workers
pub struct UpdateRemainingPoliciesStep;
#[async_trait]
impl StepExecutor for UpdateRemainingPoliciesStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let affected_models: Arc<HashSet<String>> = context
.get("affected_models")
.ok_or_else(|| WorkflowError::ContextValueNotFound("affected_models".to_string()))?;
let worker_urls: Arc<Vec<String>> = context
.get("worker_urls")
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?;
debug!(
"Updating cache-aware policies for {} affected model(s)",
affected_models.len()
);
for model_id in affected_models.iter() {
let remaining_workers = app_context.worker_registry.get_by_model_fast(model_id);
if let Some(policy) = app_context.policy_registry.get_policy(model_id) {
if policy.name() == "cache_aware" && !remaining_workers.is_empty() {
app_context
.policy_registry
.init_cache_aware_policy(model_id, &remaining_workers);
debug!(
"Updated cache-aware policy for model {} ({} remaining workers)",
model_id,
remaining_workers.len()
);
}
}
}
// Log final result at info level
if worker_urls.len() == 1 {
info!("Removed worker {}", worker_urls[0]);
} else {
info!(
"Removed {} DP-aware workers: {:?}",
worker_urls.len(),
worker_urls
);
}
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Policy update is not retryable
}
}
/// Create a worker removal workflow definition
pub fn create_worker_removal_workflow() -> WorkflowDefinition {
use std::time::Duration;
WorkflowDefinition::new("worker_removal", "Remove worker from router")
.add_step(
StepDefinition::new(
"find_workers_to_remove",
"Find workers to remove",
Arc::new(FindWorkersToRemoveStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
.add_step(
StepDefinition::new(
"remove_from_policy_registry",
"Remove workers from policy registry",
Arc::new(RemoveFromPolicyRegistryStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
.add_step(
StepDefinition::new(
"remove_from_worker_registry",
"Remove workers from worker registry",
Arc::new(RemoveFromWorkerRegistryStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
.add_step(
StepDefinition::new(
"update_remaining_policies",
"Update cache-aware policies for remaining workers",
Arc::new(UpdateRemainingPoliciesStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
}
...@@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
Ok(ChatCompletionRequest { Ok(ChatCompletionRequest {
messages, messages,
model: req.model.clone().unwrap_or_else(|| "default".to_string()), model: if req.model.is_empty() {
"default".to_string()
} else {
req.model.clone()
},
temperature: req.temperature, temperature: req.temperature,
max_completion_tokens: req.max_output_tokens, max_completion_tokens: req.max_output_tokens,
stream: is_streaming, stream: is_streaming,
...@@ -311,7 +315,7 @@ mod tests { ...@@ -311,7 +315,7 @@ mod tests {
let req = ResponsesRequest { let req = ResponsesRequest {
input: ResponseInput::Text("Hello, world!".to_string()), input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("You are a helpful assistant.".to_string()), instructions: Some("You are a helpful assistant.".to_string()),
model: Some("gpt-4".to_string()), model: "gpt-4".to_string(),
temperature: Some(0.7), temperature: Some(0.7),
..Default::default() ..Default::default()
}; };
......
...@@ -324,10 +324,11 @@ async fn route_responses_background( ...@@ -324,10 +324,11 @@ async fn route_responses_background(
incomplete_details: None, incomplete_details: None,
instructions: request.instructions.clone(), instructions: request.instructions.clone(),
max_output_tokens: request.max_output_tokens, max_output_tokens: request.max_output_tokens,
model: request model: if request.model.is_empty() {
.model "default".to_string()
.clone() } else {
.unwrap_or_else(|| "default".to_string()), request.model.clone()
},
output: Vec::new(), output: Vec::new(),
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true), parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(), previous_response_id: request.previous_response_id.clone(),
...@@ -622,10 +623,11 @@ async fn process_and_transform_sse_stream( ...@@ -622,10 +623,11 @@ async fn process_and_transform_sse_stream(
// Create event emitter for OpenAI-compatible streaming // Create event emitter for OpenAI-compatible streaming
let response_id = format!("resp_{}", Uuid::new_v4()); let response_id = format!("resp_{}", Uuid::new_v4());
let model = original_request let model = if original_request.model.is_empty() {
.model "default".to_string()
.clone() } else {
.unwrap_or_else(|| "default".to_string()); original_request.model.clone()
};
let created_at = chrono::Utc::now().timestamp() as u64; let created_at = chrono::Utc::now().timestamp() as u64;
let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at); let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at);
......
...@@ -608,10 +608,11 @@ async fn execute_tool_loop_streaming_internal( ...@@ -608,10 +608,11 @@ async fn execute_tool_loop_streaming_internal(
// Create response event emitter // Create response event emitter
let response_id = format!("resp_{}", Uuid::new_v4()); let response_id = format!("resp_{}", Uuid::new_v4());
let model = current_request let model = if current_request.model.is_empty() {
.model "default".to_string()
.clone() } else {
.unwrap_or_else(|| "default".to_string()); current_request.model.clone()
};
let created_at = SystemTime::now() let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
......
...@@ -22,8 +22,12 @@ use tracing::{error, info, warn, Level}; ...@@ -22,8 +22,12 @@ use tracing::{error, info, warn, Level};
use crate::{ use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{ core::{
worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor, worker_to_info,
WorkerManager, WorkerRegistry, WorkerType, workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
WorkflowEngine,
},
Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType,
}, },
data_connector::{ data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
...@@ -439,51 +443,6 @@ async fn v1_conversations_delete_item( ...@@ -439,51 +443,6 @@ async fn v1_conversations_delete_item(
.await .await
} }
#[derive(Deserialize)]
struct AddWorkerQuery {
url: String,
api_key: Option<String>,
}
async fn add_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response {
// Warn if router has API key but worker is being added without one
if state.context.router_config.api_key.is_some() && api_key.is_none() {
warn!(
"Adding worker {} without API key while router has API key configured. \
Worker will be accessible without authentication. \
If the worker requires the same API key as the router, please specify it explicitly.",
url
);
}
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
match result {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
Json(json!({ "urls": worker_list })).into_response()
}
async fn remove_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
) -> Response {
let result = WorkerManager::remove_worker(&url, &state.context);
match result {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response { async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client) match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
.await .await
...@@ -566,6 +525,12 @@ async fn create_worker( ...@@ -566,6 +525,12 @@ async fn create_worker(
); );
} }
// Populate dp_aware from router's configuration
let config = WorkerConfigRequest {
dp_aware: state.context.router_config.dp_aware,
..config
};
// Submit job for async processing // Submit job for async processing
let worker_url = config.url.clone(); let worker_url = config.url.clone();
let job = Job::AddWorker { let job = Job::AddWorker {
...@@ -761,9 +726,6 @@ pub fn build_app( ...@@ -761,9 +726,6 @@ pub fn build_app(
.route("/get_server_info", get(get_server_info)); .route("/get_server_info", get(get_server_info));
let admin_routes = Router::new() let admin_routes = Router::new()
.route("/add_worker", post(add_worker))
.route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache)) .route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads)) .route("/get_loads", get(get_loads))
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
...@@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
engine engine
.event_bus() .event_bus()
.subscribe(Arc::new(crate::core::workflow::LoggingSubscriber)) .subscribe(Arc::new(LoggingSubscriber))
.await; .await;
engine.register_workflow(crate::core::workflow::create_worker_registration_workflow()); engine.register_workflow(create_worker_registration_workflow());
engine.register_workflow(create_worker_removal_workflow());
app_context app_context
.workflow_engine .workflow_engine
.set(engine) .set(engine)
.expect("WorkflowEngine should only be initialized once"); .expect("WorkflowEngine should only be initialized once");
info!("Workflow engine initialized with worker registration workflow"); info!("Workflow engine initialized with worker registration and removal workflows");
info!( info!(
"Initializing workers for routing mode: {:?}", "Initializing workers for routing mode: {:?}",
......
...@@ -18,11 +18,7 @@ use rustls; ...@@ -18,11 +18,7 @@ use rustls;
use tokio::{task, time}; use tokio::{task, time};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{ use crate::{core::Job, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
core::{Job, WorkerManager},
protocols::worker_spec::WorkerConfigRequest,
server::AppContext,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig { pub struct ServiceDiscoveryConfig {
...@@ -386,7 +382,7 @@ async fn handle_pod_event( ...@@ -386,7 +382,7 @@ async fn handle_pod_event(
reasoning_parser: None, reasoning_parser: None,
tool_parser: None, tool_parser: None,
chat_template: None, chat_template: None,
api_key: None, api_key: app_context.router_config.api_key.clone(),
health_check_timeout_secs: app_context.router_config.health_check.timeout_secs, health_check_timeout_secs: app_context.router_config.health_check.timeout_secs,
health_check_interval_secs: app_context health_check_interval_secs: app_context
.router_config .router_config
...@@ -453,8 +449,24 @@ async fn handle_pod_deletion( ...@@ -453,8 +449,24 @@ async fn handle_pod_deletion(
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
); );
if let Err(e) = WorkerManager::remove_worker(&worker_url, &app_context) { let job = Job::RemoveWorker {
error!("Failed to remove worker {}: {}", worker_url, e); url: worker_url.clone(),
};
if let Some(job_queue) = app_context.worker_job_queue.get() {
if let Err(e) = job_queue.submit(job).await {
error!(
"Failed to submit worker removal job for {}: {}",
worker_url, e
);
} else {
debug!("Submitted worker removal job for {}", worker_url);
}
} else {
error!(
"JobQueue not initialized, cannot remove worker {}",
worker_url
);
} }
} else { } else {
debug!( debug!(
......
...@@ -14,7 +14,7 @@ use sglang_router_rs::{ ...@@ -14,7 +14,7 @@ use sglang_router_rs::{
config::{ config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
}, },
core::WorkerManager, core::Job,
routers::{RouterFactory, RouterTrait}, routers::{RouterFactory, RouterTrait},
server::AppContext, server::AppContext,
}; };
...@@ -112,22 +112,51 @@ impl TestContext { ...@@ -112,22 +112,51 @@ impl TestContext {
// Create app context // Create app context
let app_context = common::create_test_context(config.clone()); let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router // Submit worker initialization job (same as real server does)
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None) let job_queue = app_context
.worker_job_queue
.get()
.expect("JobQueue should be initialized");
let job = Job::InitializeWorkersFromConfig {
router_config: Box::new(config.clone()),
};
job_queue
.submit(job)
.await .await
.expect("Failed to initialize workers"); .expect("Failed to submit worker initialization job");
// Poll until all workers are healthy (up to 10 seconds)
let expected_count = worker_urls.len();
let start = tokio::time::Instant::now();
let timeout_duration = tokio::time::Duration::from_secs(10);
loop {
let healthy_workers = app_context
.worker_registry
.get_all()
.iter()
.filter(|w| w.is_healthy())
.count();
if healthy_workers >= expected_count {
break;
}
if start.elapsed() > timeout_duration {
panic!(
"Timeout waiting for {} workers to become healthy (only {} ready)",
expected_count, healthy_workers
);
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
} }
// Create router // Create router
let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
// Wait for router to discover workers
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self { Self {
workers, workers,
router, router,
...@@ -711,221 +740,6 @@ mod model_info_tests { ...@@ -711,221 +740,6 @@ mod model_info_tests {
} }
} }
#[cfg(test)]
mod worker_management_tests {
use super::*;
#[tokio::test]
async fn test_add_new_worker() {
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Start a mock worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18301,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
// Add the worker
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// List workers to verify
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
assert!(workers.iter().any(|w| w.as_str().unwrap() == url));
worker.stop().await;
ctx.shutdown().await;
}
#[tokio::test]
async fn test_remove_existing_worker() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18302,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Get the worker URL
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
let worker_url = workers[0].as_str().unwrap();
// Remove the worker
let req = Request::builder()
.method("POST")
.uri(format!("/remove_worker?url={}", worker_url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
assert!(workers.is_empty());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_worker_invalid_url() {
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Invalid URL format
let req = Request::builder()
.method("POST")
.uri("/add_worker?url=not-a-valid-url")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Missing URL parameter
let req = Request::builder()
.method("POST")
.uri("/add_worker")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Empty URL
let req = Request::builder()
.method("POST")
.uri("/add_worker?url=")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_duplicate_worker() {
// Start a mock worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18303,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Add worker first time
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Try to add same worker again
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Should return error for duplicate
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
worker.stop().await;
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_unhealthy_worker() {
// Start unhealthy worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18304,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Unhealthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Try to add unhealthy worker
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Router should reject unhealthy workers
assert!(
resp.status() == StatusCode::BAD_REQUEST
|| resp.status() == StatusCode::SERVICE_UNAVAILABLE
);
worker.stop().await;
ctx.shutdown().await;
}
}
#[cfg(test)] #[cfg(test)]
mod router_policy_tests { mod router_policy_tests {
use super::*; use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment