Unverified Commit 5dccf697 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] create worker removal step and clean up worker manager (#11921)

parent eec9e471
......@@ -85,6 +85,8 @@ def _popen_launch_router(
str(prom_port),
"--router-prometheus-host",
"127.0.0.1",
"--router-log-level",
"warn",
]
proc = subprocess.Popen(cmd)
......
import time
from types import SimpleNamespace
import pytest
import requests
def _wait_for_workers(
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
) -> None:
"""Poll /workers endpoint until expected number of workers are registered."""
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
if r.status_code == 200:
workers = r.json().get("workers", [])
if len(workers) >= expected_count:
return
except requests.RequestException:
pass
time.sleep(0.5)
raise TimeoutError(
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
)
@pytest.mark.e2e
def test_embeddings_basic(
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
......@@ -12,8 +34,11 @@ def test_embeddings_basic(
worker_url = e2e_primary_embedding_worker.url
# Attach embedding worker to router-only instance
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
# Simple embedding request with two inputs
payload = {
......
......@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
"--policy",
"round_robin",
"--pd-disaggregation",
"--log-level",
"warn",
]
for url, bport in prefill:
cmd += ["--prefill", url, str(bport)]
......
......@@ -8,13 +8,39 @@ import requests
from sglang.test.run_eval import run_eval
def _wait_for_workers(
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
) -> None:
"""Poll /workers endpoint until expected number of workers are registered."""
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
if r.status_code == 200:
workers = r.json().get("workers", [])
if len(workers) >= expected_count:
return
except requests.RequestException:
pass
time.sleep(0.5)
raise TimeoutError(
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
)
@pytest.mark.e2e
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
assert (
r.status_code == 202
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered
_wait_for_workers(base, expected_count=2, timeout=60.0)
args = SimpleNamespace(
base_url=base,
......@@ -35,8 +61,13 @@ def test_genai_bench(
"""Attach a worker to the regular router and run a short genai-bench."""
base = e2e_router_only_rr.url
for w in e2e_two_workers_dp2:
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
assert (
r.status_code == 202
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered
_wait_for_workers(base, expected_count=2, timeout=60.0)
genai_bench_runner(
router_url=base,
......@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
base = e2e_router_only_rr.url
worker_url = e2e_primary_worker.url
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
with requests.Session() as s:
for i in range(8):
......@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
r.raise_for_status()
# Remove the worker
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
r.raise_for_status()
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
r = requests.delete(f"{base}/workers/{encoded_url}", timeout=60)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
@pytest.mark.e2e
......@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
base = e2e_router_only_rr.url
worker = e2e_primary_worker
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
r.raise_for_status()
r = requests.post(f"{base}/workers", json={"url": worker.url}, timeout=180)
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for worker to be registered
_wait_for_workers(base, expected_count=1, timeout=60.0)
def killer():
time.sleep(10)
......@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers
r = requests.post(
f"{router_url}/add_worker",
params={"url": worker_url, "api_key": api_key},
f"{router_url}/workers",
json={"url": worker_url, "api_key": api_key},
headers={"Authorization": f"Bearer {api_key}"},
timeout=180,
)
r.raise_for_status()
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
# Wait for workers to be registered and expanded
_wait_for_workers(
router_url,
expected_count=2,
timeout=60.0,
headers={"Authorization": f"Bearer {api_key}"},
)
# Verify the expanded workers have correct URLs
r = requests.get(
f"{router_url}/list_workers",
f"{router_url}/workers",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30,
)
r.raise_for_status()
urls = r.json().get("urls", [])
workers = r.json().get("workers", [])
urls = [w["url"] for w in workers]
assert len(urls) == 2
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
......
......@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
policy,
"--model-path",
model,
"--log-level",
"warn",
]
# Add worker URLs
......
......@@ -133,19 +133,90 @@ class RouterManager:
time.sleep(0.2)
raise TimeoutError(f"Router at {base_url} did not become healthy")
def add_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
def add_worker(self, base_url: str, worker_url: str, timeout: float = 30.0) -> None:
r = requests.post(f"{base_url}/workers", json={"url": worker_url})
assert (
r.status_code == 202
), f"add_worker failed: {r.status_code} {r.text}" # ACCEPTED status
def remove_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
# Poll until worker is actually added and healthy
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
if r.status_code == 200:
data = r.json()
# Check if registration job failed
job_status = data.get("job_status")
if job_status and job_status.get("state") == "failed":
raise RuntimeError(
f"Worker registration failed: {job_status.get('message', 'Unknown error')}"
)
# Check if worker is healthy and registered (not just in job queue)
if data.get("is_healthy", False):
return
# Worker not ready yet, continue polling
except requests.RequestException:
pass
time.sleep(0.1)
raise TimeoutError(
f"Worker {worker_url} was not added and healthy after {timeout}s"
)
def remove_worker(
self, base_url: str, worker_url: str, timeout: float = 30.0
) -> None:
# URL encode the worker_url for path parameter
from urllib.parse import quote
encoded_url = quote(worker_url, safe="")
r = requests.delete(f"{base_url}/workers/{encoded_url}")
assert (
r.status_code == 202
), f"remove_worker failed: {r.status_code} {r.text}" # ACCEPTED status
# Poll until worker is actually removed (GET returns 404) or timeout
start = time.time()
last_status = None
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
if r.status_code == 404:
# Worker successfully removed
return
elif r.status_code == 200:
# Check if removal job failed
data = r.json()
job_status = data.get("job_status")
if job_status:
last_status = job_status
if job_status.get("state") == "failed":
raise RuntimeError(
f"Worker removal failed: {job_status.get('message', 'Unknown error')}"
)
# Worker still being processed, continue polling
except requests.RequestException:
pass
time.sleep(0.1)
# Provide detailed timeout error with last known status
error_msg = f"Worker {worker_url} was not removed after {timeout}s"
if last_status:
error_msg += f". Last job status: {last_status}"
raise TimeoutError(error_msg)
def list_workers(self, base_url: str) -> list[str]:
r = requests.get(f"{base_url}/list_workers")
r = requests.get(f"{base_url}/workers")
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
data = r.json()
return data.get("urls", [])
# Extract URLs from WorkerInfo objects
workers = data.get("workers", [])
return [w["url"] for w in workers]
def stop_all(self):
for p in self._children:
......
......@@ -2,7 +2,7 @@ import os
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import pytest
import requests
......@@ -84,7 +84,7 @@ def mock_workers():
procs: List[subprocess.Popen] = []
def _start(n: int, args: List[str] | None = None):
def _start(n: int, args: Optional[List[str]] = None):
args = args or []
new_procs: List[subprocess.Popen] = []
urls: List[str] = []
......
......@@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn};
use crate::{
config::{RouterConfig, RoutingMode},
core::{
workflow::{
WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus,
},
WorkerManager,
core::workflow::{
steps::WorkerRemovalRequest, WorkflowContext, WorkflowEngine, WorkflowId,
WorkflowInstanceId, WorkflowStatus,
},
metrics::RouterMetrics,
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
......@@ -320,11 +318,29 @@ impl JobQueue {
.await
}
Job::RemoveWorker { url } => {
let result = WorkerManager::remove_worker(url, context);
let engine = context
.workflow_engine
.get()
.ok_or_else(|| "Workflow engine not initialized".to_string())?;
let instance_id = Self::start_worker_removal_workflow(engine, url, context).await?;
debug!(
"Started worker removal workflow for {} (instance: {})",
url, instance_id
);
let timeout_duration = Duration::from_secs(30);
let result =
Self::wait_for_workflow_completion(engine, instance_id, url, timeout_duration)
.await;
// Clean up job status when removing worker
if let Some(queue) = context.worker_job_queue.get() {
queue.remove_status(url);
}
result
}
Job::InitializeWorkersFromConfig { router_config } => {
......@@ -424,6 +440,27 @@ impl JobQueue {
.map_err(|e| format!("Failed to start worker registration workflow: {:?}", e))
}
/// Start worker removal workflow
async fn start_worker_removal_workflow(
engine: &Arc<WorkflowEngine>,
url: &str,
context: &Arc<AppContext>,
) -> Result<WorkflowInstanceId, String> {
let removal_request = WorkerRemovalRequest {
url: url.to_string(),
dp_aware: context.router_config.dp_aware,
};
let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new());
workflow_context.set("removal_request", removal_request);
workflow_context.set_arc("app_context", Arc::clone(context));
engine
.start_workflow(WorkflowId::new("worker_removal"), workflow_context)
.await
.map_err(|e| format!("Failed to start worker removal workflow: {:?}", e))
}
/// Wait for workflow completion with adaptive polling
async fn wait_for_workflow_completion(
engine: &Arc<WorkflowEngine>,
......
......@@ -29,5 +29,5 @@ pub use worker::{
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
};
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager};
pub use worker_manager::{LoadMonitor, WorkerManager};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
......@@ -6,8 +6,6 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use futures::future;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::{
sync::{watch, Mutex},
......@@ -16,698 +14,15 @@ use tokio::{
use tracing::{debug, error, info, warn};
use crate::{
config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
},
core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder,
HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType,
},
grpc_client::SglangSchedulerClient,
core::{ConnectionMode, WorkerRegistry, WorkerType},
policies::PolicyRegistry,
protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
},
server::AppContext,
protocols::worker_spec::{FlushCacheResult, WorkerLoadInfo, WorkerLoadsResult},
};
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client")
});
/// Server information returned from worker endpoints
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerInfo {
pub model_id: Option<String>,
pub model_path: Option<String>,
pub dp_size: Option<usize>,
pub version: Option<String>,
pub max_batch_size: Option<usize>,
pub max_total_tokens: Option<usize>,
pub max_prefill_tokens: Option<usize>,
pub max_running_requests: Option<usize>,
pub max_num_reqs: Option<usize>,
}
/// DP (Data Parallel) information for a worker
#[derive(Debug, Clone)]
pub struct DpInfo {
pub dp_size: usize,
pub model_id: String,
}
/// Worker discovery results gathered from backend endpoints
struct WorkerDiscovery {
labels: HashMap<String, String>,
grpc_client: Option<SglangSchedulerClient>,
}
impl WorkerDiscovery {
fn new() -> Self {
Self {
labels: HashMap::new(),
grpc_client: None,
}
}
}
/// Unified worker management
pub struct WorkerManager;
impl WorkerManager {
/// Get server info from /get_server_info endpoint
pub async fn get_server_info(url: &str, api_key: Option<&str>) -> Result<ServerInfo, String> {
let base_url = url.trim_end_matches('/');
let server_info_url = format!("{}/get_server_info", base_url);
let mut req = HTTP_CLIENT.get(&server_info_url);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}
let response = req
.send()
.await
.map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?;
if !response.status().is_success() {
return Err(format!(
"Server returned status {} from {}",
response.status(),
server_info_url
));
}
let json = response
.json::<Value>()
.await
.map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?;
info!(
"Successfully retrieved server info from {}",
server_info_url
);
Self::parse_server_info(json)
}
/// Get model info from /get_model_info endpoint
pub async fn get_model_info(url: &str, api_key: Option<&str>) -> Result<Value, String> {
let base_url = url.trim_end_matches('/');
let model_info_url = format!("{}/get_model_info", base_url);
let mut req = HTTP_CLIENT.get(&model_info_url);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}
let response = req
.send()
.await
.map_err(|e| format!("Failed to connect to {}: {}", model_info_url, e))?;
if !response.status().is_success() {
return Err(format!(
"Server returned status {} from {}",
response.status(),
model_info_url
));
}
let json = response
.json::<Value>()
.await
.map_err(|e| format!("Failed to parse response from {}: {}", model_info_url, e))?;
info!("Successfully retrieved model info from {}", model_info_url);
Ok(json)
}
/// Get DP info for a worker URL
pub async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result<DpInfo, String> {
let info = Self::get_server_info(url, api_key).await?;
let dp_size = info
.dp_size
.ok_or_else(|| format!("No dp_size in response from {}", url))?;
let model_id = info
.model_id
.or_else(|| {
info.model_path
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
})
.unwrap_or_else(|| "unknown".to_string());
Ok(DpInfo { dp_size, model_id })
}
/// Generate DP-aware worker URLs
pub async fn get_dp_aware_urls(
base_urls: &[String],
api_key: Option<&str>,
) -> Result<Vec<String>, String> {
let mut dp_urls = Vec::new();
for base_url in base_urls {
match Self::get_dp_info(base_url, api_key).await {
Ok(dp_info) => {
info!(
"Discovered DP size {} for {} (model: {})",
dp_info.dp_size, base_url, dp_info.model_id
);
for rank in 0..dp_info.dp_size {
dp_urls.push(format!("{}@{}", base_url, rank));
}
}
Err(e) => {
return Err(format!("Failed to get DP info from {}: {}", base_url, e));
}
}
}
Ok(dp_urls)
}
/// Initialize workers from configuration at startup
pub async fn initialize_workers(
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Starting worker initialization");
// Determine connection mode from config
let connection_mode = &config.connection_mode;
match &config.mode {
RoutingMode::Regular { worker_urls } => match connection_mode {
ConfigConnectionMode::Http => {
Self::initialize_regular_workers(
worker_urls,
config,
registry,
policy_registry,
)
.await?;
}
ConfigConnectionMode::Grpc => {
Self::initialize_grpc_workers(worker_urls, config, registry, policy_registry)
.await?;
}
},
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => match connection_mode {
ConfigConnectionMode::Http => {
let prefill_entries: Vec<(&String, &Option<u16>)> =
prefill_urls.iter().map(|(url, port)| (url, port)).collect();
Self::initialize_prefill_workers(
&prefill_entries,
config,
registry,
policy_registry,
)
.await?;
Self::initialize_decode_workers(decode_urls, config, registry, policy_registry)
.await?;
}
ConfigConnectionMode::Grpc => {
Self::initialize_grpc_pd_workers(
prefill_urls,
decode_urls,
config,
registry,
policy_registry,
)
.await?;
}
},
RoutingMode::OpenAI { .. } => {
info!("OpenAI routing mode - no workers to initialize");
}
}
Self::wait_for_healthy_workers(
registry,
config.worker_startup_timeout_secs,
config.health_check.check_interval_secs,
)
.await?;
info!("Worker initialization completed successfully");
Ok(())
}
/// Initialize regular workers
async fn initialize_regular_workers(
urls: &[String],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} regular workers", urls.len());
let connection_mode = Self::convert_connection_mode(&config.connection_mode, urls.first());
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
if config.dp_aware {
match Self::get_dp_info(url, config.api_key.as_deref()).await {
Ok(dp_info) => {
info!(
"Discovered DP-aware worker {} with size {}",
url, dp_info.dp_size
);
for rank in 0..dp_info.dp_size {
let mut builder =
DPAwareWorkerBuilder::new(url.clone(), rank, dp_info.dp_size)
.worker_type(WorkerType::Regular)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(circuit_breaker_config.clone())
.health_config(health_config.clone());
if let Some(ref key) = config.api_key {
builder = builder.api_key(key.clone());
}
let worker = Arc::new(builder.build()) as Arc<dyn Worker>;
let model_id = worker.model_id();
let worker_id = registry.register(Arc::clone(&worker));
info!(
"Registered DP-aware worker {}@{} with ID {:?}",
url, rank, worker_id
);
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker));
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
}
Err(e) => {
return Err(format!(
"Failed to get DP info for worker {}: {}. DP-aware mode requires all workers to support DP.",
url, e
));
}
}
} else {
let worker = Self::create_basic_worker(
url.clone(),
WorkerType::Regular,
connection_mode.clone(),
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
)
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
}
}
Self::initialize_cache_policies(&registered_workers, registry, policy_registry);
Ok(())
}
/// Initialize prefill workers for PD mode
async fn initialize_prefill_workers(
prefill_entries: &[(&String, &Option<u16>)],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} prefill workers", prefill_entries.len());
let connection_mode = Self::convert_connection_mode(
&config.connection_mode,
prefill_entries.first().map(|(url, _)| *url),
);
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
// TODO: Add proper DP-aware support for prefill workers in PD mode
if config.dp_aware {
warn!("DP-aware mode is not yet supported for prefill workers in PD mode. Creating regular prefill workers instead.");
}
for (url, bootstrap_port) in prefill_entries {
let worker_type = WorkerType::Prefill {
bootstrap_port: **bootstrap_port,
};
let worker = Self::create_basic_worker(
(*url).clone(),
worker_type,
connection_mode.clone(),
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
)
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
}
if let Some(policy_reg) = policy_registry {
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]);
}
Ok(())
}
/// Initialize decode workers for PD mode
async fn initialize_decode_workers(
urls: &[String],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} decode workers", urls.len());
let connection_mode = Self::convert_connection_mode(&config.connection_mode, urls.first());
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
// TODO: Add proper DP-aware support for decode workers in PD mode
if config.dp_aware {
warn!("DP-aware mode is not yet supported for decode workers in PD mode. Creating regular decode workers instead.");
}
for url in urls {
let worker = Self::create_basic_worker(
url.clone(),
WorkerType::Decode,
connection_mode.clone(),
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
)
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
}
if let Some(policy_reg) = policy_registry {
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers);
}
Ok(())
}
/// Initialize gRPC workers for regular mode
async fn initialize_grpc_workers(
urls: &[String],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} gRPC regular workers", urls.len());
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let connection_mode = ConnectionMode::Grpc { port: None };
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
let worker = Self::create_basic_worker(
url.clone(),
WorkerType::Regular,
connection_mode.clone(),
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
)
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
info!(
"Registered gRPC worker at {} (will connect on first use)",
url
);
}
Self::initialize_cache_policies(&registered_workers, registry, policy_registry);
Ok(())
}
/// Initialize gRPC PD (Prefill-Decode) workers
async fn initialize_grpc_pd_workers(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!(
"Creating {} gRPC prefill workers and {} gRPC decode workers",
prefill_urls.len(),
decode_urls.len()
);
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let mut registered_prefill_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
let mut registered_decode_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, bootstrap_port) in prefill_urls {
let worker_type = WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
};
let connection_mode = ConnectionMode::Grpc {
port: *bootstrap_port,
};
let worker = Self::create_basic_worker(
url.clone(),
worker_type,
connection_mode,
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
)
.await;
Self::register_worker(
worker,
registry,
&mut registered_prefill_workers,
policy_registry,
);
info!(
"Registered gRPC prefill worker at {} (will connect on first use)",
url
);
}
// Create decode workers
for url in decode_urls {
let connection_mode = ConnectionMode::Grpc { port: None };
let worker = Self::create_basic_worker(
url.clone(),
WorkerType::Decode,
connection_mode,
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
)
.await;
Self::register_worker(
worker,
registry,
&mut registered_decode_workers,
policy_registry,
);
info!(
"Registered gRPC decode worker at {} (will connect on first use)",
url
);
}
if let Some(policy_reg) = policy_registry {
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_prefill_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_decode_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &all_decode_workers);
}
Ok(())
}
/// Add a worker from a configuration request
///
/// Registers worker immediately with healthy=false, returns worker for async validation
pub async fn add_worker_from_config(
config: &WorkerConfigRequest,
context: &AppContext,
) -> Result<Arc<dyn Worker>, String> {
// Check if worker already exists
if context.worker_registry.get_by_url(&config.url).is_some() {
return Err(format!("Worker {} already exists", config.url));
}
let mut labels = config.labels.clone();
if let Some(model_id) = &config.model_id {
labels.insert("model_id".to_string(), model_id.clone());
}
if let Some(priority) = config.priority {
labels.insert("priority".to_string(), priority.to_string());
}
if let Some(cost) = config.cost {
labels.insert("cost".to_string(), cost.to_string());
}
if let Some(ref tokenizer_path) = config.tokenizer_path {
labels.insert("tokenizer_path".to_string(), tokenizer_path.clone());
}
if let Some(ref reasoning_parser) = config.reasoning_parser {
labels.insert("reasoning_parser".to_string(), reasoning_parser.clone());
}
if let Some(ref tool_parser) = config.tool_parser {
labels.insert("tool_parser".to_string(), tool_parser.clone());
}
if let Some(ref chat_template) = config.chat_template {
labels.insert("chat_template".to_string(), chat_template.clone());
}
let worker_type = config
.worker_type
.as_ref()
.map(|t| match t.as_str() {
"prefill" => WorkerType::Prefill {
bootstrap_port: config.bootstrap_port,
},
"decode" => WorkerType::Decode,
_ => WorkerType::Regular,
})
.unwrap_or(WorkerType::Regular);
let connection_mode = if config.url.starts_with("grpc://") {
ConnectionMode::Grpc { port: None }
} else {
ConnectionMode::Http
};
let circuit_breaker_config = Self::convert_circuit_breaker_config(
&context.router_config.effective_circuit_breaker_config(),
);
let health_config = Self::convert_health_config(&context.router_config.health_check);
// Create and register worker (starts with healthy=false)
let worker = Self::create_basic_worker(
config.url.clone(),
worker_type,
connection_mode,
config.api_key.clone(),
Some(labels.clone()),
circuit_breaker_config,
health_config,
)
.await;
worker.set_healthy(false);
context.worker_registry.register(worker.clone());
let policy_hint = labels.get("policy").map(|s| s.as_str());
let model_id = worker.model_id().to_string();
context
.policy_registry
.on_worker_added(&model_id, policy_hint);
info!("Registered worker {} (initializing)", config.url);
// Return worker for async validation
Ok(worker)
}
/// Validate and activate a worker (for async validation after registration)
pub async fn validate_and_activate_worker(
worker: &Arc<dyn Worker>,
context: &AppContext,
) -> Result<String, String> {
let url = worker.url();
// Perform health validation
WorkerFactory::validate_health(url, context.router_config.worker_startup_timeout_secs)
.await
.map_err(|e| format!("Health check failed for {}: {}", url, e))?;
// Mark as healthy
worker.set_healthy(true);
info!("Worker {} validated and activated", url);
Ok(format!("Worker {} is now healthy", url))
}
/// Add a worker from URL (legacy endpoint)
pub async fn add_worker(
url: &str,
api_key: &Option<String>,
context: &AppContext,
) -> Result<String, String> {
Self::add_worker_internal(
url,
WorkerType::Regular,
ConnectionMode::Http,
api_key.clone(),
None,
None,
context,
)
.await
}
/// Remove a worker
pub fn remove_worker(url: &str, context: &AppContext) -> Result<String, String> {
if context.router_config.dp_aware {
Self::remove_dp_aware_workers(url, context)
} else {
Self::remove_single_worker(url, context)
}
}
pub fn get_worker_urls(registry: &Arc<WorkerRegistry>) -> Vec<String> {
registry
.get_all()
......@@ -716,757 +31,6 @@ impl WorkerManager {
.collect()
}
/// Internal method to add a worker with all parameters
async fn add_worker_internal(
worker_url: &str,
worker_type: WorkerType,
connection_mode: ConnectionMode,
api_key: Option<String>,
labels: Option<HashMap<String, String>>,
policy_hint: Option<&str>,
context: &AppContext,
) -> Result<String, String> {
WorkerFactory::validate_health(
worker_url,
context.router_config.worker_startup_timeout_secs,
)
.await
.map_err(|e| format!("Health check failed: {}", e))?;
let circuit_breaker_config = Self::convert_circuit_breaker_config(
&context.router_config.effective_circuit_breaker_config(),
);
let health_config = Self::convert_health_config(&context.router_config.health_check);
if context.router_config.dp_aware {
let dp_urls = Self::get_dp_aware_urls(
&[worker_url.to_string()],
context.router_config.api_key.as_deref(),
)
.await?;
let mut workers_added = 0;
let mut model_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
let dp_size_for_base = dp_urls.len();
for (rank, dp_url) in dp_urls.iter().enumerate() {
if context.worker_registry.get_by_url(dp_url).is_some() {
info!("Worker {} already exists, skipping", dp_url);
continue;
}
let base_url = dp_url.split('@').next().unwrap().to_string();
let mut builder = DPAwareWorkerBuilder::new(base_url, rank, dp_size_for_base)
.worker_type(worker_type.clone())
.connection_mode(connection_mode.clone())
.circuit_breaker_config(circuit_breaker_config.clone())
.health_config(health_config.clone());
if let Some(ref key) = api_key {
builder = builder.api_key(key.clone());
}
if let Some(ref worker_labels) = labels {
builder = builder.labels(worker_labels.clone());
}
let worker = Arc::new(builder.build()) as Arc<dyn Worker>;
let model_id = worker.model_id().to_string();
context.worker_registry.register(worker.clone());
workers_added += 1;
model_workers
.entry(model_id.clone())
.or_default()
.push(worker);
context
.policy_registry
.on_worker_added(&model_id, policy_hint);
}
for model_id in model_workers.keys() {
let all_model_workers = context.worker_registry.get_by_model_fast(model_id);
if let Some(policy) = context.policy_registry.get_policy(model_id) {
if policy.name() == "cache_aware" {
context
.policy_registry
.init_cache_aware_policy(model_id, &all_model_workers);
}
}
}
if workers_added == 0 {
Ok(format!("All DP workers already exist for {}", worker_url))
} else {
Ok(format!(
"Added {} DP-aware workers for {}",
workers_added, worker_url
))
}
} else {
if context.worker_registry.get_by_url(worker_url).is_some() {
return Err(format!("Worker {} already exists", worker_url));
}
let worker = Self::create_basic_worker(
worker_url.to_string(),
worker_type,
connection_mode,
api_key,
labels,
circuit_breaker_config,
health_config,
)
.await;
let model_id = worker.model_id().to_string();
context.worker_registry.register(worker.clone());
context
.policy_registry
.on_worker_added(&model_id, policy_hint);
let workers = context.worker_registry.get_by_model_fast(&model_id);
if let Some(policy) = context.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
context
.policy_registry
.init_cache_aware_policy(&model_id, &workers);
}
}
Ok(format!("Worker {} added successfully", worker_url))
}
}
/// Remove a single worker
fn remove_single_worker(worker_url: &str, context: &AppContext) -> Result<String, String> {
let worker = context
.worker_registry
.get_by_url(worker_url)
.ok_or_else(|| format!("Worker {} not found", worker_url))?;
let model_id = worker.model_id().to_string();
context
.policy_registry
.remove_worker_from_cache_aware(&model_id, worker_url);
context.worker_registry.remove_by_url(worker_url);
context.policy_registry.on_worker_removed(&model_id);
let remaining_workers = context.worker_registry.get_by_model_fast(&model_id);
if let Some(policy) = context.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" && !remaining_workers.is_empty() {
context
.policy_registry
.init_cache_aware_policy(&model_id, &remaining_workers);
}
}
Ok(format!("Worker {} removed successfully", worker_url))
}
/// Remove DP-aware workers with prefix matching
fn remove_dp_aware_workers(worker_url: &str, context: &AppContext) -> Result<String, String> {
let worker_url_prefix = format!("{}@", worker_url);
let mut removed_workers = Vec::new();
let mut affected_models = std::collections::HashSet::new();
let all_workers = context.worker_registry.get_all();
for worker in all_workers.iter() {
if worker.url().starts_with(&worker_url_prefix) {
let model_id = worker.model_id().to_string();
affected_models.insert(model_id.clone());
context
.policy_registry
.remove_worker_from_cache_aware(&model_id, worker.url());
if context
.worker_registry
.remove_by_url(worker.url())
.is_some()
{
removed_workers.push(worker.url().to_string());
context.policy_registry.on_worker_removed(&model_id);
}
}
}
for model_id in affected_models {
let remaining_workers = context.worker_registry.get_by_model_fast(&model_id);
if let Some(policy) = context.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" && !remaining_workers.is_empty() {
context
.policy_registry
.init_cache_aware_policy(&model_id, &remaining_workers);
}
}
}
if removed_workers.is_empty() {
Err(format!(
"No workers found with prefix {}",
worker_url_prefix
))
} else {
Ok(format!(
"Removed {} DP-aware workers: {:?}",
removed_workers.len(),
removed_workers
))
}
}
/// Create a basic worker
async fn create_basic_worker(
url: String,
worker_type: WorkerType,
connection_mode: ConnectionMode,
api_key: Option<String>,
labels: Option<HashMap<String, String>>,
circuit_breaker_config: CircuitBreakerConfig,
health_config: HealthConfig,
) -> Arc<dyn Worker> {
let discovery =
Self::discover_worker_metadata(&url, &connection_mode, api_key.as_deref()).await;
let mut final_labels = discovery.labels;
if let Some(custom_labels) = labels {
for (key, value) in custom_labels {
final_labels.insert(key, value);
}
}
let mut builder = BasicWorkerBuilder::new(url)
.worker_type(worker_type)
.connection_mode(connection_mode)
.circuit_breaker_config(circuit_breaker_config)
.health_config(health_config);
if let Some(key) = api_key {
builder = builder.api_key(key);
}
if !final_labels.is_empty() {
builder = builder.labels(final_labels);
}
if let Some(client) = discovery.grpc_client {
builder = builder.grpc_client(client);
}
let worker = builder.build();
Arc::new(worker) as Arc<dyn Worker>
}
/// Register a worker and update policies
fn register_worker(
worker: Arc<dyn Worker>,
registry: &Arc<WorkerRegistry>,
registered_workers: &mut HashMap<String, Vec<Arc<dyn Worker>>>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) {
let model_id = worker.model_id();
let url = worker.url();
let worker_id = registry.register(Arc::clone(&worker));
info!("Registered worker {} with ID {:?}", url, worker_id);
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker));
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
/// Initialize cache-aware policies
fn initialize_cache_policies(
registered_workers: &HashMap<String, Vec<Arc<dyn Worker>>>,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) {
if let Some(policy_reg) = policy_registry {
for model_id in registered_workers.keys() {
let all_model_workers = registry.get_by_model_fast(model_id);
if let Some(policy) = policy_reg.get_policy(model_id) {
if policy.name() == "cache_aware" {
policy_reg.init_cache_aware_policy(model_id, &all_model_workers);
}
}
}
}
}
/// Wait for workers to become healthy
async fn wait_for_healthy_workers(
registry: &Arc<WorkerRegistry>,
timeout_secs: u64,
check_interval_secs: u64,
) -> Result<(), String> {
let timeout = Duration::from_secs(timeout_secs);
let check_interval = Duration::from_secs(check_interval_secs);
let start_time = std::time::Instant::now();
info!(
"Waiting for workers to become healthy (timeout: {}s)",
timeout_secs
);
let workers = registry.get_all();
if workers.is_empty() {
info!("No workers to wait for, continuing");
return Ok(());
}
// Mark all workers as unhealthy initially
info!(
"Marking {} workers as unhealthy before health checks",
workers.len()
);
for worker in &workers {
worker.set_healthy(false);
}
loop {
// 1. Filter unhealthy workers
let workers = registry.get_all();
let unhealthy_workers: Vec<_> = workers
.iter()
.filter(|w| !w.is_healthy())
.cloned()
.collect();
// 2. If all workers are healthy, return immediately
if unhealthy_workers.is_empty() {
let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect();
info!(
"All {} workers are healthy: {:?}",
workers.len(),
healthy_urls
);
return Ok(());
}
// Check timeout
if start_time.elapsed() > timeout {
let healthy_workers: Vec<_> = workers
.iter()
.filter(|w| w.is_healthy())
.map(|w| w.url().to_string())
.collect();
let unhealthy_urls: Vec<_> = unhealthy_workers
.iter()
.map(|w| w.url().to_string())
.collect();
error!(
"Workers failed to become healthy after {}s. Unhealthy: {:?}, Healthy: {:?}",
timeout_secs, unhealthy_urls, healthy_workers
);
return Err(format!(
"Workers failed to become healthy after {}s. Unhealthy: {:?}",
timeout_secs, unhealthy_urls
));
}
let unhealthy_urls: Vec<_> = unhealthy_workers
.iter()
.map(|w| w.url().to_string())
.collect();
info!(
"Waiting for {} workers to become healthy. Unhealthy: {:?}",
unhealthy_workers.len(),
unhealthy_urls
);
// 3. Check health of all unhealthy workers in parallel
let health_check_futures: Vec<_> = unhealthy_workers
.iter()
.map(|worker| {
let w = worker.clone();
let url = worker.url().to_string();
async move {
match w.check_health_async().await {
Ok(_) => {
w.set_healthy(true);
debug!("Worker {} now healthy", url);
}
Err(e) => {
debug!("Worker {} health check failed: {}", url, e);
}
}
}
})
.collect();
future::join_all(health_check_futures).await;
// 4. Check if all workers are now healthy after health checks
let still_unhealthy: Vec<_> = workers.iter().filter(|w| !w.is_healthy()).collect();
// 5. If all workers are now healthy, return immediately without sleeping
if still_unhealthy.is_empty() {
let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect();
info!(
"All {} workers are healthy: {:?}",
workers.len(),
healthy_urls
);
return Ok(());
}
// 6. Otherwise, sleep before next iteration
tokio::time::sleep(check_interval).await;
}
}
/// Gather worker metadata directly from the backend before registration.
async fn discover_worker_metadata(
url: &str,
connection_mode: &ConnectionMode,
api_key: Option<&str>,
) -> WorkerDiscovery {
match connection_mode {
ConnectionMode::Http => Self::discover_http_metadata(url, api_key).await,
ConnectionMode::Grpc { .. } => Self::discover_grpc_metadata(url).await,
}
}
async fn discover_http_metadata(url: &str, api_key: Option<&str>) -> WorkerDiscovery {
let mut discovery = WorkerDiscovery::new();
match Self::get_model_info(url, api_key).await {
Ok(model_info) => {
if let Some(model_path) = model_info.get("model_path").and_then(|v| v.as_str()) {
if !model_path.is_empty() {
discovery
.labels
.insert("model_path".to_string(), model_path.to_string());
}
}
if let Some(tokenizer_path) =
model_info.get("tokenizer_path").and_then(|v| v.as_str())
{
if !tokenizer_path.is_empty() {
discovery
.labels
.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
}
}
if let Some(served_model_name) =
model_info.get("served_model_name").and_then(|v| v.as_str())
{
if !served_model_name.is_empty() {
discovery.labels.insert(
"served_model_name".to_string(),
served_model_name.to_string(),
);
}
}
if let Some(weight_version) =
model_info.get("weight_version").and_then(|v| v.as_str())
{
if !weight_version.is_empty() {
discovery
.labels
.insert("weight_version".to_string(), weight_version.to_string());
}
}
if let Some(model_type) = model_info.get("model_type").and_then(|v| v.as_str()) {
if !model_type.is_empty() {
discovery
.labels
.insert("model_type".to_string(), model_type.to_string());
}
}
if let Some(is_generation) =
model_info.get("is_generation").and_then(|v| v.as_bool())
{
discovery
.labels
.insert("is_generation".to_string(), is_generation.to_string());
}
if let Some(preferred_sampling_params) = model_info
.get("preferred_sampling_params")
.and_then(|v| v.as_str())
{
if !preferred_sampling_params.is_empty() {
discovery.labels.insert(
"preferred_sampling_params".to_string(),
preferred_sampling_params.to_string(),
);
}
}
if let Some(max_context_length) = model_info
.get("max_context_length")
.and_then(|v| v.as_i64())
{
discovery.labels.insert(
"max_context_length".to_string(),
max_context_length.to_string(),
);
}
if let Some(max_req_input_len) =
model_info.get("max_req_input_len").and_then(|v| v.as_i64())
{
discovery.labels.insert(
"max_req_input_len".to_string(),
max_req_input_len.to_string(),
);
}
}
Err(e) => {
warn!(
"Worker discovery: failed to fetch HTTP model info from {}: {}",
url, e
);
}
}
match Self::get_server_info(url, api_key).await {
Ok(server_info) => {
if let Some(model_id) = server_info.model_id {
if !model_id.is_empty() {
discovery.labels.insert("model_id".to_string(), model_id);
}
}
if let Some(model_path) = server_info.model_path {
if !model_path.is_empty() {
discovery
.labels
.insert("model_path".to_string(), model_path);
}
}
if let Some(version) = server_info.version {
if !version.is_empty() {
discovery
.labels
.insert("server_version".to_string(), version);
}
}
if let Some(max_total_tokens) = server_info.max_total_tokens {
discovery
.labels
.insert("max_total_tokens".to_string(), max_total_tokens.to_string());
}
if let Some(max_prefill_tokens) = server_info.max_prefill_tokens {
discovery.labels.insert(
"max_prefill_tokens".to_string(),
max_prefill_tokens.to_string(),
);
}
if let Some(max_running_requests) = server_info.max_running_requests {
discovery.labels.insert(
"max_running_requests".to_string(),
max_running_requests.to_string(),
);
}
}
Err(e) => {
warn!(
"Worker discovery: failed to fetch HTTP server info from {}: {}",
url, e
);
}
}
Self::finalize_model_id(&mut discovery.labels);
discovery
}
async fn discover_grpc_metadata(url: &str) -> WorkerDiscovery {
let mut discovery = WorkerDiscovery::new();
let client = match SglangSchedulerClient::connect(url).await {
Ok(client) => client,
Err(e) => {
warn!(
"Worker discovery: failed to connect to gRPC worker {}: {}",
url, e
);
return discovery;
}
};
match client.get_model_info().await {
Ok(model_info) => {
if !model_info.model_path.is_empty() {
discovery
.labels
.insert("model_path".to_string(), model_info.model_path.clone());
}
if !model_info.tokenizer_path.is_empty() {
discovery.labels.insert(
"tokenizer_path".to_string(),
model_info.tokenizer_path.clone(),
);
}
if !model_info.served_model_name.is_empty() {
discovery.labels.insert(
"served_model_name".to_string(),
model_info.served_model_name.clone(),
);
discovery
.labels
.insert("model_id".to_string(), model_info.served_model_name);
}
if !model_info.weight_version.is_empty() {
discovery.labels.insert(
"weight_version".to_string(),
model_info.weight_version.clone(),
);
}
if !model_info.model_type.is_empty() {
discovery
.labels
.insert("model_type".to_string(), model_info.model_type.clone());
}
if !model_info.preferred_sampling_params.is_empty() {
discovery.labels.insert(
"preferred_sampling_params".to_string(),
model_info.preferred_sampling_params.clone(),
);
}
discovery.labels.insert(
"is_generation".to_string(),
model_info.is_generation.to_string(),
);
if model_info.max_context_length > 0 {
discovery.labels.insert(
"max_context_length".to_string(),
model_info.max_context_length.to_string(),
);
}
if model_info.max_req_input_len > 0 {
discovery.labels.insert(
"max_req_input_len".to_string(),
model_info.max_req_input_len.to_string(),
);
}
if model_info.vocab_size > 0 {
discovery
.labels
.insert("vocab_size".to_string(), model_info.vocab_size.to_string());
}
}
Err(e) => {
warn!(
"Worker discovery: failed to fetch gRPC model info from {}: {}",
url, e
);
}
}
if !discovery.labels.contains_key("model_id") {
Self::finalize_model_id(&mut discovery.labels);
}
discovery.grpc_client = Some(client);
discovery
}
fn finalize_model_id(labels: &mut HashMap<String, String>) {
let has_model_id = labels
.get("model_id")
.map(|v| !v.trim().is_empty())
.unwrap_or(false);
if has_model_id {
return;
}
if let Some(served_name) = labels.get("served_model_name") {
if !served_name.trim().is_empty() {
labels.insert("model_id".to_string(), served_name.clone());
return;
}
}
if let Some(model_path) = labels.get("model_path") {
if !model_path.trim().is_empty() {
labels.insert("model_id".to_string(), model_path.clone());
}
}
}
/// Parse server info from JSON response
fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
Ok(ServerInfo {
model_id: json
.get("model_id")
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| json.get("model").and_then(|v| v.as_str()).map(String::from)),
model_path: json
.get("model_path")
.and_then(|v| v.as_str())
.map(String::from),
dp_size: json
.get("dp_size")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
version: json
.get("version")
.and_then(|v| v.as_str())
.map(String::from),
max_batch_size: json
.get("max_batch_size")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
max_total_tokens: json
.get("max_total_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
max_prefill_tokens: json
.get("max_prefill_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
max_running_requests: json
.get("max_running_requests")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
max_num_reqs: json
.get("max_num_reqs")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
})
}
/// Convert config connection mode to core connection mode
fn convert_connection_mode(
config_mode: &ConfigConnectionMode,
_sample_url: Option<&String>,
) -> ConnectionMode {
match config_mode {
ConfigConnectionMode::Http => ConnectionMode::Http,
ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None },
}
}
/// Convert config circuit breaker to core circuit breaker
fn convert_circuit_breaker_config(config: &ConfigCircuitBreakerConfig) -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold: config.failure_threshold,
success_threshold: config.success_threshold,
timeout_duration: Duration::from_secs(config.timeout_duration_secs),
window_duration: Duration::from_secs(config.window_duration_secs),
}
}
/// Convert config health check to core health config
fn convert_health_config(config: &HealthCheckConfig) -> HealthConfig {
HealthConfig {
timeout_secs: config.timeout_secs,
check_interval_secs: config.check_interval_secs,
endpoint: config.endpoint.clone(),
failure_threshold: config.failure_threshold,
success_threshold: config.success_threshold,
}
}
/// Flush cache on all workers
///
/// Sends a POST request to /flush_cache endpoint on all HTTP workers.
......@@ -1804,69 +368,3 @@ impl Drop for LoadMonitor {
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
#[test]
fn test_parse_server_info() {
let json = serde_json::json!({
"model_id": "llama-3",
"model_path": "/models/llama-3",
"dp_size": 4,
"version": "0.1.0"
});
let info = WorkerManager::parse_server_info(json).unwrap();
assert_eq!(info.model_id, Some("llama-3".to_string()));
assert_eq!(info.dp_size, Some(4));
}
#[test]
fn test_parse_server_info_with_fallback() {
let json = serde_json::json!({
"model": "gpt-4",
"dp_size": 2
});
let info = WorkerManager::parse_server_info(json).unwrap();
assert_eq!(info.model_id, Some("gpt-4".to_string()));
assert_eq!(info.dp_size, Some(2));
}
#[test]
fn test_parse_server_info_minimal() {
let json = serde_json::json!({});
let info = WorkerManager::parse_server_info(json).unwrap();
assert_eq!(info.model_id, None);
assert_eq!(info.dp_size, None);
}
#[test]
fn test_finalize_model_id_prefers_existing() {
let mut labels = HashMap::new();
labels.insert("model_id".to_string(), "manual-id".to_string());
labels.insert("served_model_name".to_string(), "auto-id".to_string());
WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "manual-id");
}
#[test]
fn test_finalize_model_id_prefers_served_name() {
let mut labels = HashMap::new();
labels.insert("served_model_name".to_string(), "served-name".to_string());
WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "served-name");
}
#[test]
fn test_finalize_model_id_falls_back_to_path() {
let mut labels = HashMap::new();
labels.insert("model_path".to_string(), "/models/alpha".to_string());
WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "/models/alpha");
}
}
......@@ -14,5 +14,5 @@ pub use engine::WorkflowEngine;
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
pub use executor::{FunctionStep, StepExecutor};
pub use state::WorkflowStateStore;
pub use steps::create_worker_registration_workflow;
pub use steps::{create_worker_registration_workflow, create_worker_removal_workflow};
pub use types::*;
......@@ -2,11 +2,17 @@
//!
//! This module contains concrete step implementations for various workflows:
//! - Worker registration and activation
//! - Worker removal
//! - Future: Tokenizer fetching, LoRA updates, etc.
pub mod worker_registration;
pub mod worker_removal;
pub use worker_registration::{
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
};
pub use worker_removal::{
create_worker_removal_workflow, FindWorkersToRemoveStep, RemoveFromPolicyRegistryStep,
RemoveFromWorkerRegistryStep, UpdateRemainingPoliciesStep, WorkerRemovalRequest,
};
......@@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
use async_trait::async_trait;
use once_cell::sync::Lazy;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::{debug, info, warn};
use crate::{
core::{
workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode,
DPAwareWorkerBuilder, DpInfo, HealthConfig, Worker, WorkerManager, WorkerType,
DPAwareWorkerBuilder, HealthConfig, Worker, WorkerType,
},
grpc_client::SglangSchedulerClient,
protocols::worker_spec::WorkerConfigRequest,
......@@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
.expect("Failed to create HTTP client")
});
/// Server information returned from worker endpoints
#[derive(Debug, Clone, Deserialize, Serialize)]
struct ServerInfo {
#[serde(alias = "model")]
model_id: Option<String>,
model_path: Option<String>,
dp_size: Option<usize>,
version: Option<String>,
max_batch_size: Option<usize>,
max_total_tokens: Option<usize>,
max_prefill_tokens: Option<usize>,
max_running_requests: Option<usize>,
max_num_reqs: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct DpInfo {
pub dp_size: usize,
pub model_id: String,
}
/// Parse server info from JSON response using serde
fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
serde_json::from_value(json).map_err(|e| format!("Failed to parse server info: {}", e))
}
/// Get server info from /get_server_info endpoint
async fn get_server_info(url: &str, api_key: Option<&str>) -> Result<ServerInfo, String> {
let base_url = url.trim_end_matches('/');
let server_info_url = format!("{}/get_server_info", base_url);
let mut req = HTTP_CLIENT.get(&server_info_url);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}
let response = req
.send()
.await
.map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?;
if !response.status().is_success() {
return Err(format!(
"Server returned status {} from {}",
response.status(),
server_info_url
));
}
let json = response
.json::<Value>()
.await
.map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?;
parse_server_info(json)
}
/// Get DP info for a worker URL
async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result<DpInfo, String> {
let info = get_server_info(url, api_key).await?;
let dp_size = info
.dp_size
.ok_or_else(|| format!("No dp_size in response from {}", url))?;
let model_id = info
.model_id
.or_else(|| {
info.model_path
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
})
.unwrap_or_else(|| "unknown".to_string());
Ok(DpInfo { dp_size, model_id })
}
/// Helper: Strip protocol prefix from URL
fn strip_protocol(url: &str) -> String {
url.trim_start_matches("http://")
......@@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin
Ok(())
}
/// Helper: Fetch HTTP metadata
async fn fetch_http_metadata(
url: &str,
api_key: Option<&str>,
) -> Result<HashMap<String, String>, String> {
let clean_url = strip_protocol(url);
let info_url = if clean_url.starts_with("http://") || clean_url.starts_with("https://") {
format!("{}/get_server_info", clean_url)
} else {
format!("http://{}/get_server_info", clean_url)
};
let mut request = HTTP_CLIENT.get(&info_url);
if let Some(key) = api_key {
request = request.header("Authorization", format!("Bearer {}", key));
}
let response = request
.send()
.await
.map_err(|e| format!("Failed to fetch HTTP metadata: {}", e))?;
let server_info: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse HTTP metadata: {}", e))?;
let mut labels = HashMap::new();
if let Some(model_path) = server_info.get("model_path").and_then(|v| v.as_str()) {
if !model_path.is_empty() {
labels.insert("model_path".to_string(), model_path.to_string());
}
}
if let Some(tokenizer_path) = server_info.get("tokenizer_path").and_then(|v| v.as_str()) {
if !tokenizer_path.is_empty() {
labels.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
}
}
Ok(labels)
}
/// Helper: Fetch gRPC metadata
async fn fetch_grpc_metadata(url: &str) -> Result<HashMap<String, String>, String> {
let grpc_url = if url.starts_with("grpc://") {
......@@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep {
let discovered_labels = match connection_mode.as_ref() {
ConnectionMode::Http => {
fetch_http_metadata(&config.url, config.api_key.as_deref()).await
match get_server_info(&config.url, config.api_key.as_deref()).await {
Ok(server_info) => {
let mut labels = HashMap::new();
if let Some(model_path) = server_info.model_path {
if !model_path.is_empty() {
labels.insert("model_path".to_string(), model_path);
}
}
Ok(labels)
}
Err(e) => Err(e),
}
}
ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await,
}
......@@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep {
debug!("Discovering DP info for {} (DP-aware)", config.url);
// Get DP info from worker
let dp_info = WorkerManager::get_dp_info(&config.url, config.api_key.as_deref())
let dp_info = get_dp_info(&config.url, config.api_key.as_deref())
.await
.map_err(|e| WorkflowError::StepFailed {
step_id: StepId::new("discover_dp_info"),
......@@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep {
);
// Store DP info in context
context.set("dp_info", Arc::new(dp_info));
context.set("dp_info", dp_info);
Ok(StepResult::Success)
}
......@@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep {
}
// Store workers (plural) and labels in context
context.set("workers", Arc::new(workers));
context.set("workers", workers);
context.set("labels", final_labels);
Ok(StepResult::Success)
......@@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep {
);
}
context.set("worker_ids", Arc::new(worker_ids));
context.set("worker_ids", worker_ids);
Ok(StepResult::Success)
} else {
// Non-DP-aware path: Register single worker
......
//! Worker Removal Workflow Steps
//!
//! This module implements the workflow steps for removing workers from the router.
//! Handles both single worker removal and DP-aware worker removal with prefix matching.
//!
//! Steps:
//! 1. FindWorkersToRemove - Identify workers to remove based on URL (handles DP-aware prefix matching)
//! 2. RemoveFromPolicyRegistry - Remove workers from policy registry and cache-aware policies
//! 3. RemoveFromWorkerRegistry - Remove workers from worker registry
//! 4. UpdateRemainingPolicies - Update cache-aware policies for remaining workers
use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait;
use tracing::{debug, info};
use crate::{
core::{workflow::*, Worker},
server::AppContext,
};
/// Request structure for worker removal
#[derive(Debug, Clone)]
pub struct WorkerRemovalRequest {
pub url: String,
pub dp_aware: bool,
}
/// Step 1: Find workers to remove based on URL
pub struct FindWorkersToRemoveStep;
#[async_trait]
impl StepExecutor for FindWorkersToRemoveStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let request: Arc<WorkerRemovalRequest> = context
.get("removal_request")
.ok_or_else(|| WorkflowError::ContextValueNotFound("removal_request".to_string()))?;
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
debug!(
"Finding workers to remove for {} (dp_aware: {})",
request.url, request.dp_aware
);
let workers_to_remove: Vec<Arc<dyn Worker>> = if request.dp_aware {
// DP-aware: Find all workers with matching prefix
let worker_url_prefix = format!("{}@", request.url);
let all_workers = app_context.worker_registry.get_all();
all_workers
.iter()
.filter(|worker| worker.url().starts_with(&worker_url_prefix))
.cloned()
.collect()
} else {
// Non-DP-aware: Find single worker by exact URL
match app_context.worker_registry.get_by_url(&request.url) {
Some(worker) => vec![worker],
None => Vec::new(),
}
};
if workers_to_remove.is_empty() {
let error_msg = if request.dp_aware {
format!("No workers found with prefix {}@", request.url)
} else {
format!("Worker {} not found", request.url)
};
return Err(WorkflowError::StepFailed {
step_id: StepId::new("find_workers_to_remove"),
message: error_msg,
});
}
debug!(
"Found {} worker(s) to remove for {}",
workers_to_remove.len(),
request.url
);
// Store workers and their model IDs for subsequent steps
let worker_urls: Vec<String> = workers_to_remove
.iter()
.map(|w| w.url().to_string())
.collect();
let affected_models: HashSet<String> = workers_to_remove
.iter()
.map(|w| w.model_id().to_string())
.collect();
context.set("workers_to_remove", workers_to_remove);
context.set("worker_urls", worker_urls);
context.set("affected_models", affected_models);
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Worker not found is not retryable
}
}
/// Step 2: Remove workers from policy registry
pub struct RemoveFromPolicyRegistryStep;
#[async_trait]
impl StepExecutor for RemoveFromPolicyRegistryStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let workers_to_remove: Arc<Vec<Arc<dyn Worker>>> = context
.get("workers_to_remove")
.ok_or_else(|| WorkflowError::ContextValueNotFound("workers_to_remove".to_string()))?;
debug!(
"Removing {} worker(s) from policy registry",
workers_to_remove.len()
);
for worker in workers_to_remove.iter() {
let model_id = worker.model_id().to_string();
let worker_url = worker.url();
// Remove from cache-aware policy
app_context
.policy_registry
.remove_worker_from_cache_aware(&model_id, worker_url);
// Notify policy registry
app_context.policy_registry.on_worker_removed(&model_id);
debug!(
"Removed worker {} from policy registry (model: {})",
worker_url, model_id
);
}
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Policy removal is not retryable
}
}
/// Step 3: Remove workers from worker registry
pub struct RemoveFromWorkerRegistryStep;
#[async_trait]
impl StepExecutor for RemoveFromWorkerRegistryStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let worker_urls: Arc<Vec<String>> = context
.get("worker_urls")
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?;
debug!(
"Removing {} worker(s) from worker registry",
worker_urls.len()
);
let mut removed_count = 0;
for worker_url in worker_urls.iter() {
if app_context
.worker_registry
.remove_by_url(worker_url)
.is_some()
{
removed_count += 1;
debug!("Removed worker {} from registry", worker_url);
}
}
if removed_count != worker_urls.len() {
return Err(WorkflowError::StepFailed {
step_id: StepId::new("remove_from_worker_registry"),
message: format!(
"Expected to remove {} workers but only removed {}",
worker_urls.len(),
removed_count
),
});
}
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Worker removal is not retryable
}
}
/// Step 4: Update cache-aware policies for remaining workers
pub struct UpdateRemainingPoliciesStep;
#[async_trait]
impl StepExecutor for UpdateRemainingPoliciesStep {
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
let app_context: Arc<AppContext> = context
.get("app_context")
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
let affected_models: Arc<HashSet<String>> = context
.get("affected_models")
.ok_or_else(|| WorkflowError::ContextValueNotFound("affected_models".to_string()))?;
let worker_urls: Arc<Vec<String>> = context
.get("worker_urls")
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?;
debug!(
"Updating cache-aware policies for {} affected model(s)",
affected_models.len()
);
for model_id in affected_models.iter() {
let remaining_workers = app_context.worker_registry.get_by_model_fast(model_id);
if let Some(policy) = app_context.policy_registry.get_policy(model_id) {
if policy.name() == "cache_aware" && !remaining_workers.is_empty() {
app_context
.policy_registry
.init_cache_aware_policy(model_id, &remaining_workers);
debug!(
"Updated cache-aware policy for model {} ({} remaining workers)",
model_id,
remaining_workers.len()
);
}
}
}
// Log final result at info level
if worker_urls.len() == 1 {
info!("Removed worker {}", worker_urls[0]);
} else {
info!(
"Removed {} DP-aware workers: {:?}",
worker_urls.len(),
worker_urls
);
}
Ok(StepResult::Success)
}
fn is_retryable(&self, _error: &WorkflowError) -> bool {
false // Policy update is not retryable
}
}
/// Create a worker removal workflow definition
pub fn create_worker_removal_workflow() -> WorkflowDefinition {
use std::time::Duration;
WorkflowDefinition::new("worker_removal", "Remove worker from router")
.add_step(
StepDefinition::new(
"find_workers_to_remove",
"Find workers to remove",
Arc::new(FindWorkersToRemoveStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
.add_step(
StepDefinition::new(
"remove_from_policy_registry",
"Remove workers from policy registry",
Arc::new(RemoveFromPolicyRegistryStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
.add_step(
StepDefinition::new(
"remove_from_worker_registry",
"Remove workers from worker registry",
Arc::new(RemoveFromWorkerRegistryStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
.add_step(
StepDefinition::new(
"update_remaining_policies",
"Update cache-aware policies for remaining workers",
Arc::new(UpdateRemainingPoliciesStep),
)
.with_timeout(Duration::from_secs(10))
.with_retry(RetryPolicy {
max_attempts: 1,
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
}),
)
}
......@@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
Ok(ChatCompletionRequest {
messages,
model: req.model.clone().unwrap_or_else(|| "default".to_string()),
model: if req.model.is_empty() {
"default".to_string()
} else {
req.model.clone()
},
temperature: req.temperature,
max_completion_tokens: req.max_output_tokens,
stream: is_streaming,
......@@ -311,7 +315,7 @@ mod tests {
let req = ResponsesRequest {
input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("You are a helpful assistant.".to_string()),
model: Some("gpt-4".to_string()),
model: "gpt-4".to_string(),
temperature: Some(0.7),
..Default::default()
};
......
......@@ -324,10 +324,11 @@ async fn route_responses_background(
incomplete_details: None,
instructions: request.instructions.clone(),
max_output_tokens: request.max_output_tokens,
model: request
.model
.clone()
.unwrap_or_else(|| "default".to_string()),
model: if request.model.is_empty() {
"default".to_string()
} else {
request.model.clone()
},
output: Vec::new(),
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(),
......@@ -622,10 +623,11 @@ async fn process_and_transform_sse_stream(
// Create event emitter for OpenAI-compatible streaming
let response_id = format!("resp_{}", Uuid::new_v4());
let model = original_request
.model
.clone()
.unwrap_or_else(|| "default".to_string());
let model = if original_request.model.is_empty() {
"default".to_string()
} else {
original_request.model.clone()
};
let created_at = chrono::Utc::now().timestamp() as u64;
let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at);
......
......@@ -608,10 +608,11 @@ async fn execute_tool_loop_streaming_internal(
// Create response event emitter
let response_id = format!("resp_{}", Uuid::new_v4());
let model = current_request
.model
.clone()
.unwrap_or_else(|| "default".to_string());
let model = if current_request.model.is_empty() {
"default".to_string()
} else {
current_request.model.clone()
};
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
......
......@@ -22,8 +22,12 @@ use tracing::{error, info, warn, Level};
use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{
worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor,
WorkerManager, WorkerRegistry, WorkerType,
worker_to_info,
workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
WorkflowEngine,
},
Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType,
},
data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
......@@ -439,51 +443,6 @@ async fn v1_conversations_delete_item(
.await
}
#[derive(Deserialize)]
struct AddWorkerQuery {
url: String,
api_key: Option<String>,
}
async fn add_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response {
// Warn if router has API key but worker is being added without one
if state.context.router_config.api_key.is_some() && api_key.is_none() {
warn!(
"Adding worker {} without API key while router has API key configured. \
Worker will be accessible without authentication. \
If the worker requires the same API key as the router, please specify it explicitly.",
url
);
}
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
match result {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
Json(json!({ "urls": worker_list })).into_response()
}
async fn remove_worker(
State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
) -> Response {
let result = WorkerManager::remove_worker(&url, &state.context);
match result {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
}
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
.await
......@@ -566,6 +525,12 @@ async fn create_worker(
);
}
// Populate dp_aware from router's configuration
let config = WorkerConfigRequest {
dp_aware: state.context.router_config.dp_aware,
..config
};
// Submit job for async processing
let worker_url = config.url.clone();
let job = Job::AddWorker {
......@@ -761,9 +726,6 @@ pub fn build_app(
.route("/get_server_info", get(get_server_info));
let admin_routes = Router::new()
.route("/add_worker", post(add_worker))
.route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads))
.route_layer(axum::middleware::from_fn_with_state(
......@@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
engine
.event_bus()
.subscribe(Arc::new(crate::core::workflow::LoggingSubscriber))
.subscribe(Arc::new(LoggingSubscriber))
.await;
engine.register_workflow(crate::core::workflow::create_worker_registration_workflow());
engine.register_workflow(create_worker_registration_workflow());
engine.register_workflow(create_worker_removal_workflow());
app_context
.workflow_engine
.set(engine)
.expect("WorkflowEngine should only be initialized once");
info!("Workflow engine initialized with worker registration workflow");
info!("Workflow engine initialized with worker registration and removal workflows");
info!(
"Initializing workers for routing mode: {:?}",
......
......@@ -18,11 +18,7 @@ use rustls;
use tokio::{task, time};
use tracing::{debug, error, info, warn};
use crate::{
core::{Job, WorkerManager},
protocols::worker_spec::WorkerConfigRequest,
server::AppContext,
};
use crate::{core::Job, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
#[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig {
......@@ -386,7 +382,7 @@ async fn handle_pod_event(
reasoning_parser: None,
tool_parser: None,
chat_template: None,
api_key: None,
api_key: app_context.router_config.api_key.clone(),
health_check_timeout_secs: app_context.router_config.health_check.timeout_secs,
health_check_interval_secs: app_context
.router_config
......@@ -453,8 +449,24 @@ async fn handle_pod_deletion(
pod_info.name, pod_info.pod_type, worker_url
);
if let Err(e) = WorkerManager::remove_worker(&worker_url, &app_context) {
error!("Failed to remove worker {}: {}", worker_url, e);
let job = Job::RemoveWorker {
url: worker_url.clone(),
};
if let Some(job_queue) = app_context.worker_job_queue.get() {
if let Err(e) = job_queue.submit(job).await {
error!(
"Failed to submit worker removal job for {}: {}",
worker_url, e
);
} else {
debug!("Submitted worker removal job for {}", worker_url);
}
} else {
error!(
"JobQueue not initialized, cannot remove worker {}",
worker_url
);
}
} else {
debug!(
......
......@@ -14,7 +14,7 @@ use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
},
core::WorkerManager,
core::Job,
routers::{RouterFactory, RouterTrait},
server::AppContext,
};
......@@ -112,22 +112,51 @@ impl TestContext {
// Create app context
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
// Submit worker initialization job (same as real server does)
if !worker_urls.is_empty() {
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
let job_queue = app_context
.worker_job_queue
.get()
.expect("JobQueue should be initialized");
let job = Job::InitializeWorkersFromConfig {
router_config: Box::new(config.clone()),
};
job_queue
.submit(job)
.await
.expect("Failed to initialize workers");
.expect("Failed to submit worker initialization job");
// Poll until all workers are healthy (up to 10 seconds)
let expected_count = worker_urls.len();
let start = tokio::time::Instant::now();
let timeout_duration = tokio::time::Duration::from_secs(10);
loop {
let healthy_workers = app_context
.worker_registry
.get_all()
.iter()
.filter(|w| w.is_healthy())
.count();
if healthy_workers >= expected_count {
break;
}
if start.elapsed() > timeout_duration {
panic!(
"Timeout waiting for {} workers to become healthy (only {} ready)",
expected_count, healthy_workers
);
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
// Create router
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
// Wait for router to discover workers
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self {
workers,
router,
......@@ -711,221 +740,6 @@ mod model_info_tests {
}
}
#[cfg(test)]
mod worker_management_tests {
use super::*;
#[tokio::test]
async fn test_add_new_worker() {
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Start a mock worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18301,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
// Add the worker
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// List workers to verify
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
assert!(workers.iter().any(|w| w.as_str().unwrap() == url));
worker.stop().await;
ctx.shutdown().await;
}
#[tokio::test]
async fn test_remove_existing_worker() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18302,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Get the worker URL
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
let worker_url = workers[0].as_str().unwrap();
// Remove the worker
let req = Request::builder()
.method("POST")
.uri(format!("/remove_worker?url={}", worker_url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
assert!(workers.is_empty());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_worker_invalid_url() {
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Invalid URL format
let req = Request::builder()
.method("POST")
.uri("/add_worker?url=not-a-valid-url")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Missing URL parameter
let req = Request::builder()
.method("POST")
.uri("/add_worker")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Empty URL
let req = Request::builder()
.method("POST")
.uri("/add_worker?url=")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_duplicate_worker() {
// Start a mock worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18303,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Add worker first time
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Try to add same worker again
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Should return error for duplicate
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
worker.stop().await;
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_unhealthy_worker() {
// Start unhealthy worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18304,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Unhealthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Try to add unhealthy worker
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Router should reject unhealthy workers
assert!(
resp.status() == StatusCode::BAD_REQUEST
|| resp.status() == StatusCode::SERVICE_UNAVAILABLE
);
worker.stop().await;
ctx.shutdown().await;
}
}
#[cfg(test)]
mod router_policy_tests {
use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment