Unverified Commit 30f2a44a authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[misc] Add PD service discovery support in router (#7361)

parent bd4f5818
......@@ -35,6 +35,7 @@ metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0"
# Added for request tracing
uuid = { version = "1.10", features = ["v4", "serde"] }
thiserror = "2.0.12"
[profile.release]
lto = "thin"
codegen-units = 1
......@@ -95,38 +95,217 @@ python -m sglang_router.launch_router \
### Kubernetes Service Discovery
SGL Router supports automatic service discovery for worker nodes in Kubernetes environments. When enabled, the router will automatically:
SGL Router supports automatic service discovery for worker nodes in Kubernetes environments. This feature works with both regular (single-server) routing and PD (Prefill-Decode) routing modes. When enabled, the router will automatically:
- Discover and add worker pods with matching labels
- Remove unhealthy or deleted worker pods
- Dynamically adjust the worker pool based on pod health and availability
- For PD mode: distinguish between prefill and decode servers based on labels
#### Command Line Usage
#### Regular Mode Service Discovery
For traditional single-server routing:
```bash
python -m sglang_router.launch_router \
--service-discovery \
--selector app=sglang-worker role=inference \
--service-discovery-port 8000 \
--service-discovery-namespace default
```
#### PD Mode Service Discovery
For PD (Prefill-Decode) disaggregated routing, service discovery can automatically discover and classify pods as either prefill or decode servers based on their labels:
```bash
python -m sglang_router.launch_router \
--pd-disaggregation \
--policy cache_aware \
--service-discovery \
--prefill-selector app=sglang component=prefill \
--decode-selector app=sglang component=decode \
--service-discovery-namespace sglang-system
```
You can also specify initial prefill and decode servers and let service discovery add more:
```bash
python -m sglang_router.launch_router \
--pd-disaggregation \
--policy cache_aware \
--prefill http://prefill-1:8000 8001 \
--decode http://decode-1:8000 \
--service-discovery \
--prefill-selector app=sglang component=prefill \
--decode-selector app=sglang component=decode \
--service-discovery-namespace sglang-system
```
#### Kubernetes Pod Configuration for PD Mode
When using PD service discovery, your Kubernetes pods need specific labels to be classified as prefill or decode servers:
**Prefill Server Pod:**
```yaml
apiVersion: v1
kind: Pod
metadata:
name: sglang-prefill-1
labels:
app: sglang
component: prefill
annotations:
sglang.ai/bootstrap-port: "9001" # Optional: Bootstrap port for Mooncake prefill coordination
spec:
containers:
- name: sglang
image: lmsys/sglang:latest
ports:
- containerPort: 8000 # Main API port
- containerPort: 9001 # Optional: Bootstrap coordination port
# ... rest of configuration
```
**Decode Server Pod:**
```yaml
apiVersion: v1
kind: Pod
metadata:
name: sglang-decode-1
labels:
app: sglang
component: decode
spec:
containers:
- name: sglang
image: lmsys/sglang:latest
ports:
- containerPort: 8000 # Main API port
# ... rest of configuration
```
**Key Requirements:**
- Prefill pods must have labels matching your `--prefill-selector`
- Decode pods must have labels matching your `--decode-selector`
- Prefill pods can optionally include bootstrap port in annotations using `sglang.ai/bootstrap-port` (defaults to None if not specified)
#### Service Discovery Arguments
**General Arguments:**
- `--service-discovery`: Enable Kubernetes service discovery feature
- `--selector`: One or more label key-value pairs for pod selection (format: key1=value1 key2=value2)
- `--service-discovery-port`: Port to use when generating worker URLs (default: 80)
- `--service-discovery-port`: Port to use when generating worker URLs (default: 8000)
- `--service-discovery-namespace`: Optional. Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)
- `--selector`: One or more label key-value pairs for pod selection in regular mode (format: key1=value1 key2=value2)
**PD Mode Arguments:**
- `--pd-disaggregation`: Enable PD (Prefill-Decode) disaggregated mode
- `--prefill`: Specify initial prefill server URL and bootstrap port (format: URL BOOTSTRAP_PORT, can be used multiple times)
- `--decode`: Specify initial decode server URL (can be used multiple times)
- `--prefill-selector`: Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)
- `--decode-selector`: Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)
- `--policy`: Routing policy (cache_aware, random, power_of_two - note: power_of_two only works in PD mode)
**Notes:**
- Bootstrap port annotation is automatically set to `sglang.ai/bootstrap-port` for Mooncake deployments
- Advanced cache tuning parameters use sensible defaults and are not exposed via CLI
#### RBAC Requirements
When using service discovery, you must configure proper Kubernetes RBAC permissions:
- **If using namespace-scoped discovery** (with `--service-discovery-namespace`):
Set up a ServiceAccount, Role, and RoleBinding
**Namespace-scoped (recommended):**
```yaml
apiVersion: v1
kind: ServiceAccount
metadata:
name: sglang-router
namespace: sglang-system
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
namespace: sglang-system
name: sglang-router
rules:
- apiGroups: [""]
resources: ["pods"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: sglang-router
namespace: sglang-system
subjects:
- kind: ServiceAccount
name: sglang-router
namespace: sglang-system
roleRef:
kind: Role
name: sglang-router
apiGroup: rbac.authorization.k8s.io
```
**Cluster-wide (if watching all namespaces):**
```yaml
apiVersion: v1
kind: ServiceAccount
metadata:
name: sglang-router
namespace: sglang-system
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: sglang-router
rules:
- apiGroups: [""]
resources: ["pods"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: sglang-router
subjects:
- kind: ServiceAccount
name: sglang-router
namespace: sglang-system
roleRef:
kind: ClusterRole
name: sglang-router
apiGroup: rbac.authorization.k8s.io
```
#### Complete Example: PD Mode with Service Discovery
Here's a complete example of running SGLang Router with PD mode and service discovery:
```bash
# Start the router with PD mode and automatic prefill/decode discovery
python -m sglang_router.launch_router \
--pd-disaggregation \
--policy cache_aware \
--service-discovery \
--prefill-selector app=sglang component=prefill environment=production \
--decode-selector app=sglang component=decode environment=production \
--service-discovery-namespace production \
--host 0.0.0.0 \
--port 8080 \
--prometheus-host 0.0.0.0 \
--prometheus-port 9090
```
This setup will:
1. Enable PD (Prefill-Decode) disaggregated routing mode with automatic pod classification
2. Watch for pods in the `production` namespace
3. Automatically add prefill servers with labels `app=sglang`, `component=prefill`, `environment=production`
4. Automatically add decode servers with labels `app=sglang`, `component=decode`, `environment=production`
5. Extract bootstrap ports from the `sglang.ai/bootstrap-port` annotation on prefill pods
6. Use cache-aware load balancing for optimal performance
7. Expose the router API on port 8080 and metrics on port 9090
- **If watching all namespaces** (without specifying namespace):
Set up a ServiceAccount, ClusterRole, and ClusterRoleBinding with permissions to list/watch pods at the cluster level
**Note:** In PD mode with service discovery, pods MUST match either the prefill or decode selector to be added. Pods that don't match either selector are ignored.
### Troubleshooting
......
......@@ -32,7 +32,7 @@ class RouterArgs:
port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
pd_disaggregation: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
......@@ -55,6 +55,10 @@ class RouterArgs:
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
service_discovery_port: int = 80
service_discovery_namespace: Optional[str] = None
# PD service discovery configuration
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
......@@ -108,7 +112,7 @@ class RouterArgs:
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
f"--{prefix}pd-disaggregation",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
......@@ -207,6 +211,18 @@ class RouterArgs:
type=str,
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
)
parser.add_argument(
f"--{prefix}prefill-selector",
type=str,
nargs="+",
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}decode-selector",
type=str,
nargs="+",
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
)
# Prometheus configuration
parser.add_argument(
f"--{prefix}prometheus-port",
......@@ -243,7 +259,7 @@ class RouterArgs:
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
......@@ -267,6 +283,13 @@ class RouterArgs:
service_discovery_namespace=getattr(
args, f"{prefix}service_discovery_namespace", None
),
prefill_selector=cls._parse_selector(
getattr(args, f"{prefix}prefill_selector", None)
),
decode_selector=cls._parse_selector(
getattr(args, f"{prefix}decode_selector", None)
),
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
)
......@@ -355,17 +378,20 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if router_args.pd_disaggregation:
# Validate PD configuration - skip URL requirements if using service discovery
if not router_args.service_discovery:
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
raise ValueError("PD disaggregation mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
raise ValueError("PD disaggregation mode requires --decode")
# Create router with unified constructor
router = Router(
worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
[]
if router_args.service_discovery or router_args.pd_disaggregation
else router_args.worker_urls
),
host=router_args.host,
port=router_args.port,
......@@ -384,14 +410,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
prefill_selector=router_args.prefill_selector,
decode_selector=router_args.decode_selector,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
router_args.prefill_urls if router_args.pd_disaggregation else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
router_args.decode_urls if router_args.pd_disaggregation else None
),
)
......@@ -425,7 +453,7 @@ Examples:
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
......
......@@ -41,9 +41,13 @@ class Router:
worker URLs using this port. Default: 80
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for prefill servers (PD mode only). Default: {}
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for decode servers (PD mode only). Default: {}
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
"""
......@@ -68,14 +72,20 @@ class Router:
selector: Dict[str, str] = None,
service_discovery_port: int = 80,
service_discovery_namespace: Optional[str] = None,
prefill_selector: Dict[str, str] = None,
decode_selector: Dict[str, str] = None,
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
pd_disaggregated: bool = False,
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
):
if selector is None:
selector = {}
if prefill_selector is None:
prefill_selector = {}
if decode_selector is None:
decode_selector = {}
self._router = _Router(
worker_urls=worker_urls,
......@@ -96,9 +106,11 @@ class Router:
selector=selector,
service_discovery_port=service_discovery_port,
service_discovery_namespace=service_discovery_namespace,
prefill_selector=prefill_selector,
decode_selector=decode_selector,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
pd_disaggregated=pd_disaggregated,
pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
)
......
......@@ -45,7 +45,7 @@ class TestLaunchRouter(unittest.TestCase):
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
pd_disaggregated=False,
pd_disaggregation=False,
prefill=None,
decode=None,
# Keep worker_urls for regular mode
......@@ -119,7 +119,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregated=True,
pd_disaggregation=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
......@@ -133,7 +133,7 @@ class TestLaunchRouter(unittest.TestCase):
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregated)
self.assertTrue(router_args.pd_disaggregation)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
......@@ -147,7 +147,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregated=True,
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
......@@ -165,7 +165,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test 1: PowerOfTwo is only valid in PD mode
args = self.create_router_args(
pd_disaggregated=False,
pd_disaggregation=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"],
)
......@@ -180,7 +180,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test 2: RoundRobin is not valid in PD mode
args = self.create_router_args(
pd_disaggregated=True,
pd_disaggregation=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
......@@ -198,7 +198,7 @@ class TestLaunchRouter(unittest.TestCase):
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregated=False,
pd_disaggregation=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
......@@ -206,7 +206,7 @@ class TestLaunchRouter(unittest.TestCase):
# PD mode with PowerOfTwo
args = self.create_router_args(
pd_disaggregated=True,
pd_disaggregation=True,
policy="power_of_two",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
......@@ -214,6 +214,79 @@ class TestLaunchRouter(unittest.TestCase):
)
# This should not raise (though it may fail to connect)
def test_pd_service_discovery_args_parsing(self):
"""Test PD service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--pd-disaggregation",
"--service-discovery",
"--prefill-selector",
"app=sglang",
"component=prefill",
"--decode-selector",
"app=sglang",
"component=decode",
"--service-discovery-port",
"8000",
"--service-discovery-namespace",
"production",
"--policy",
"cache_aware",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
)
self.assertEqual(
router_args.decode_selector, {"app": "sglang", "component": "decode"}
)
self.assertEqual(router_args.service_discovery_port, 8000)
self.assertEqual(router_args.service_discovery_namespace, "production")
def test_regular_service_discovery_args_parsing(self):
"""Test regular mode service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--service-discovery",
"--selector",
"app=sglang-worker",
"environment=staging",
"--service-discovery-port",
"8000",
"--policy",
"round_robin",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertFalse(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
)
self.assertEqual(router_args.prefill_selector, {})
self.assertEqual(router_args.decode_selector, {})
if __name__ == "__main__":
unittest.main()
......@@ -42,12 +42,16 @@ struct Router {
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
// PD service discovery fields
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
}
......@@ -74,10 +78,13 @@ impl Router {
selector = HashMap::new(),
service_discovery_port = 80,
service_discovery_namespace = None,
prefill_selector = HashMap::new(),
decode_selector = HashMap::new(),
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
prometheus_port = None,
prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
pd_disaggregation = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))]
......@@ -100,10 +107,13 @@ impl Router {
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> {
......@@ -126,17 +136,20 @@ impl Router {
selector,
service_discovery_port,
service_discovery_namespace,
prefill_selector,
decode_selector,
bootstrap_port_annotation,
prometheus_port,
prometheus_host,
request_timeout_secs,
pd_disaggregated,
pd_disaggregation,
prefill_urls,
decode_urls,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = if self.pd_disaggregated {
let policy_config = if self.pd_disaggregation {
// PD mode - map PolicyType to PDSelectionPolicy
let pd_selection_policy = match &self.policy {
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
......@@ -207,6 +220,11 @@ impl Router {
check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(),
// PD mode configuration
pd_mode: self.pd_disaggregation,
prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(),
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
})
} else {
None
......
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use crate::pd_types::{Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDSelectionPolicy};
use crate::pd_types::{
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
};
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
......@@ -65,12 +67,145 @@ impl Drop for LoadGuard<'_> {
}
impl PDRouter {
// TODO: Add methods for dynamic worker management to support /register endpoint:
// - add_prefill_server(url: String, bootstrap_port: Option<u16>)
// - add_decode_server(url: String)
// - remove_prefill_server(url: &str)
// - remove_decode_server(url: &str)
// These methods will be used when service discovery is implemented for PD mode
// Dynamic worker management methods for service discovery
pub async fn add_prefill_server(
&self,
url: String,
bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> {
// Create EngineInfo for the new prefill server
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Add to prefill workers list
let mut workers = self
.prefill_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "prefill_workers write".to_string(),
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
// Add to cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
tree.lock().unwrap().insert("", &url);
}
info!("Added prefill server: {}", url);
Ok(format!("Successfully added prefill server: {}", url))
}
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
// Create EngineInfo for the new decode server
let engine_info = EngineInfo::new_decode(url.clone());
// Wait for the new server to be healthy
crate::router::Router::wait_for_healthy_workers(
&[url.clone()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
// Add to decode workers list
let mut workers = self
.decode_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "decode_workers write".to_string(),
})?;
// Check if already exists
if workers.iter().any(|w| w.url == url) {
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
}
workers.push(engine_info);
// Initialize load tracking
self.load_tracking
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
}
pub async fn remove_prefill_server(&self, url: &str) -> Result<String, PDRouterError> {
let mut workers = self
.prefill_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "prefill_workers write".to_string(),
})?;
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
// Remove from load tracking
self.load_tracking.remove(url);
// Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
// Note: Tree doesn't have a remove method, so we rebuild it
let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new();
for worker in workers.iter() {
tree_guard.insert("", &worker.url);
}
}
info!("Removed prefill server: {}", url);
Ok(format!("Successfully removed prefill server: {}", url))
}
pub async fn remove_decode_server(&self, url: &str) -> Result<String, PDRouterError> {
let mut workers = self
.decode_workers
.write()
.map_err(|_| PDRouterError::LockError {
operation: "decode_workers write".to_string(),
})?;
// Find and remove the server
let initial_len = workers.len();
workers.retain(|w| w.url != url);
if workers.len() == initial_len {
return Err(PDRouterError::WorkerNotFound {
url: url.to_string(),
});
}
// Remove from load tracking
self.load_tracking.remove(url);
info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url))
}
pub fn new(
prefill_urls: Vec<(String, Option<u16>)>,
......
......@@ -3,6 +3,31 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// Custom error type for PD router operations
#[derive(Debug, thiserror::Error)]
pub enum PDRouterError {
#[error("Worker already exists: {url}")]
WorkerAlreadyExists { url: String },
#[error("Worker not found: {url}")]
WorkerNotFound { url: String },
#[error("Lock acquisition failed: {operation}")]
LockError { operation: String },
#[error("Health check failed for worker: {url}")]
HealthCheckFailed { url: String },
#[error("Invalid worker configuration: {reason}")]
InvalidConfiguration { reason: String },
#[error("Network error: {message}")]
NetworkError { message: String },
#[error("Timeout waiting for worker: {url}")]
Timeout { url: String },
}
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
......
......@@ -1045,6 +1045,55 @@ impl Router {
}
}
/// Add a worker with PD mode support
pub async fn add_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
bootstrap_port: Option<u16>,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.add_prefill_server(worker_url.to_string(), bootstrap_port)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.add_decode_server(worker_url.to_string())
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("add_pd_worker only supported in PD mode".to_string()),
}
}
/// Remove a worker with PD mode support
pub async fn remove_pd_worker(
&self,
worker_url: &str,
pod_type: crate::service_discovery::PodType,
) -> Result<String, String> {
match self {
Router::PrefillDecode { pd_router } => match pod_type {
crate::service_discovery::PodType::Prefill => pd_router
.remove_prefill_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Decode => pd_router
.remove_decode_server(worker_url)
.await
.map_err(|e| e.to_string()),
crate::service_discovery::PodType::Regular => {
Err("Regular pod type not supported in PD mode".to_string())
}
},
_ => Err("remove_pd_worker only supported in PD mode".to_string()),
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
......@@ -1174,3 +1223,108 @@ impl Router {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::service_discovery::PodType;
fn create_test_regular_router() -> Router {
Router::Random {
worker_urls: Arc::new(RwLock::new(vec![
"http://worker1:8080".to_string(),
"http://worker2:8080".to_string(),
])),
timeout_secs: 5,
interval_secs: 1,
}
}
#[test]
fn test_router_get_worker_urls_regular() {
let router = create_test_regular_router();
let worker_urls = router.get_worker_urls();
let urls = worker_urls.read().unwrap();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string()));
}
// #[test]
// fn test_router_get_worker_urls_pd_mode() {
// // For PD mode, get_worker_urls returns empty list
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[tokio::test]
async fn test_add_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("add_pd_worker only supported in PD mode"));
}
#[tokio::test]
async fn test_remove_pd_worker_with_regular_router() {
let router = create_test_regular_router();
let result = router
.remove_pd_worker("http://worker:8080", PodType::Decode)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("remove_pd_worker only supported in PD mode"));
}
// #[tokio::test]
// async fn test_add_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
// #[tokio::test]
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_select_first_worker_regular() {
let router = create_test_regular_router();
let result = router.select_first_worker();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "http://worker1:8080");
}
// #[test]
// fn test_select_first_worker_pd_mode() {
// // Note: PDRouter::new requires health checks which fail in tests
// // This test would need a mock server or different test setup
// }
#[test]
fn test_wait_for_healthy_workers_empty_list() {
let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok());
}
#[test]
fn test_wait_for_healthy_workers_invalid_urls() {
// This test will timeout quickly since the URLs are invalid
let result =
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
}
This diff is collapsed.
......@@ -5,7 +5,7 @@
//! - Phase 2: Bootstrap injection and request handling
//! - Phase 3: Cache-aware selection (when implemented)
//!
//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type.
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
#[cfg(test)]
......@@ -90,7 +90,7 @@ mod test_pd_routing {
#[test]
fn test_pd_selection_policies() {
// Test all PD selection policy variants
// Note: These policies are only used when pd_disaggregated=true
// Note: These policies are only used when pd_disaggregation=true
let policies = vec![
PDSelectionPolicy::Random,
PDSelectionPolicy::PowerOfTwo,
......@@ -122,7 +122,7 @@ mod test_pd_routing {
#[test]
fn test_pd_router_configuration() {
// Test PrefillDecodeConfig creation with various policies
// This config is used when pd_disaggregated=true
// This config is used when pd_disaggregation=true
let configs = vec![
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::Random,
......@@ -878,7 +878,7 @@ mod test_pd_routing {
#[test]
fn test_policy_type_to_pd_selection_policy_mapping() {
// Document the mapping from PolicyType to PDSelectionPolicy
// This mapping happens in lib.rs when pd_disaggregated=true
// This mapping happens in lib.rs when pd_disaggregation=true
// PolicyType::Random -> PDSelectionPolicy::Random
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment