"docs/source/api/vscode:/vscode.git/clone" did not exist on "7ff2662b88958cb8691de78a626d2463cf742a02"
Unverified Commit 2ab97023 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add different policies for p node and d node (#8395)

parent 0bcc195f
...@@ -120,6 +120,16 @@ python -m sglang_router.launch_router \ ...@@ -120,6 +120,16 @@ python -m sglang_router.launch_router \
--prefill-selector app=sglang component=prefill \ --prefill-selector app=sglang component=prefill \
--decode-selector app=sglang component=decode \ --decode-selector app=sglang component=decode \
--service-discovery-namespace sglang-system --service-discovery-namespace sglang-system
# With separate routing policies:
python -m sglang_router.launch_router \
--pd-disaggregation \
--prefill-policy cache_aware \
--decode-policy power_of_two \
--service-discovery \
--prefill-selector app=sglang component=prefill \
--decode-selector app=sglang component=decode \
--service-discovery-namespace sglang-system
``` ```
#### Kubernetes Pod Configuration #### Kubernetes Pod Configuration
...@@ -226,7 +236,9 @@ python -m sglang_router.launch_router \ ...@@ -226,7 +236,9 @@ python -m sglang_router.launch_router \
- `--decode`: Initial decode server URL - `--decode`: Initial decode server URL
- `--prefill-selector`: Label selector for prefill pods - `--prefill-selector`: Label selector for prefill pods
- `--decode-selector`: Label selector for decode pods - `--decode-selector`: Label selector for decode pods
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`) - `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`, `round_robin`)
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)
## Development ## Development
......
...@@ -40,6 +40,8 @@ class RouterArgs: ...@@ -40,6 +40,8 @@ class RouterArgs:
# Routing policy # Routing policy
policy: str = "cache_aware" policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 300 worker_startup_timeout_secs: int = 300
worker_startup_check_interval: int = 10 worker_startup_check_interval: int = 10
cache_threshold: float = 0.5 cache_threshold: float = 0.5
...@@ -108,7 +110,21 @@ class RouterArgs: ...@@ -108,7 +110,21 @@ class RouterArgs:
type=str, type=str,
default=RouterArgs.policy, default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"], choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode", help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
) )
# PD-specific arguments # PD-specific arguments
...@@ -266,6 +282,8 @@ class RouterArgs: ...@@ -266,6 +282,8 @@ class RouterArgs:
prefill_urls=prefill_urls, prefill_urls=prefill_urls,
decode_urls=decode_urls, decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"), policy=getattr(args, f"{prefix}policy"),
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
decode_policy=getattr(args, f"{prefix}decode_policy", None),
worker_startup_timeout_secs=getattr( worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs" args, f"{prefix}worker_startup_timeout_secs"
), ),
...@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
if not router_args.decode_urls: if not router_args.decode_urls:
raise ValueError("PD disaggregation mode requires --decode") raise ValueError("PD disaggregation mode requires --decode")
# Warn about policy usage in PD mode
if (
router_args.prefill_policy
and router_args.decode_policy
and router_args.policy
):
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif (
router_args.prefill_policy
and not router_args.decode_policy
and router_args.policy
):
logger.info(
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
f"and --policy '{router_args.policy}' for decode nodes."
)
elif (
router_args.decode_policy
and not router_args.prefill_policy
and router_args.policy
):
logger.info(
f"Using --policy '{router_args.policy}' for prefill nodes "
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
)
# Create router with unified constructor # Create router with unified constructor
router = Router( router = Router(
worker_urls=( worker_urls=(
...@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
decode_urls=( decode_urls=(
router_args.decode_urls if router_args.pd_disaggregation else None router_args.decode_urls if router_args.pd_disaggregation else None
), ),
prefill_policy=(
policy_from_str(router_args.prefill_policy)
if router_args.prefill_policy
else None
),
decode_policy=(
policy_from_str(router_args.decode_policy)
if router_args.decode_policy
else None
),
) )
router.start() router.start()
...@@ -455,12 +512,18 @@ Examples: ...@@ -455,12 +512,18 @@ Examples:
# Regular mode # Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode # PD disaggregated mode with same policy for both
python -m sglang_router.launch_router --pd-disaggregation \\ python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\ --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\ --decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware --policy cache_aware
# PD mode with different policies for prefill and decode
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--prefill-policy cache_aware --decode-policy power_of_two
""", """,
formatter_class=CustomHelpFormatter, formatter_class=CustomHelpFormatter,
) )
......
...@@ -50,6 +50,10 @@ class Router: ...@@ -50,6 +50,10 @@ class Router:
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only) prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only) decode_urls: List of URLs for decode servers (PD mode only)
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None
""" """
def __init__( def __init__(
...@@ -79,6 +83,8 @@ class Router: ...@@ -79,6 +83,8 @@ class Router:
pd_disaggregation: bool = False, pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None, prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None, decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
): ):
if selector is None: if selector is None:
selector = {} selector = {}
...@@ -113,6 +119,8 @@ class Router: ...@@ -113,6 +119,8 @@ class Router:
pd_disaggregation=pd_disaggregation, pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls, prefill_urls=prefill_urls,
decode_urls=decode_urls, decode_urls=decode_urls,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
) )
def start(self) -> None: def start(self) -> None:
......
...@@ -46,6 +46,12 @@ pub enum RoutingMode { ...@@ -46,6 +46,12 @@ pub enum RoutingMode {
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
/// Decode worker URLs /// Decode worker URLs
decode_urls: Vec<String>, decode_urls: Vec<String>,
/// Optional separate policy for prefill workers
#[serde(skip_serializing_if = "Option::is_none")]
prefill_policy: Option<PolicyConfig>,
/// Optional separate policy for decode workers
#[serde(skip_serializing_if = "Option::is_none")]
decode_policy: Option<PolicyConfig>,
}, },
} }
...@@ -60,9 +66,32 @@ impl RoutingMode { ...@@ -60,9 +66,32 @@ impl RoutingMode {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
decode_urls, decode_urls,
..
} => prefill_urls.len() + decode_urls.len(), } => prefill_urls.len() + decode_urls.len(),
} }
} }
/// Get the effective prefill policy for PD mode
/// Falls back to the main policy if no specific prefill policy is set
pub fn get_prefill_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
match self {
RoutingMode::PrefillDecode { prefill_policy, .. } => {
prefill_policy.as_ref().unwrap_or(main_policy)
}
_ => main_policy,
}
}
/// Get the effective decode policy for PD mode
/// Falls back to the main policy if no specific decode policy is set
pub fn get_decode_policy<'a>(&'a self, main_policy: &'a PolicyConfig) -> &'a PolicyConfig {
match self {
RoutingMode::PrefillDecode { decode_policy, .. } => {
decode_policy.as_ref().unwrap_or(main_policy)
}
_ => main_policy,
}
}
} }
/// Policy configuration for routing /// Policy configuration for routing
...@@ -307,6 +336,8 @@ mod tests { ...@@ -307,6 +336,8 @@ mod tests {
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
}; };
assert!(pd.is_pd_mode()); assert!(pd.is_pd_mode());
} }
...@@ -332,6 +363,8 @@ mod tests { ...@@ -332,6 +363,8 @@ mod tests {
"http://decode2".to_string(), "http://decode2".to_string(),
"http://decode3".to_string(), "http://decode3".to_string(),
], ],
prefill_policy: None,
decode_policy: None,
}; };
assert_eq!(pd.worker_count(), 5); assert_eq!(pd.worker_count(), 5);
...@@ -355,6 +388,8 @@ mod tests { ...@@ -355,6 +388,8 @@ mod tests {
let pd = RoutingMode::PrefillDecode { let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], prefill_urls: vec![("http://prefill1".to_string(), Some(8001))],
decode_urls: vec!["http://decode1".to_string()], decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
}; };
let json = serde_json::to_string(&pd).unwrap(); let json = serde_json::to_string(&pd).unwrap();
assert!(json.contains("\"type\":\"prefill_decode\"")); assert!(json.contains("\"type\":\"prefill_decode\""));
...@@ -551,6 +586,8 @@ mod tests { ...@@ -551,6 +586,8 @@ mod tests {
mode: RoutingMode::PrefillDecode { mode: RoutingMode::PrefillDecode {
prefill_urls: vec![], prefill_urls: vec![],
decode_urls: vec![], decode_urls: vec![],
prefill_policy: None,
decode_policy: None,
}, },
..Default::default() ..Default::default()
}; };
...@@ -674,6 +711,8 @@ mod tests { ...@@ -674,6 +711,8 @@ mod tests {
"http://decode1:8000".to_string(), "http://decode1:8000".to_string(),
"http://decode2:8000".to_string(), "http://decode2:8000".to_string(),
], ],
prefill_policy: None,
decode_policy: None,
}, },
policy: PolicyConfig::PowerOfTwo { policy: PolicyConfig::PowerOfTwo {
load_check_interval_secs: 30, load_check_interval_secs: 30,
...@@ -800,4 +839,155 @@ mod tests { ...@@ -800,4 +839,155 @@ mod tests {
Some("production".to_string()) Some("production".to_string())
); );
} }
// ============= Policy Fallback Tests =============
#[test]
fn test_pd_policy_fallback_both_specified() {
// When both prefill and decode policies are specified, they should be used
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
};
let main_policy = PolicyConfig::Random;
// Both specific policies should be used
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { .. } => {} // Success
_ => panic!("Expected CacheAware for prefill"),
}
match pd.get_decode_policy(&main_policy) {
PolicyConfig::PowerOfTwo { .. } => {} // Success
_ => panic!("Expected PowerOfTwo for decode"),
}
}
#[test]
fn test_pd_policy_fallback_only_prefill() {
// When only prefill policy is specified, decode should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: None,
};
let main_policy = PolicyConfig::RoundRobin;
// Prefill should use specific policy
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware { .. } => {} // Success
_ => panic!("Expected CacheAware for prefill"),
}
// Decode should fall back to main policy
match pd.get_decode_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success
_ => panic!("Expected RoundRobin for decode"),
}
}
#[test]
fn test_pd_policy_fallback_only_decode() {
// When only decode policy is specified, prefill should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
};
let main_policy = PolicyConfig::Random;
// Prefill should fall back to main policy
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::Random => {} // Success
_ => panic!("Expected Random for prefill"),
}
// Decode should use specific policy
match pd.get_decode_policy(&main_policy) {
PolicyConfig::PowerOfTwo { .. } => {} // Success
_ => panic!("Expected PowerOfTwo for decode"),
}
}
#[test]
fn test_pd_policy_fallback_none_specified() {
// When no specific policies are specified, both should use main policy
let pd = RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1".to_string(), None)],
decode_urls: vec!["http://decode1".to_string()],
prefill_policy: None,
decode_policy: None,
};
let main_policy = PolicyConfig::CacheAware {
cache_threshold: 0.7,
balance_abs_threshold: 20,
balance_rel_threshold: 1.5,
eviction_interval_secs: 300,
max_tree_size: 2000,
};
// Both should fall back to main policy
match pd.get_prefill_policy(&main_policy) {
PolicyConfig::CacheAware {
cache_threshold, ..
} => {
assert!((cache_threshold - 0.7).abs() < 0.0001);
}
_ => panic!("Expected CacheAware for prefill"),
}
match pd.get_decode_policy(&main_policy) {
PolicyConfig::CacheAware {
cache_threshold, ..
} => {
assert!((cache_threshold - 0.7).abs() < 0.0001);
}
_ => panic!("Expected CacheAware for decode"),
}
}
#[test]
fn test_regular_mode_policy_fallback() {
// For regular mode, the helper methods should just return the main policy
let regular = RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
};
let main_policy = PolicyConfig::RoundRobin;
// Both methods should return main policy for regular mode
match regular.get_prefill_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success
_ => panic!("Expected RoundRobin for regular mode"),
}
match regular.get_decode_policy(&main_policy) {
PolicyConfig::RoundRobin => {} // Success
_ => panic!("Expected RoundRobin for regular mode"),
}
}
} }
...@@ -41,6 +41,8 @@ impl ConfigValidator { ...@@ -41,6 +41,8 @@ impl ConfigValidator {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
decode_urls, decode_urls,
prefill_policy,
decode_policy,
} => { } => {
// Only require URLs if service discovery is disabled // Only require URLs if service discovery is disabled
if !has_service_discovery { if !has_service_discovery {
...@@ -78,6 +80,14 @@ impl ConfigValidator { ...@@ -78,6 +80,14 @@ impl ConfigValidator {
} }
} }
} }
// Validate optional prefill and decode policies
if let Some(p_policy) = prefill_policy {
Self::validate_policy(p_policy)?;
}
if let Some(d_policy) = decode_policy {
Self::validate_policy(d_policy)?;
}
} }
} }
Ok(()) Ok(())
...@@ -272,6 +282,35 @@ impl ConfigValidator { ...@@ -272,6 +282,35 @@ impl ConfigValidator {
}); });
} }
} }
// For PD mode, validate that policies have sufficient workers
if let RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} = &config.mode
{
// Check power-of-two for prefill
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
if prefill_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig {
reason: "Power-of-two policy for prefill requires at least 2 prefill workers".to_string(),
});
}
}
// Check power-of-two for decode
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
if decode_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig {
reason:
"Power-of-two policy for decode requires at least 2 decode workers"
.to_string(),
});
}
}
}
} }
Ok(()) Ok(())
...@@ -430,6 +469,8 @@ mod tests { ...@@ -430,6 +469,8 @@ mod tests {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))], prefill_urls: vec![("http://prefill:8000".to_string(), Some(8081))],
decode_urls: vec!["http://decode:8000".to_string()], decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
}, },
PolicyConfig::Random, PolicyConfig::Random,
); );
...@@ -444,6 +485,8 @@ mod tests { ...@@ -444,6 +485,8 @@ mod tests {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)], prefill_urls: vec![("http://prefill:8000".to_string(), None)],
decode_urls: vec!["http://decode:8000".to_string()], decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
}, },
PolicyConfig::RoundRobin, PolicyConfig::RoundRobin,
); );
...@@ -459,6 +502,8 @@ mod tests { ...@@ -459,6 +502,8 @@ mod tests {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8000".to_string(), None)], prefill_urls: vec![("http://prefill:8000".to_string(), None)],
decode_urls: vec!["http://decode:8000".to_string()], decode_urls: vec!["http://decode:8000".to_string()],
prefill_policy: None,
decode_policy: None,
}, },
PolicyConfig::CacheAware { PolicyConfig::CacheAware {
cache_threshold: 0.5, cache_threshold: 0.5,
...@@ -491,4 +536,60 @@ mod tests { ...@@ -491,4 +536,60 @@ mod tests {
let result = ConfigValidator::validate(&config); let result = ConfigValidator::validate(&config);
assert!(result.is_ok()); assert!(result.is_ok());
} }
#[test]
fn test_validate_pd_mode_with_separate_policies() {
// Test PD mode with different policies for prefill and decode
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8000".to_string(), None),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: Some(PolicyConfig::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
eviction_interval_secs: 60,
max_tree_size: 1000,
}),
decode_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}),
},
PolicyConfig::Random, // Main policy as fallback
);
let result = ConfigValidator::validate(&config);
assert!(result.is_ok());
}
#[test]
fn test_validate_pd_mode_power_of_two_insufficient_workers() {
// Test that power-of-two policy requires at least 2 workers
let config = RouterConfig::new(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill1:8000".to_string(), None)], // Only 1 prefill
decode_urls: vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: Some(PolicyConfig::PowerOfTwo {
load_check_interval_secs: 60,
}), // Requires 2+ workers
decode_policy: None,
},
PolicyConfig::Random,
);
let result = ConfigValidator::validate(&config);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("prefill requires at least 2"));
}
}
} }
...@@ -54,6 +54,8 @@ struct Router { ...@@ -54,6 +54,8 @@ struct Router {
// PD-specific fields (only used when pd_disaggregation is true) // PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>, prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>, decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
} }
impl Router { impl Router {
...@@ -63,11 +65,31 @@ impl Router { ...@@ -63,11 +65,31 @@ impl Router {
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
}; };
// Convert policy helper function
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
match policy {
PolicyType::Random => ConfigPolicyConfig::Random,
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
},
}
};
// Determine routing mode // Determine routing mode
let mode = if self.pd_disaggregation { let mode = if self.pd_disaggregation {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(), prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
decode_urls: self.decode_urls.clone().unwrap_or_default(), decode_urls: self.decode_urls.clone().unwrap_or_default(),
prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
decode_policy: self.decode_policy.as_ref().map(convert_policy),
} }
} else { } else {
RoutingMode::Regular { RoutingMode::Regular {
...@@ -75,21 +97,8 @@ impl Router { ...@@ -75,21 +97,8 @@ impl Router {
} }
}; };
// Convert policy // Convert main policy
let policy = match self.policy { let policy = convert_policy(&self.policy);
PolicyType::Random => ConfigPolicyConfig::Random,
PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
},
};
// Service discovery configuration // Service discovery configuration
let discovery = if self.service_discovery { let discovery = if self.service_discovery {
...@@ -163,7 +172,9 @@ impl Router { ...@@ -163,7 +172,9 @@ impl Router {
request_timeout_secs = 600, // Add configurable request timeout request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregation = false, // New flag for PD mode pd_disaggregation = false, // New flag for PD mode
prefill_urls = None, prefill_urls = None,
decode_urls = None decode_urls = None,
prefill_policy = None,
decode_policy = None
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
...@@ -193,6 +204,8 @@ impl Router { ...@@ -193,6 +204,8 @@ impl Router {
pd_disaggregation: bool, pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>, prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>, decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Router { Ok(Router {
host, host,
...@@ -222,6 +235,8 @@ impl Router { ...@@ -222,6 +235,8 @@ impl Router {
pd_disaggregation, pd_disaggregation,
prefill_urls, prefill_urls,
decode_urls, decode_urls,
prefill_policy,
decode_policy,
}) })
} }
......
...@@ -254,7 +254,11 @@ impl LoadBalancingPolicy for CacheAwarePolicy { ...@@ -254,7 +254,11 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
decode_workers: &[Box<dyn Worker>], decode_workers: &[Box<dyn Worker>],
request_text: Option<&str>, request_text: Option<&str>,
) -> Option<(usize, usize)> { ) -> Option<(usize, usize)> {
// In PD mode: // DEPRECATED: This method is no longer used when separate policies are configured.
// The PD router now uses separate policies for prefill and decode selection.
// This implementation remains for backward compatibility when a single policy is used.
// In PD mode with single policy:
// - Prefill: Use cache-aware routing for better cache utilization // - Prefill: Use cache-aware routing for better cache utilization
// - Decode: Use least-load routing for better load distribution // - Decode: Use least-load routing for better load distribution
......
...@@ -17,7 +17,16 @@ impl RouterFactory { ...@@ -17,7 +17,16 @@ impl RouterFactory {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
decode_urls, decode_urls,
} => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config), prefill_policy,
decode_policy,
} => Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&config.policy,
config,
),
} }
} }
...@@ -45,18 +54,23 @@ impl RouterFactory { ...@@ -45,18 +54,23 @@ impl RouterFactory {
fn create_pd_router( fn create_pd_router(
prefill_urls: &[(String, Option<u16>)], prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String], decode_urls: &[String],
policy_config: &PolicyConfig, prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig,
router_config: &RouterConfig, router_config: &RouterConfig,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Create policy directly from PolicyConfig // Create policies - use specific policies if provided, otherwise fall back to main policy
// All policies now support PD mode through the select_worker_pair method let prefill_policy =
let policy = PolicyFactory::create_from_config(policy_config); PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create PD router with injected policy // Create PD router with separate policies
let router = PDRouter::new( let router = PDRouter::new(
prefill_urls.to_vec(), prefill_urls.to_vec(),
decode_urls.to_vec(), decode_urls.to_vec(),
policy, prefill_policy,
decode_policy,
router_config.worker_startup_timeout_secs, router_config.worker_startup_timeout_secs,
router_config.worker_startup_check_interval_secs, router_config.worker_startup_check_interval_secs,
)?; )?;
......
...@@ -22,8 +22,10 @@ use uuid::Uuid; ...@@ -22,8 +22,10 @@ use uuid::Uuid;
pub struct PDRouter { pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>, pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>, pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub policy: Arc<dyn LoadBalancingPolicy>, pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>, pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub decode_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64, pub timeout_secs: u64,
pub interval_secs: u64, pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
...@@ -66,7 +68,7 @@ impl PDRouter { ...@@ -66,7 +68,7 @@ impl PDRouter {
workers.push(worker); workers.push(worker);
// Add to cache tree if using cache-aware policy // Add to cache tree if using cache-aware policy for prefill
if let Some(ref tree) = self.prefill_tree { if let Some(ref tree) = self.prefill_tree {
tree.lock().unwrap().insert("", &url); tree.lock().unwrap().insert("", &url);
} }
...@@ -102,6 +104,11 @@ impl PDRouter { ...@@ -102,6 +104,11 @@ impl PDRouter {
workers.push(worker); workers.push(worker);
// Add to cache tree if using cache-aware policy for decode
if let Some(ref tree) = self.decode_tree {
tree.lock().unwrap().insert("", &url);
}
info!("Added decode server: {}", url); info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url)) Ok(format!("Successfully added decode server: {}", url))
} }
...@@ -126,12 +133,7 @@ impl PDRouter { ...@@ -126,12 +133,7 @@ impl PDRouter {
// Remove from cache tree if using cache-aware policy // Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree { if let Some(ref tree) = self.prefill_tree {
// Note: Tree doesn't have a remove method, so we rebuild it tree.lock().unwrap().remove_tenant(url);
let mut tree_guard = tree.lock().unwrap();
*tree_guard = Tree::new();
for worker in workers.iter() {
tree_guard.insert("", worker.url());
}
} }
info!("Removed prefill server: {}", url); info!("Removed prefill server: {}", url);
...@@ -156,6 +158,11 @@ impl PDRouter { ...@@ -156,6 +158,11 @@ impl PDRouter {
}); });
} }
// Remove from the cache tree if using cache-aware policy for decode
if let Some(ref tree) = self.decode_tree {
tree.lock().unwrap().remove_tenant(url);
}
info!("Removed decode server: {}", url); info!("Removed decode server: {}", url);
Ok(format!("Successfully removed decode server: {}", url)) Ok(format!("Successfully removed decode server: {}", url))
} }
...@@ -163,7 +170,8 @@ impl PDRouter { ...@@ -163,7 +170,8 @@ impl PDRouter {
pub fn new( pub fn new(
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>, decode_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>, prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
) -> Result<Self, String> { ) -> Result<Self, String> {
...@@ -192,10 +200,10 @@ impl PDRouter { ...@@ -192,10 +200,10 @@ impl PDRouter {
)?; )?;
} }
// Initialize cache-aware components if needed // Initialize cache-aware components if needed for prefill policy
let prefill_tree = if policy.name() == "cache_aware" { let prefill_tree = if prefill_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with prefill workers // Initialize the policy's internal tree with prefill workers
if let Some(cache_policy) = policy if let Some(cache_policy) = prefill_policy
.as_any() .as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>() .downcast_ref::<crate::policies::CacheAwarePolicy>()
{ {
...@@ -212,6 +220,26 @@ impl PDRouter { ...@@ -212,6 +220,26 @@ impl PDRouter {
None None
}; };
// Initialize cache-aware components if needed for decode policy
let decode_tree = if decode_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with decode workers
if let Some(cache_policy) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&decode_workers);
}
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with decode workers
for worker in &decode_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
} else {
None
};
// Set up background load monitoring for power-of-two selection // Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx); let worker_loads = Arc::new(rx);
...@@ -222,25 +250,28 @@ impl PDRouter { ...@@ -222,25 +250,28 @@ impl PDRouter {
.build() .build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?; .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let load_monitor_handle = if policy.name() == "power_of_two" { let load_monitor_handle =
let monitor_urls = all_urls.clone(); if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_interval = interval_secs; let monitor_urls = all_urls.clone();
let monitor_client = http_client.clone(); let monitor_interval = interval_secs;
let policy_clone = Arc::clone(&policy); let monitor_client = http_client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy);
Some(Arc::new(tokio::spawn(async move { let decode_policy_clone = Arc::clone(&decode_policy);
Self::monitor_worker_loads_with_client(
monitor_urls, Some(Arc::new(tokio::spawn(async move {
tx, Self::monitor_worker_loads_with_client(
monitor_interval, monitor_urls,
monitor_client, tx,
policy_clone, monitor_interval,
) monitor_client,
.await; prefill_policy_clone,
}))) decode_policy_clone,
} else { )
None .await;
}; })))
} else {
None
};
let prefill_workers = Arc::new(RwLock::new(prefill_workers)); let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers)); let decode_workers = Arc::new(RwLock::new(decode_workers));
...@@ -254,8 +285,10 @@ impl PDRouter { ...@@ -254,8 +285,10 @@ impl PDRouter {
Ok(PDRouter { Ok(PDRouter {
prefill_workers, prefill_workers,
decode_workers, decode_workers,
policy, prefill_policy,
decode_policy,
prefill_tree, prefill_tree,
decode_tree,
timeout_secs, timeout_secs,
interval_secs, interval_secs,
worker_loads, worker_loads,
...@@ -736,18 +769,21 @@ impl PDRouter { ...@@ -736,18 +769,21 @@ impl PDRouter {
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
} }
// Use the policy to select worker pair // Select prefill worker using prefill policy
match self let prefill_idx = self
.policy .prefill_policy
.select_worker_pair(&prefill_workers, &decode_workers, request_text) .select_worker(&prefill_workers, request_text)
{ .ok_or("Failed to select prefill worker")?;
Some((prefill_idx, decode_idx)) => {
let prefill = prefill_workers[prefill_idx].clone_worker(); // Select decode worker using decode policy
let decode = decode_workers[decode_idx].clone_worker(); let decode_idx = self
Ok((prefill, decode)) .decode_policy
} .select_worker(&decode_workers, request_text)
None => Err("Failed to select worker pair".to_string()), .ok_or("Failed to select decode worker")?;
}
let prefill = prefill_workers[prefill_idx].clone_worker();
let decode = decode_workers[decode_idx].clone_worker();
Ok((prefill, decode))
} }
// Background task to monitor worker loads with shared client // Background task to monitor worker loads with shared client
...@@ -756,7 +792,8 @@ impl PDRouter { ...@@ -756,7 +792,8 @@ impl PDRouter {
tx: tokio::sync::watch::Sender<HashMap<String, isize>>, tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64, interval_secs: u64,
client: reqwest::Client, client: reqwest::Client,
policy: Arc<dyn LoadBalancingPolicy>, prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
) { ) {
loop { loop {
let mut loads = HashMap::new(); let mut loads = HashMap::new();
...@@ -781,8 +818,9 @@ impl PDRouter { ...@@ -781,8 +818,9 @@ impl PDRouter {
debug!("Worker loads updated: {:?}", loads); debug!("Worker loads updated: {:?}", loads);
// Update the policy with current loads // Update both policies with current loads
policy.update_loads(&loads); prefill_policy.update_loads(&loads);
decode_policy.update_loads(&loads);
// Check if receiver is still active // Check if receiver is still active
if tx.send(loads).is_err() { if tx.send(loads).is_err() {
...@@ -1463,13 +1501,16 @@ mod tests { ...@@ -1463,13 +1501,16 @@ mod tests {
use actix_web::test::TestRequest; use actix_web::test::TestRequest;
fn create_test_pd_router() -> PDRouter { fn create_test_pd_router() -> PDRouter {
let policy = Arc::new(RandomPolicy::new()); let prefill_policy = Arc::new(RandomPolicy::new());
let decode_policy = Arc::new(RandomPolicy::new());
PDRouter { PDRouter {
prefill_workers: Arc::new(RwLock::new(vec![])), prefill_workers: Arc::new(RwLock::new(vec![])),
decode_workers: Arc::new(RwLock::new(vec![])), decode_workers: Arc::new(RwLock::new(vec![])),
policy, prefill_policy,
decode_policy,
prefill_tree: None, prefill_tree: None,
decode_tree: None,
timeout_secs: 5, timeout_secs: 5,
interval_secs: 1, interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
...@@ -1608,9 +1649,9 @@ mod tests { ...@@ -1608,9 +1649,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_cache_tree_operations() { async fn test_cache_tree_operations() {
let policy = Arc::new(CacheAwarePolicy::new()); let cache_policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router(); let mut router = create_test_pd_router();
router.policy = policy; router.prefill_policy = cache_policy;
// Initialize cache tree // Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new())); let tree = Arc::new(Mutex::new(Tree::new()));
...@@ -1638,9 +1679,9 @@ mod tests { ...@@ -1638,9 +1679,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_cache_tree_rebuild_on_remove() { async fn test_cache_tree_rebuild_on_remove() {
let policy = Arc::new(CacheAwarePolicy::new()); let cache_policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router(); let mut router = create_test_pd_router();
router.policy = policy; router.prefill_policy = cache_policy;
// Initialize cache tree // Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new())); let tree = Arc::new(Mutex::new(Tree::new()));
...@@ -1880,9 +1921,10 @@ mod tests { ...@@ -1880,9 +1921,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_load_monitor_updates() { async fn test_load_monitor_updates() {
let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
let mut router = create_test_pd_router(); let mut router = create_test_pd_router();
router.policy = policy; router.prefill_policy = power_of_two_policy.clone();
router.decode_policy = power_of_two_policy;
// Create load channel // Create load channel
let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
......
...@@ -122,6 +122,8 @@ mod test_pd_routing { ...@@ -122,6 +122,8 @@ mod test_pd_routing {
"http://decode1:8080".to_string(), "http://decode1:8080".to_string(),
"http://decode2:8080".to_string(), "http://decode2:8080".to_string(),
], ],
prefill_policy: None,
decode_policy: None,
}, },
PolicyConfig::Random, PolicyConfig::Random,
), ),
...@@ -129,6 +131,8 @@ mod test_pd_routing { ...@@ -129,6 +131,8 @@ mod test_pd_routing {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
decode_urls: vec!["http://decode:8080".to_string()], decode_urls: vec!["http://decode:8080".to_string()],
prefill_policy: None,
decode_policy: None,
}, },
PolicyConfig::PowerOfTwo { PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, load_check_interval_secs: 5,
...@@ -142,6 +146,8 @@ mod test_pd_routing { ...@@ -142,6 +146,8 @@ mod test_pd_routing {
("http://p3:8080".to_string(), Some(9002)), ("http://p3:8080".to_string(), Some(9002)),
], ],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
prefill_policy: None,
decode_policy: None,
}, },
PolicyConfig::CacheAware { PolicyConfig::CacheAware {
cache_threshold: 0.7, cache_threshold: 0.7,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment