Unverified Commit a730ce81 authored by Rui Chen's avatar Rui Chen Committed by GitHub
Browse files

[feature] [sgl-router] Add a dp-aware routing strategy (#6869)

parent 55ecdc0a
......@@ -141,6 +141,14 @@ Process:
For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers.
***Data-Parallelism Aware Routing***
An additional DP-aware routing strategy can be enabled on top of the sgl-router’s hybrid cache-aware load-balancing strategy by setting the `--dp-aware` flag when starting the router.
When this flag is enabled, the router attempts to contact the workers to retrieve the `dp_size` of each one and registers the new workers at the DP-rank level. In this mode, the router applies the cache-aware routing strategy in a more fine-grained manner, with assistance from the DP controller on the SRT side.
By default (when the flag is not set), the SRT’s DP controller distributes incoming requests across DP ranks in a round-robin fashion.
## Configuration Parameters
1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5)
......
......@@ -50,6 +50,8 @@ class RouterArgs:
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
dp_aware: bool = False
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
# Service discovery configuration
......@@ -197,6 +199,17 @@ class RouterArgs:
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}dp-aware",
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
default=None,
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
)
parser.add_argument(
f"--{prefix}log-dir",
type=str,
......@@ -304,6 +317,8 @@ class RouterArgs:
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
dp_aware=getattr(args, f"{prefix}dp_aware", False),
api_key=getattr(args, f"{prefix}api_key", None),
log_dir=getattr(args, f"{prefix}log_dir", None),
log_level=getattr(args, f"{prefix}log_level", None),
service_discovery=getattr(args, f"{prefix}service_discovery", False),
......@@ -463,6 +478,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
dp_aware=router_args.dp_aware,
api_key=router_args.api_key,
log_dir=router_args.log_dir,
log_level=router_args.log_level,
service_discovery=router_args.service_discovery,
......
......@@ -31,6 +31,10 @@ class Router:
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
api_key: The api key used for the authorization with the worker.
Useful when the dp aware scheduling strategy is enabled.
Default: None
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
......@@ -73,6 +77,8 @@ class Router:
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 256 * 1024 * 1024, # 256MB
dp_aware: bool = False,
api_key: Optional[str] = None,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
service_discovery: bool = False,
......@@ -110,6 +116,8 @@ class Router:
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
dp_aware=dp_aware,
api_key=api_key,
log_dir=log_dir,
log_level=log_level,
service_discovery=service_discovery,
......
......@@ -8,7 +8,7 @@ if __name__ == "__main__":
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=1000,
default=2000,
help="The time limit for running one file in seconds.",
)
args = arg_parser.parse_args()
......
......@@ -43,6 +43,7 @@ class TestLaunchRouter(unittest.TestCase):
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
dp_aware=False,
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
......@@ -111,6 +112,52 @@ class TestLaunchRouter(unittest.TestCase):
)
self.run_router_process(args)
def test_launch_router_common_with_dp_aware(self):
args = self.create_router_args(
worker_urls=["http://localhost:8000"],
dp_aware=True,
)
self.run_router_process(args)
def test_launch_router_with_empty_worker_urls_with_dp_aware(self):
args = self.create_router_args(
worker_urls=[],
dp_aware=True,
)
self.run_router_process(args)
def test_launch_router_common_with_dp_aware_service_discovery(self):
# Test launch router with bot srevice_discovery and dp_aware enabled
# Should fail since service_discovery and dp_aware is conflict
args = self.create_router_args(
worker_urls=["http://localhost:8000"],
dp_aware=True,
service_discovery=True,
selector=["app=test-worker"],
)
def run_router():
try:
from sglang_router.launch_router import launch_router
router = launch_router(args)
if router is None:
return 1
return 0
except Exception as e:
print(e)
return 1
process = multiprocessing.Process(target=run_router)
try:
process.start()
# Wait 3 seconds
time.sleep(3)
# Should fail since service_discovery and dp_aware is conflict
self.assertFalse(process.is_alive())
finally:
terminate_process(process)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
......
......@@ -30,6 +30,7 @@ def popen_launch_router(
service_discovery_namespace: str = None,
prometheus_port: int = None,
prometheus_host: str = None,
dp_aware: bool = False,
):
"""
Launch the router server process.
......@@ -49,6 +50,7 @@ def popen_launch_router(
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
prometheus_host: Host address to bind the Prometheus metrics server.
dp_aware: Enable data parallelism aware routing strategy.
"""
_, host, port = base_url.split(":")
host = host[2:]
......@@ -69,10 +71,12 @@ def popen_launch_router(
"5",
"--router-policy",
policy,
"--allow-auto-truncate",
]
if api_key is not None:
command.extend(["--api-key", api_key])
command.extend(["--router-api-key", api_key])
if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])
......@@ -100,6 +104,9 @@ def popen_launch_router(
if log_dir is not None:
command.extend(["--log-dir", log_dir])
if dp_aware:
command.append("--router-dp-aware")
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.perf_counter()
......@@ -127,6 +134,7 @@ def popen_launch_server(
model: str,
base_url: str,
timeout: float,
api_key: str = None,
):
_, host, port = base_url.split(":")
host = host[2:]
......@@ -145,6 +153,9 @@ def popen_launch_server(
"1",
]
if api_key is not None:
command.extend(["--api-key", api_key])
process = subprocess.Popen(command, stdout=None, stderr=None)
# intentionally don't wait and defer the job to the router health check
......@@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase):
response.status_code, 200, "Request with correct api key should succeed"
)
def test_6_mmlu_with_dp_aware(self):
print("Running test_6_mmlu_with_dp_aware...")
# DP size = 2
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=2,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="cache_aware",
dp_aware=True,
)
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_7_add_and_remove_worker_with_dp_aware(self):
print("Running test_7_add_and_remove_worker_with_dp_aware...")
# Set dp_size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin", # make sure every worker processes requests
dp_aware=True, # dp aware strategy should work well with RR
)
# 1. Start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. Run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 4. Use the /remove_worker API to remove it from the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 5. Run mmlu again
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 6. Start another worker with api_key set
terminate_and_wait(worker_process) # terminate the old worker process
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model,
worker_url,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key="correct_api_key",
)
self.other_process.append(worker_process)
# 7. Use the /add_worker API to add it to the router
# Should fail since the router would contact the worker's
# /get_server_info endpoint for the dp_size info, but it
# has no knowledge of the api key
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertNotEqual(response.status_code, 200)
def test_8_lazy_fault_tolerance_with_dp_aware(self):
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
# Set dp_size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
dp_aware=True,
)
# 1. Start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# Start a thread to kill the worker after 10 seconds to mimic
# abrupt worker failure
def kill_worker():
time.sleep(10)
kill_process_tree(worker_process.pid)
print("Worker process killed")
import threading
kill_thread = threading.Thread(target=kill_worker)
kill_thread.daemon = True
kill_thread.start()
# 3. Run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=256,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_9_payload_size_with_dp_aware(self):
print("Running test_9_payload_size_with_dp_aware...")
# Start the router with 1MB limit
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
max_payload_size=1 * 1024 * 1024, # 1MB limit
dp_aware=True,
)
# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb = {
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_0_5_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
200,
f"0.5MB payload should succeed but got status {response.status_code}",
)
# Test case 2: Payload over 1MB should fail
payload_1_plus_mb = {
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_1_plus_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
413, # Payload Too Large
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)
def test_10_api_key_with_dp_aware(self):
print("Running test_10_api_key_with_dp_aware...")
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
api_key="correct_api_key",
dp_aware=True,
)
# Test case 1: request without api key should fail
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
f"Request without api key should fail with 401 but got status {response.status_code}",
)
# Test case 2: request with invalid api key should fail
with requests.Session() as session:
response = requests.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
headers={"Authorization": "Bearer 123"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
f"Request without api key should fail with 401 but got status {response.status_code}",
)
# Test case 3: request with correct api key should succeed
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is ", "temperature": 0},
headers={"Authorization": "Bearer correct_api_key"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
200,
f"Request with correct api key should succeed but got status {response.status_code}",
)
if __name__ == "__main__":
unittest.main()
......@@ -21,6 +21,10 @@ pub struct RouterConfig {
pub worker_startup_timeout_secs: u64,
/// Worker health check interval in seconds
pub worker_startup_check_interval_secs: u64,
/// Enable data parallelism aware schedule
pub dp_aware: bool,
/// The api key used for the authorization with the worker
pub api_key: Option<String>,
/// Service discovery configuration (optional)
pub discovery: Option<DiscoveryConfig>,
/// Metrics configuration (optional)
......@@ -205,6 +209,8 @@ impl Default for RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......@@ -311,6 +317,8 @@ mod tests {
request_timeout_secs: 30,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig::default()),
metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()),
......@@ -727,6 +735,8 @@ mod tests {
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
......@@ -774,6 +784,8 @@ mod tests {
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: None,
......@@ -812,6 +824,8 @@ mod tests {
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),
......
......@@ -313,6 +313,14 @@ impl ConfigValidator {
}
}
// Service discovery is conflict with dp_aware routing for now
// since it's not fully supported yet
if has_service_discovery && config.dp_aware {
return Err(ConfigError::IncompatibleConfig {
reason: "DP-aware routing is not compatible with service discovery".to_string(),
});
}
Ok(())
}
......
......@@ -17,6 +17,8 @@ pub enum WorkerError {
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
/// Invalid URL format
InvalidUrl { url: String },
}
impl fmt::Display for WorkerError {
......@@ -37,6 +39,9 @@ impl fmt::Display for WorkerError {
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
WorkerError::InvalidUrl { url } => {
write!(f, "Invalid URL format: {}", url)
}
}
}
}
......
......@@ -162,6 +162,27 @@ impl BasicWorker {
self.metadata.health_config = config;
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let parts: Vec<&str> = self.url().split('@').collect();
if parts.len() != 2 {
return Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
});
}
// Ensure the second part (the dp_rank) can be parsed as an integer
match parts[1].parse::<usize>() {
Ok(_) => Ok(parts[0]),
Err(_) => Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
}),
}
} else {
Ok(self.url())
}
}
}
#[async_trait]
......@@ -186,7 +207,8 @@ impl Worker for BasicWorker {
use std::time::Duration;
// Perform actual HTTP health check
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
......@@ -203,7 +225,7 @@ impl Worker for BasicWorker {
} else {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
url: url.to_string(),
reason: format!("Health check returned status: {}", response.status()),
})
}
......@@ -211,7 +233,7 @@ impl Worker for BasicWorker {
Err(e) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
url: url.to_string(),
reason: format!("Health check request failed: {}", e),
})
}
......
......@@ -37,6 +37,8 @@ struct Router {
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
......@@ -136,6 +138,8 @@ impl Router {
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
......@@ -161,6 +165,8 @@ impl Router {
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
dp_aware = false,
api_key = None,
log_dir = None,
log_level = None,
service_discovery = false,
......@@ -193,6 +199,8 @@ impl Router {
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
......@@ -225,6 +233,8 @@ impl Router {
eviction_interval_secs,
max_tree_size,
max_payload_size,
dp_aware,
api_key,
log_dir,
log_level,
service_discovery,
......
......@@ -45,6 +45,8 @@ impl RouterFactory {
policy,
router_config.worker_startup_timeout_secs,
router_config.worker_startup_check_interval_secs,
router_config.dp_aware,
router_config.api_key.clone(),
)?;
Ok(Box::new(router))
......
......@@ -30,6 +30,8 @@ pub struct Router {
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
......@@ -42,6 +44,8 @@ impl Router {
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
) -> Result<Self, String> {
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
......@@ -51,6 +55,14 @@ impl Router {
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
}
let worker_urls = if dp_aware {
// worker address now in the format of "http://host:port@dp_rank"
Self::get_dp_aware_workers(&worker_urls, &api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else {
worker_urls
};
// Create Worker trait objects from URLs
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
......@@ -89,6 +101,8 @@ impl Router {
policy,
timeout_secs,
interval_secs,
dp_aware,
api_key,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
......@@ -160,6 +174,62 @@ impl Router {
}
}
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
let sync_client = reqwest::blocking::Client::new();
let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send() {
Ok(res) => {
if res.status().is_success() {
let server_info = res
.text()
.map_err(|e| format!("failed to read text from response: {}", e))?;
let server_info: serde_json::Value = serde_json::from_str(&server_info)
.map_err(|e| format!("failed to decode JSON: {}", e))?;
let dp_size = server_info
.get("dp_size")
.and_then(|v| v.as_u64())
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
Ok(if dp_size > usize::MAX as u64 {
return Err(format!("dp_size is too large: {}", dp_size));
} else {
dp_size as usize
})
} else {
Err(format!("unexpected status code: {}", res.status()))
}
}
Err(e) => Err(format!("error response: {}", e)),
}
}
// Given a list of workers, return a list of workers with dp_rank as suffix
fn get_dp_aware_workers(
worker_urls: &[String],
api_key: &Option<String>,
) -> Result<Vec<String>, String> {
let mut dp_aware_workers: Vec<String> = Vec::new();
for url in worker_urls {
match Self::get_worker_dp_size(url, api_key) {
Ok(dp_size) => {
for i in 0..dp_size {
dp_aware_workers.push(format!("{}@{}", url, i));
}
}
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
}
}
Ok(dp_aware_workers)
}
fn select_first_worker(&self) -> Result<String, String> {
let workers_guard = self.workers.read().unwrap();
if workers_guard.is_empty() {
......@@ -178,6 +248,21 @@ impl Router {
) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now();
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
worker_url_prefix
} else {
worker_url
};
let mut request_builder = client.get(format!("{}{}", worker_url, route));
// Copy all headers from original request except for /health because it does not need authorization
......@@ -292,7 +377,7 @@ impl Router {
worker_url = %worker_url,
"Removing failed worker"
);
self.remove_worker(&worker_url);
self.remove_failed_worker(&worker_url);
break;
}
}
......@@ -392,7 +477,7 @@ impl Router {
request_id = %request_id,
"Removing failed worker after typed request failures worker_url={}", worker_url
);
self.remove_worker(&worker_url);
self.remove_failed_worker(&worker_url);
break;
}
}
......@@ -415,6 +500,23 @@ impl Router {
}
}
// TODO (rui): Better accommodate to the Worker abstraction
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
let parts: Vec<&str> = worker_url.split('@').collect();
if parts.len() != 2 {
return Err(format!("invalid worker_url format: {}", worker_url));
}
// Parse the second part (dp_rank) into an integer
match parts[1].parse::<usize>() {
Ok(dp_rank) => Ok((parts[0], dp_rank)),
Err(_) => Err(format!(
"failed to parse dp_rank from worker_url: {}",
worker_url
)),
}
}
// Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self,
......@@ -429,9 +531,47 @@ impl Router {
let request_id = get_request_id(req);
let start = Instant::now();
let mut request_builder = client
let mut request_builder = if self.dp_aware {
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
// Parse the request body
let mut json_val = match serde_json::to_value(typed_req) {
Ok(j) => j,
Err(e) => {
return HttpResponse::BadRequest()
.body(format!("Convert into serde_json::Value failed: {}", e));
}
};
// Insert the data_parallel_rank field
if let Some(map) = json_val.as_object_mut() {
map.insert(
String::from("data_parallel_rank"),
serde_json::json!(dp_rank),
);
debug!(
"Modified request body: {}",
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
);
} else {
return HttpResponse::BadRequest()
.body("Failed to insert the data_parallel_rank field into the request body");
}
client
.post(format!("{}{}", worker_url_prefix, route))
.json(&json_val)
} else {
client
.post(format!("{}{}", worker_url, route))
.json(typed_req); // Use json() directly with typed request
.json(typed_req) // Use json() directly with typed request
};
// Copy all headers from original request
for (name, value) in copy_request_headers(req) {
......@@ -560,12 +700,35 @@ impl Router {
Ok(res) => {
if res.status().is_success() {
let mut workers_guard = self.workers.write().unwrap();
if self.dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
if workers_guard.iter().any(|w| w.url() == dp_url) {
warn!("Worker {} already exists", dp_url);
continue;
}
info!("Added worker: {}", dp_url);
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
workers_guard.push(new_worker);
worker_added = true;
}
if !worker_added {
return Err(format!("No worker added for {}", worker_url));
}
} else {
if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
}
RouterMetrics::set_active_workers(workers_guard.len());
// If cache aware policy, initialize the worker in the tree
......@@ -612,7 +775,53 @@ impl Router {
}
}
/// Remove all the worker(s) that match the URL prefix
pub fn remove_worker(&self, worker_url: &str) {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut candidate_workers: Vec<String> = Vec::new();
let mut removed_workers: Vec<String> = Vec::new();
let worker_url_prefix = format!("{}@", worker_url);
{
// find the candidate workers to be removed
let workers_guard = self.workers.read().unwrap();
for w in workers_guard.iter() {
if w.url().starts_with(&worker_url_prefix) {
candidate_workers.push(w.url().to_string());
}
}
}
{
// do the removing on the worker_urls
let mut workers_guard = self.workers.write().unwrap();
for dp_url in candidate_workers.iter() {
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
workers_guard.remove(index);
info!("Removed worker: {}", dp_url);
removed_workers.push(dp_url.to_string());
} else {
warn!("Worker {} not found, skipping removal", dp_url);
continue;
}
}
RouterMetrics::set_active_workers(workers_guard.len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
for dp_url in removed_workers.iter() {
cache_aware.remove_worker(dp_url);
info!("Removed worker from tree: {}", dp_url);
}
}
} else {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
......@@ -623,6 +832,30 @@ impl Router {
return;
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker(worker_url);
info!("Removed worker from tree: {}", worker_url);
}
}
}
/// Remove a specific failed worker; for internal usage
fn remove_failed_worker(&self, worker_url: &str) {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed failed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
}
// If cache aware policy, remove the worker from the tree
if let Some(cache_aware) = self
.policy
......@@ -634,6 +867,20 @@ impl Router {
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
......@@ -710,6 +957,20 @@ impl Router {
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
debug!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
......@@ -862,6 +1123,19 @@ impl RouterTrait for Router {
// Send requests to all workers concurrently without headers
let mut tasks = Vec::new();
for worker_url in &worker_urls {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
worker_url_prefix
} else {
worker_url
};
let request_builder = client.post(format!("{}/flush_cache", worker_url));
tasks.push(request_builder.send());
}
......@@ -948,6 +1222,8 @@ mod tests {
policy: Arc::new(RandomPolicy::new()),
timeout_secs: 5,
interval_secs: 1,
dp_aware: false,
api_key: None,
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,
......
......@@ -581,7 +581,7 @@ mod tests {
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, 5, 1).unwrap();
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
......
......@@ -31,6 +31,8 @@ impl TestContext {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......@@ -950,6 +952,8 @@ mod error_tests {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......
......@@ -16,6 +16,8 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......@@ -37,6 +39,8 @@ pub fn create_test_config_no_workers() -> RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 0, // No wait
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......
......@@ -42,6 +42,8 @@ impl RequestTestContext {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......
......@@ -46,6 +46,8 @@ impl StreamingTestContext {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......
......@@ -169,6 +169,8 @@ mod test_pd_routing {
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment