Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2ab97023
"docs/source/api/vscode:/vscode.git/clone" did not exist on "7ff2662b88958cb8691de78a626d2463cf742a02"
Unverified
Commit
2ab97023
authored
Jul 27, 2025
by
Simo Lin
Committed by
GitHub
Jul 27, 2025
Browse files
[router] add different policies for p node and d node (#8395)
parent
0bcc195f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
537 additions
and
82 deletions
+537
-82
sgl-router/README.md
sgl-router/README.md
+13
-1
sgl-router/py_src/sglang_router/launch_router.py
sgl-router/py_src/sglang_router/launch_router.py
+65
-2
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+8
-0
sgl-router/src/config/types.rs
sgl-router/src/config/types.rs
+190
-0
sgl-router/src/config/validation.rs
sgl-router/src/config/validation.rs
+101
-0
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+31
-16
sgl-router/src/policies/cache_aware.rs
sgl-router/src/policies/cache_aware.rs
+5
-1
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+21
-7
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+97
-55
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+6
-0
No files found.
sgl-router/README.md
View file @
2ab97023
...
@@ -120,6 +120,16 @@ python -m sglang_router.launch_router \
...
@@ -120,6 +120,16 @@ python -m sglang_router.launch_router \
--prefill-selector
app
=
sglang
component
=
prefill
\
--prefill-selector
app
=
sglang
component
=
prefill
\
--decode-selector
app
=
sglang
component
=
decode
\
--decode-selector
app
=
sglang
component
=
decode
\
--service-discovery-namespace
sglang-system
--service-discovery-namespace
sglang-system
# With separate routing policies:
python
-m
sglang_router.launch_router
\
--pd-disaggregation
\
--prefill-policy
cache_aware
\
--decode-policy
power_of_two
\
--service-discovery
\
--prefill-selector
app
=
sglang
component
=
prefill
\
--decode-selector
app
=
sglang
component
=
decode
\
--service-discovery-namespace
sglang-system
```
```
#### Kubernetes Pod Configuration
#### Kubernetes Pod Configuration
...
@@ -226,7 +236,9 @@ python -m sglang_router.launch_router \
...
@@ -226,7 +236,9 @@ python -m sglang_router.launch_router \
-
`--decode`
: Initial decode server URL
-
`--decode`
: Initial decode server URL
-
`--prefill-selector`
: Label selector for prefill pods
-
`--prefill-selector`
: Label selector for prefill pods
-
`--decode-selector`
: Label selector for decode pods
-
`--decode-selector`
: Label selector for decode pods
-
`--policy`
: Routing policy (
`cache_aware`
,
`random`
,
`power_of_two`
)
-
`--policy`
: Routing policy (
`cache_aware`
,
`random`
,
`power_of_two`
,
`round_robin`
)
-
`--prefill-policy`
: Separate routing policy for prefill nodes (optional, overrides
`--policy`
for prefill)
-
`--decode-policy`
: Separate routing policy for decode nodes (optional, overrides
`--policy`
for decode)
## Development
## Development
...
...
sgl-router/py_src/sglang_router/launch_router.py
View file @
2ab97023
...
@@ -40,6 +40,8 @@ class RouterArgs:
...
@@ -40,6 +40,8 @@ class RouterArgs:
# Routing policy
# Routing policy
policy
:
str
=
"cache_aware"
policy
:
str
=
"cache_aware"
prefill_policy
:
Optional
[
str
]
=
None
# Specific policy for prefill nodes in PD mode
decode_policy
:
Optional
[
str
]
=
None
# Specific policy for decode nodes in PD mode
worker_startup_timeout_secs
:
int
=
300
worker_startup_timeout_secs
:
int
=
300
worker_startup_check_interval
:
int
=
10
worker_startup_check_interval
:
int
=
10
cache_threshold
:
float
=
0.5
cache_threshold
:
float
=
0.5
...
@@ -108,7 +110,21 @@ class RouterArgs:
...
@@ -108,7 +110,21 @@ class RouterArgs:
type
=
str
,
type
=
str
,
default
=
RouterArgs
.
policy
,
default
=
RouterArgs
.
policy
,
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
],
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
],
help
=
"Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode"
,
help
=
"Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
prefill-policy"
,
type
=
str
,
default
=
None
,
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
],
help
=
"Specific policy for prefill nodes in PD mode. If not specified, uses the main policy"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
decode-policy"
,
type
=
str
,
default
=
None
,
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
],
help
=
"Specific policy for decode nodes in PD mode. If not specified, uses the main policy"
,
)
)
# PD-specific arguments
# PD-specific arguments
...
@@ -266,6 +282,8 @@ class RouterArgs:
...
@@ -266,6 +282,8 @@ class RouterArgs:
prefill_urls
=
prefill_urls
,
prefill_urls
=
prefill_urls
,
decode_urls
=
decode_urls
,
decode_urls
=
decode_urls
,
policy
=
getattr
(
args
,
f
"
{
prefix
}
policy"
),
policy
=
getattr
(
args
,
f
"
{
prefix
}
policy"
),
prefill_policy
=
getattr
(
args
,
f
"
{
prefix
}
prefill_policy"
,
None
),
decode_policy
=
getattr
(
args
,
f
"
{
prefix
}
decode_policy"
,
None
),
worker_startup_timeout_secs
=
getattr
(
worker_startup_timeout_secs
=
getattr
(
args
,
f
"
{
prefix
}
worker_startup_timeout_secs"
args
,
f
"
{
prefix
}
worker_startup_timeout_secs"
),
),
...
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
if
not
router_args
.
decode_urls
:
if
not
router_args
.
decode_urls
:
raise
ValueError
(
"PD disaggregation mode requires --decode"
)
raise
ValueError
(
"PD disaggregation mode requires --decode"
)
# Warn about policy usage in PD mode
if
(
router_args
.
prefill_policy
and
router_args
.
decode_policy
and
router_args
.
policy
):
logger
.
warning
(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif
(
router_args
.
prefill_policy
and
not
router_args
.
decode_policy
and
router_args
.
policy
):
logger
.
info
(
f
"Using --prefill-policy '
{
router_args
.
prefill_policy
}
' for prefill nodes "
f
"and --policy '
{
router_args
.
policy
}
' for decode nodes."
)
elif
(
router_args
.
decode_policy
and
not
router_args
.
prefill_policy
and
router_args
.
policy
):
logger
.
info
(
f
"Using --policy '
{
router_args
.
policy
}
' for prefill nodes "
f
"and --decode-policy '
{
router_args
.
decode_policy
}
' for decode nodes."
)
# Create router with unified constructor
# Create router with unified constructor
router
=
Router
(
router
=
Router
(
worker_urls
=
(
worker_urls
=
(
...
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
decode_urls
=
(
decode_urls
=
(
router_args
.
decode_urls
if
router_args
.
pd_disaggregation
else
None
router_args
.
decode_urls
if
router_args
.
pd_disaggregation
else
None
),
),
prefill_policy
=
(
policy_from_str
(
router_args
.
prefill_policy
)
if
router_args
.
prefill_policy
else
None
),
decode_policy
=
(
policy_from_str
(
router_args
.
decode_policy
)
if
router_args
.
decode_policy
else
None
),
)
)
router
.
start
()
router
.
start
()
...
@@ -455,12 +512,18 @@ Examples:
...
@@ -455,12 +512,18 @@ Examples:
# Regular mode
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode
# PD disaggregated mode
with same policy for both
python -m sglang_router.launch_router --pd-disaggregation
\\
python -m sglang_router.launch_router --pd-disaggregation
\\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none
\\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none
\\
--decode http://decode1:8001 --decode http://decode2:8001
\\
--decode http://decode1:8001 --decode http://decode2:8001
\\
--policy cache_aware
--policy cache_aware
# PD mode with different policies for prefill and decode
python -m sglang_router.launch_router --pd-disaggregation
\\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none
\\
--decode http://decode1:8001 --decode http://decode2:8001
\\
--prefill-policy cache_aware --decode-policy power_of_two
"""
,
"""
,
formatter_class
=
CustomHelpFormatter
,
formatter_class
=
CustomHelpFormatter
,
)
)
...
...
sgl-router/py_src/sglang_router/router.py
View file @
2ab97023
...
@@ -50,6 +50,10 @@ class Router:
...
@@ -50,6 +50,10 @@ class Router:
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None
"""
"""
def
__init__
(
def
__init__
(
...
@@ -79,6 +83,8 @@ class Router:
...
@@ -79,6 +83,8 @@ class Router:
pd_disaggregation
:
bool
=
False
,
pd_disaggregation
:
bool
=
False
,
prefill_urls
:
Optional
[
List
[
tuple
]]
=
None
,
prefill_urls
:
Optional
[
List
[
tuple
]]
=
None
,
decode_urls
:
Optional
[
List
[
str
]]
=
None
,
decode_urls
:
Optional
[
List
[
str
]]
=
None
,
prefill_policy
:
Optional
[
PolicyType
]
=
None
,
decode_policy
:
Optional
[
PolicyType
]
=
None
,
):
):
if
selector
is
None
:
if
selector
is
None
:
selector
=
{}
selector
=
{}
...
@@ -113,6 +119,8 @@ class Router:
...
@@ -113,6 +119,8 @@ class Router:
pd_disaggregation
=
pd_disaggregation
,
pd_disaggregation
=
pd_disaggregation
,
prefill_urls
=
prefill_urls
,
prefill_urls
=
prefill_urls
,
decode_urls
=
decode_urls
,
decode_urls
=
decode_urls
,
prefill_policy
=
prefill_policy
,
decode_policy
=
decode_policy
,
)
)
def
start
(
self
)
->
None
:
def
start
(
self
)
->
None
:
...
...
sgl-router/src/config/types.rs
View file @
2ab97023
...
@@ -46,6 +46,12 @@ pub enum RoutingMode {
...
@@ -46,6 +46,12 @@ pub enum RoutingMode {
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
/// Decode worker URLs
/// Decode worker URLs
decode_urls
:
Vec
<
String
>
,
decode_urls
:
Vec
<
String
>
,
/// Optional separate policy for prefill workers
#[serde(skip_serializing_if
=
"Option::is_none"
)]
prefill_policy
:
Option
<
PolicyConfig
>
,
/// Optional separate policy for decode workers
#[serde(skip_serializing_if
=
"Option::is_none"
)]
decode_policy
:
Option
<
PolicyConfig
>
,
},
},
}
}
...
@@ -60,9 +66,32 @@ impl RoutingMode {
...
@@ -60,9 +66,32 @@ impl RoutingMode {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
,
prefill_urls
,
decode_urls
,
decode_urls
,
..
}
=>
prefill_urls
.len
()
+
decode_urls
.len
(),
}
=>
prefill_urls
.len
()
+
decode_urls
.len
(),
}
}
}
}
/// Get the effective prefill policy for PD mode
/// Falls back to the main policy if no specific prefill policy is set
pub
fn
get_prefill_policy
<
'a
>
(
&
'a
self
,
main_policy
:
&
'a
PolicyConfig
)
->
&
'a
PolicyConfig
{
match
self
{
RoutingMode
::
PrefillDecode
{
prefill_policy
,
..
}
=>
{
prefill_policy
.as_ref
()
.unwrap_or
(
main_policy
)
}
_
=>
main_policy
,
}
}
/// Get the effective decode policy for PD mode
/// Falls back to the main policy if no specific decode policy is set
pub
fn
get_decode_policy
<
'a
>
(
&
'a
self
,
main_policy
:
&
'a
PolicyConfig
)
->
&
'a
PolicyConfig
{
match
self
{
RoutingMode
::
PrefillDecode
{
decode_policy
,
..
}
=>
{
decode_policy
.as_ref
()
.unwrap_or
(
main_policy
)
}
_
=>
main_policy
,
}
}
}
}
/// Policy configuration for routing
/// Policy configuration for routing
...
@@ -307,6 +336,8 @@ mod tests {
...
@@ -307,6 +336,8 @@ mod tests {
let
pd
=
RoutingMode
::
PrefillDecode
{
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
Some
(
8001
))],
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
Some
(
8001
))],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
};
};
assert
!
(
pd
.is_pd_mode
());
assert
!
(
pd
.is_pd_mode
());
}
}
...
@@ -332,6 +363,8 @@ mod tests {
...
@@ -332,6 +363,8 @@ mod tests {
"http://decode2"
.to_string
(),
"http://decode2"
.to_string
(),
"http://decode3"
.to_string
(),
"http://decode3"
.to_string
(),
],
],
prefill_policy
:
None
,
decode_policy
:
None
,
};
};
assert_eq!
(
pd
.worker_count
(),
5
);
assert_eq!
(
pd
.worker_count
(),
5
);
...
@@ -355,6 +388,8 @@ mod tests {
...
@@ -355,6 +388,8 @@ mod tests {
let
pd
=
RoutingMode
::
PrefillDecode
{
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
Some
(
8001
))],
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
Some
(
8001
))],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
};
};
let
json
=
serde_json
::
to_string
(
&
pd
)
.unwrap
();
let
json
=
serde_json
::
to_string
(
&
pd
)
.unwrap
();
assert
!
(
json
.contains
(
"
\"
type
\"
:
\"
prefill_decode
\"
"
));
assert
!
(
json
.contains
(
"
\"
type
\"
:
\"
prefill_decode
\"
"
));
...
@@ -551,6 +586,8 @@ mod tests {
...
@@ -551,6 +586,8 @@ mod tests {
mode
:
RoutingMode
::
PrefillDecode
{
mode
:
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[],
prefill_urls
:
vec!
[],
decode_urls
:
vec!
[],
decode_urls
:
vec!
[],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
..
Default
::
default
()
..
Default
::
default
()
};
};
...
@@ -674,6 +711,8 @@ mod tests {
...
@@ -674,6 +711,8 @@ mod tests {
"http://decode1:8000"
.to_string
(),
"http://decode1:8000"
.to_string
(),
"http://decode2:8000"
.to_string
(),
"http://decode2:8000"
.to_string
(),
],
],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
policy
:
PolicyConfig
::
PowerOfTwo
{
policy
:
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
30
,
load_check_interval_secs
:
30
,
...
@@ -800,4 +839,155 @@ mod tests {
...
@@ -800,4 +839,155 @@ mod tests {
Some
(
"production"
.to_string
())
Some
(
"production"
.to_string
())
);
);
}
}
// ============= Policy Fallback Tests =============
#[test]
fn
test_pd_policy_fallback_both_specified
()
{
// When both prefill and decode policies are specified, they should be used
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
None
)],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
prefill_policy
:
Some
(
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
32
,
balance_rel_threshold
:
1.1
,
eviction_interval_secs
:
60
,
max_tree_size
:
1000
,
}),
decode_policy
:
Some
(
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
60
,
}),
};
let
main_policy
=
PolicyConfig
::
Random
;
// Both specific policies should be used
match
pd
.get_prefill_policy
(
&
main_policy
)
{
PolicyConfig
::
CacheAware
{
..
}
=>
{}
// Success
_
=>
panic!
(
"Expected CacheAware for prefill"
),
}
match
pd
.get_decode_policy
(
&
main_policy
)
{
PolicyConfig
::
PowerOfTwo
{
..
}
=>
{}
// Success
_
=>
panic!
(
"Expected PowerOfTwo for decode"
),
}
}
#[test]
fn
test_pd_policy_fallback_only_prefill
()
{
// When only prefill policy is specified, decode should use main policy
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
None
)],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
prefill_policy
:
Some
(
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
32
,
balance_rel_threshold
:
1.1
,
eviction_interval_secs
:
60
,
max_tree_size
:
1000
,
}),
decode_policy
:
None
,
};
let
main_policy
=
PolicyConfig
::
RoundRobin
;
// Prefill should use specific policy
match
pd
.get_prefill_policy
(
&
main_policy
)
{
PolicyConfig
::
CacheAware
{
..
}
=>
{}
// Success
_
=>
panic!
(
"Expected CacheAware for prefill"
),
}
// Decode should fall back to main policy
match
pd
.get_decode_policy
(
&
main_policy
)
{
PolicyConfig
::
RoundRobin
=>
{}
// Success
_
=>
panic!
(
"Expected RoundRobin for decode"
),
}
}
#[test]
fn
test_pd_policy_fallback_only_decode
()
{
// When only decode policy is specified, prefill should use main policy
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
None
)],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
Some
(
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
60
,
}),
};
let
main_policy
=
PolicyConfig
::
Random
;
// Prefill should fall back to main policy
match
pd
.get_prefill_policy
(
&
main_policy
)
{
PolicyConfig
::
Random
=>
{}
// Success
_
=>
panic!
(
"Expected Random for prefill"
),
}
// Decode should use specific policy
match
pd
.get_decode_policy
(
&
main_policy
)
{
PolicyConfig
::
PowerOfTwo
{
..
}
=>
{}
// Success
_
=>
panic!
(
"Expected PowerOfTwo for decode"
),
}
}
#[test]
fn
test_pd_policy_fallback_none_specified
()
{
// When no specific policies are specified, both should use main policy
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
None
)],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
};
let
main_policy
=
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.7
,
balance_abs_threshold
:
20
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
300
,
max_tree_size
:
2000
,
};
// Both should fall back to main policy
match
pd
.get_prefill_policy
(
&
main_policy
)
{
PolicyConfig
::
CacheAware
{
cache_threshold
,
..
}
=>
{
assert
!
((
cache_threshold
-
0.7
)
.abs
()
<
0.0001
);
}
_
=>
panic!
(
"Expected CacheAware for prefill"
),
}
match
pd
.get_decode_policy
(
&
main_policy
)
{
PolicyConfig
::
CacheAware
{
cache_threshold
,
..
}
=>
{
assert
!
((
cache_threshold
-
0.7
)
.abs
()
<
0.0001
);
}
_
=>
panic!
(
"Expected CacheAware for decode"
),
}
}
#[test]
fn
test_regular_mode_policy_fallback
()
{
// For regular mode, the helper methods should just return the main policy
let
regular
=
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
()],
};
let
main_policy
=
PolicyConfig
::
RoundRobin
;
// Both methods should return main policy for regular mode
match
regular
.get_prefill_policy
(
&
main_policy
)
{
PolicyConfig
::
RoundRobin
=>
{}
// Success
_
=>
panic!
(
"Expected RoundRobin for regular mode"
),
}
match
regular
.get_decode_policy
(
&
main_policy
)
{
PolicyConfig
::
RoundRobin
=>
{}
// Success
_
=>
panic!
(
"Expected RoundRobin for regular mode"
),
}
}
}
}
sgl-router/src/config/validation.rs
View file @
2ab97023
...
@@ -41,6 +41,8 @@ impl ConfigValidator {
...
@@ -41,6 +41,8 @@ impl ConfigValidator {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
,
prefill_urls
,
decode_urls
,
decode_urls
,
prefill_policy
,
decode_policy
,
}
=>
{
}
=>
{
// Only require URLs if service discovery is disabled
// Only require URLs if service discovery is disabled
if
!
has_service_discovery
{
if
!
has_service_discovery
{
...
@@ -78,6 +80,14 @@ impl ConfigValidator {
...
@@ -78,6 +80,14 @@ impl ConfigValidator {
}
}
}
}
}
}
// Validate optional prefill and decode policies
if
let
Some
(
p_policy
)
=
prefill_policy
{
Self
::
validate_policy
(
p_policy
)
?
;
}
if
let
Some
(
d_policy
)
=
decode_policy
{
Self
::
validate_policy
(
d_policy
)
?
;
}
}
}
}
}
Ok
(())
Ok
(())
...
@@ -272,6 +282,35 @@ impl ConfigValidator {
...
@@ -272,6 +282,35 @@ impl ConfigValidator {
});
});
}
}
}
}
// For PD mode, validate that policies have sufficient workers
if
let
RoutingMode
::
PrefillDecode
{
prefill_urls
,
decode_urls
,
prefill_policy
,
decode_policy
,
}
=
&
config
.mode
{
// Check power-of-two for prefill
if
let
Some
(
PolicyConfig
::
PowerOfTwo
{
..
})
=
prefill_policy
{
if
prefill_urls
.len
()
<
2
{
return
Err
(
ConfigError
::
IncompatibleConfig
{
reason
:
"Power-of-two policy for prefill requires at least 2 prefill workers"
.to_string
(),
});
}
}
// Check power-of-two for decode
if
let
Some
(
PolicyConfig
::
PowerOfTwo
{
..
})
=
decode_policy
{
if
decode_urls
.len
()
<
2
{
return
Err
(
ConfigError
::
IncompatibleConfig
{
reason
:
"Power-of-two policy for decode requires at least 2 decode workers"
.to_string
(),
});
}
}
}
}
}
Ok
(())
Ok
(())
...
@@ -430,6 +469,8 @@ mod tests {
...
@@ -430,6 +469,8 @@ mod tests {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill:8000"
.to_string
(),
Some
(
8081
))],
prefill_urls
:
vec!
[(
"http://prefill:8000"
.to_string
(),
Some
(
8081
))],
decode_urls
:
vec!
[
"http://decode:8000"
.to_string
()],
decode_urls
:
vec!
[
"http://decode:8000"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
PolicyConfig
::
Random
,
PolicyConfig
::
Random
,
);
);
...
@@ -444,6 +485,8 @@ mod tests {
...
@@ -444,6 +485,8 @@ mod tests {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill:8000"
.to_string
(),
None
)],
prefill_urls
:
vec!
[(
"http://prefill:8000"
.to_string
(),
None
)],
decode_urls
:
vec!
[
"http://decode:8000"
.to_string
()],
decode_urls
:
vec!
[
"http://decode:8000"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
PolicyConfig
::
RoundRobin
,
PolicyConfig
::
RoundRobin
,
);
);
...
@@ -459,6 +502,8 @@ mod tests {
...
@@ -459,6 +502,8 @@ mod tests {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill:8000"
.to_string
(),
None
)],
prefill_urls
:
vec!
[(
"http://prefill:8000"
.to_string
(),
None
)],
decode_urls
:
vec!
[
"http://decode:8000"
.to_string
()],
decode_urls
:
vec!
[
"http://decode:8000"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
PolicyConfig
::
CacheAware
{
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.5
,
cache_threshold
:
0.5
,
...
@@ -491,4 +536,60 @@ mod tests {
...
@@ -491,4 +536,60 @@ mod tests {
let
result
=
ConfigValidator
::
validate
(
&
config
);
let
result
=
ConfigValidator
::
validate
(
&
config
);
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
}
}
#[test]
fn
test_validate_pd_mode_with_separate_policies
()
{
// Test PD mode with different policies for prefill and decode
let
config
=
RouterConfig
::
new
(
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[
(
"http://prefill1:8000"
.to_string
(),
None
),
(
"http://prefill2:8000"
.to_string
(),
None
),
],
decode_urls
:
vec!
[
"http://decode1:8000"
.to_string
(),
"http://decode2:8000"
.to_string
(),
],
prefill_policy
:
Some
(
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
32
,
balance_rel_threshold
:
1.1
,
eviction_interval_secs
:
60
,
max_tree_size
:
1000
,
}),
decode_policy
:
Some
(
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
60
,
}),
},
PolicyConfig
::
Random
,
// Main policy as fallback
);
let
result
=
ConfigValidator
::
validate
(
&
config
);
assert
!
(
result
.is_ok
());
}
#[test]
fn
test_validate_pd_mode_power_of_two_insufficient_workers
()
{
// Test that power-of-two policy requires at least 2 workers
let
config
=
RouterConfig
::
new
(
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1:8000"
.to_string
(),
None
)],
// Only 1 prefill
decode_urls
:
vec!
[
"http://decode1:8000"
.to_string
(),
"http://decode2:8000"
.to_string
(),
],
prefill_policy
:
Some
(
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
60
,
}),
// Requires 2+ workers
decode_policy
:
None
,
},
PolicyConfig
::
Random
,
);
let
result
=
ConfigValidator
::
validate
(
&
config
);
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert
!
(
e
.to_string
()
.contains
(
"prefill requires at least 2"
));
}
}
}
}
sgl-router/src/lib.rs
View file @
2ab97023
...
@@ -54,6 +54,8 @@ struct Router {
...
@@ -54,6 +54,8 @@ struct Router {
// PD-specific fields (only used when pd_disaggregation is true)
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls
:
Option
<
Vec
<
(
String
,
Option
<
u16
>
)
>>
,
prefill_urls
:
Option
<
Vec
<
(
String
,
Option
<
u16
>
)
>>
,
decode_urls
:
Option
<
Vec
<
String
>>
,
decode_urls
:
Option
<
Vec
<
String
>>
,
prefill_policy
:
Option
<
PolicyType
>
,
decode_policy
:
Option
<
PolicyType
>
,
}
}
impl
Router
{
impl
Router
{
...
@@ -63,11 +65,31 @@ impl Router {
...
@@ -63,11 +65,31 @@ impl Router {
DiscoveryConfig
,
MetricsConfig
,
PolicyConfig
as
ConfigPolicyConfig
,
RoutingMode
,
DiscoveryConfig
,
MetricsConfig
,
PolicyConfig
as
ConfigPolicyConfig
,
RoutingMode
,
};
};
// Convert policy helper function
let
convert_policy
=
|
policy
:
&
PolicyType
|
->
ConfigPolicyConfig
{
match
policy
{
PolicyType
::
Random
=>
ConfigPolicyConfig
::
Random
,
PolicyType
::
RoundRobin
=>
ConfigPolicyConfig
::
RoundRobin
,
PolicyType
::
CacheAware
=>
ConfigPolicyConfig
::
CacheAware
{
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
eviction_interval_secs
:
self
.eviction_interval_secs
,
max_tree_size
:
self
.max_tree_size
,
},
PolicyType
::
PowerOfTwo
=>
ConfigPolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
5
,
// Default value
},
}
};
// Determine routing mode
// Determine routing mode
let
mode
=
if
self
.pd_disaggregation
{
let
mode
=
if
self
.pd_disaggregation
{
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
:
self
.prefill_urls
.clone
()
.unwrap_or_default
(),
prefill_urls
:
self
.prefill_urls
.clone
()
.unwrap_or_default
(),
decode_urls
:
self
.decode_urls
.clone
()
.unwrap_or_default
(),
decode_urls
:
self
.decode_urls
.clone
()
.unwrap_or_default
(),
prefill_policy
:
self
.prefill_policy
.as_ref
()
.map
(
convert_policy
),
decode_policy
:
self
.decode_policy
.as_ref
()
.map
(
convert_policy
),
}
}
}
else
{
}
else
{
RoutingMode
::
Regular
{
RoutingMode
::
Regular
{
...
@@ -75,21 +97,8 @@ impl Router {
...
@@ -75,21 +97,8 @@ impl Router {
}
}
};
};
// Convert policy
// Convert main policy
let
policy
=
match
self
.policy
{
let
policy
=
convert_policy
(
&
self
.policy
);
PolicyType
::
Random
=>
ConfigPolicyConfig
::
Random
,
PolicyType
::
RoundRobin
=>
ConfigPolicyConfig
::
RoundRobin
,
PolicyType
::
CacheAware
=>
ConfigPolicyConfig
::
CacheAware
{
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
eviction_interval_secs
:
self
.eviction_interval_secs
,
max_tree_size
:
self
.max_tree_size
,
},
PolicyType
::
PowerOfTwo
=>
ConfigPolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
5
,
// Default value
},
};
// Service discovery configuration
// Service discovery configuration
let
discovery
=
if
self
.service_discovery
{
let
discovery
=
if
self
.service_discovery
{
...
@@ -163,7 +172,9 @@ impl Router {
...
@@ -163,7 +172,9 @@ impl Router {
request_timeout_secs
=
600
,
// Add configurable request timeout
request_timeout_secs
=
600
,
// Add configurable request timeout
pd_disaggregation
=
false
,
// New flag for PD mode
pd_disaggregation
=
false
,
// New flag for PD mode
prefill_urls
=
None,
prefill_urls
=
None,
decode_urls
=
None
decode_urls
=
None,
prefill_policy
=
None,
decode_policy
=
None
))]
))]
fn
new
(
fn
new
(
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
...
@@ -193,6 +204,8 @@ impl Router {
...
@@ -193,6 +204,8 @@ impl Router {
pd_disaggregation
:
bool
,
pd_disaggregation
:
bool
,
prefill_urls
:
Option
<
Vec
<
(
String
,
Option
<
u16
>
)
>>
,
prefill_urls
:
Option
<
Vec
<
(
String
,
Option
<
u16
>
)
>>
,
decode_urls
:
Option
<
Vec
<
String
>>
,
decode_urls
:
Option
<
Vec
<
String
>>
,
prefill_policy
:
Option
<
PolicyType
>
,
decode_policy
:
Option
<
PolicyType
>
,
)
->
PyResult
<
Self
>
{
)
->
PyResult
<
Self
>
{
Ok
(
Router
{
Ok
(
Router
{
host
,
host
,
...
@@ -222,6 +235,8 @@ impl Router {
...
@@ -222,6 +235,8 @@ impl Router {
pd_disaggregation
,
pd_disaggregation
,
prefill_urls
,
prefill_urls
,
decode_urls
,
decode_urls
,
prefill_policy
,
decode_policy
,
})
})
}
}
...
...
sgl-router/src/policies/cache_aware.rs
View file @
2ab97023
...
@@ -254,7 +254,11 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
...
@@ -254,7 +254,11 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
decode_workers
:
&
[
Box
<
dyn
Worker
>
],
decode_workers
:
&
[
Box
<
dyn
Worker
>
],
request_text
:
Option
<&
str
>
,
request_text
:
Option
<&
str
>
,
)
->
Option
<
(
usize
,
usize
)
>
{
)
->
Option
<
(
usize
,
usize
)
>
{
// In PD mode:
// DEPRECATED: This method is no longer used when separate policies are configured.
// The PD router now uses separate policies for prefill and decode selection.
// This implementation remains for backward compatibility when a single policy is used.
// In PD mode with single policy:
// - Prefill: Use cache-aware routing for better cache utilization
// - Prefill: Use cache-aware routing for better cache utilization
// - Decode: Use least-load routing for better load distribution
// - Decode: Use least-load routing for better load distribution
...
...
sgl-router/src/routers/factory.rs
View file @
2ab97023
...
@@ -17,7 +17,16 @@ impl RouterFactory {
...
@@ -17,7 +17,16 @@ impl RouterFactory {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
,
prefill_urls
,
decode_urls
,
decode_urls
,
}
=>
Self
::
create_pd_router
(
prefill_urls
,
decode_urls
,
&
config
.policy
,
config
),
prefill_policy
,
decode_policy
,
}
=>
Self
::
create_pd_router
(
prefill_urls
,
decode_urls
,
prefill_policy
.as_ref
(),
decode_policy
.as_ref
(),
&
config
.policy
,
config
,
),
}
}
}
}
...
@@ -45,18 +54,23 @@ impl RouterFactory {
...
@@ -45,18 +54,23 @@ impl RouterFactory {
fn
create_pd_router
(
fn
create_pd_router
(
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
decode_urls
:
&
[
String
],
decode_urls
:
&
[
String
],
policy_config
:
&
PolicyConfig
,
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
decode_policy_config
:
Option
<&
PolicyConfig
>
,
main_policy_config
:
&
PolicyConfig
,
router_config
:
&
RouterConfig
,
router_config
:
&
RouterConfig
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// Create policy directly from PolicyConfig
// Create policies - use specific policies if provided, otherwise fall back to main policy
// All policies now support PD mode through the select_worker_pair method
let
prefill_policy
=
let
policy
=
PolicyFactory
::
create_from_config
(
policy_config
);
PolicyFactory
::
create_from_config
(
prefill_policy_config
.unwrap_or
(
main_policy_config
));
let
decode_policy
=
PolicyFactory
::
create_from_config
(
decode_policy_config
.unwrap_or
(
main_policy_config
));
// Create PD router with
injec
te
d
polic
y
// Create PD router with
separa
te polic
ies
let
router
=
PDRouter
::
new
(
let
router
=
PDRouter
::
new
(
prefill_urls
.to_vec
(),
prefill_urls
.to_vec
(),
decode_urls
.to_vec
(),
decode_urls
.to_vec
(),
policy
,
prefill_policy
,
decode_policy
,
router_config
.worker_startup_timeout_secs
,
router_config
.worker_startup_timeout_secs
,
router_config
.worker_startup_check_interval_secs
,
router_config
.worker_startup_check_interval_secs
,
)
?
;
)
?
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
2ab97023
...
@@ -22,8 +22,10 @@ use uuid::Uuid;
...
@@ -22,8 +22,10 @@ use uuid::Uuid;
pub
struct
PDRouter
{
pub
struct
PDRouter
{
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
pub
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
pub
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
pub
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
decode_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
timeout_secs
:
u64
,
pub
timeout_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
...
@@ -66,7 +68,7 @@ impl PDRouter {
...
@@ -66,7 +68,7 @@ impl PDRouter {
workers
.push
(
worker
);
workers
.push
(
worker
);
// Add to cache tree if using cache-aware policy
// Add to cache tree if using cache-aware policy
for prefill
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
url
);
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
url
);
}
}
...
@@ -102,6 +104,11 @@ impl PDRouter {
...
@@ -102,6 +104,11 @@ impl PDRouter {
workers
.push
(
worker
);
workers
.push
(
worker
);
// Add to cache tree if using cache-aware policy for decode
if
let
Some
(
ref
tree
)
=
self
.decode_tree
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
url
);
}
info!
(
"Added decode server: {}"
,
url
);
info!
(
"Added decode server: {}"
,
url
);
Ok
(
format!
(
"Successfully added decode server: {}"
,
url
))
Ok
(
format!
(
"Successfully added decode server: {}"
,
url
))
}
}
...
@@ -126,12 +133,7 @@ impl PDRouter {
...
@@ -126,12 +133,7 @@ impl PDRouter {
// Remove from cache tree if using cache-aware policy
// Remove from cache tree if using cache-aware policy
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
// Note: Tree doesn't have a remove method, so we rebuild it
tree
.lock
()
.unwrap
()
.remove_tenant
(
url
);
let
mut
tree_guard
=
tree
.lock
()
.unwrap
();
*
tree_guard
=
Tree
::
new
();
for
worker
in
workers
.iter
()
{
tree_guard
.insert
(
""
,
worker
.url
());
}
}
}
info!
(
"Removed prefill server: {}"
,
url
);
info!
(
"Removed prefill server: {}"
,
url
);
...
@@ -156,6 +158,11 @@ impl PDRouter {
...
@@ -156,6 +158,11 @@ impl PDRouter {
});
});
}
}
// Remove from the cache tree if using cache-aware policy for decode
if
let
Some
(
ref
tree
)
=
self
.decode_tree
{
tree
.lock
()
.unwrap
()
.remove_tenant
(
url
);
}
info!
(
"Removed decode server: {}"
,
url
);
info!
(
"Removed decode server: {}"
,
url
);
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
}
}
...
@@ -163,7 +170,8 @@ impl PDRouter {
...
@@ -163,7 +170,8 @@ impl PDRouter {
pub
fn
new
(
pub
fn
new
(
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
decode_urls
:
Vec
<
String
>
,
decode_urls
:
Vec
<
String
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
timeout_secs
:
u64
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
Self
,
String
>
{
)
->
Result
<
Self
,
String
>
{
...
@@ -192,10 +200,10 @@ impl PDRouter {
...
@@ -192,10 +200,10 @@ impl PDRouter {
)
?
;
)
?
;
}
}
// Initialize cache-aware components if needed
// Initialize cache-aware components if needed
for prefill policy
let
prefill_tree
=
if
policy
.name
()
==
"cache_aware"
{
let
prefill_tree
=
if
prefill_
policy
.name
()
==
"cache_aware"
{
// Initialize the policy's internal tree with prefill workers
// Initialize the policy's internal tree with prefill workers
if
let
Some
(
cache_policy
)
=
policy
if
let
Some
(
cache_policy
)
=
prefill_
policy
.as_any
()
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
{
...
@@ -212,6 +220,26 @@ impl PDRouter {
...
@@ -212,6 +220,26 @@ impl PDRouter {
None
None
};
};
// Initialize cache-aware components if needed for decode policy
let
decode_tree
=
if
decode_policy
.name
()
==
"cache_aware"
{
// Initialize the policy's internal tree with decode workers
if
let
Some
(
cache_policy
)
=
decode_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.init_workers
(
&
decode_workers
);
}
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
// Initialize tree with decode workers
for
worker
in
&
decode_workers
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker
.url
());
}
Some
(
tree
)
}
else
{
None
};
// Set up background load monitoring for power-of-two selection
// Set up background load monitoring for power-of-two selection
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
worker_loads
=
Arc
::
new
(
rx
);
let
worker_loads
=
Arc
::
new
(
rx
);
...
@@ -222,25 +250,28 @@ impl PDRouter {
...
@@ -222,25 +250,28 @@ impl PDRouter {
.build
()
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
let
load_monitor_handle
=
if
policy
.name
()
==
"power_of_two"
{
let
load_monitor_handle
=
let
monitor_urls
=
all_urls
.clone
();
if
prefill_policy
.name
()
==
"power_of_two"
||
decode_policy
.name
()
==
"power_of_two"
{
let
monitor_interval
=
interval_secs
;
let
monitor_urls
=
all_urls
.clone
();
let
monitor_client
=
http_client
.clone
();
let
monitor_interval
=
interval_secs
;
let
policy_clone
=
Arc
::
clone
(
&
policy
);
let
monitor_client
=
http_client
.clone
();
let
prefill_policy_clone
=
Arc
::
clone
(
&
prefill_policy
);
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
let
decode_policy_clone
=
Arc
::
clone
(
&
decode_policy
);
Self
::
monitor_worker_loads_with_client
(
monitor_urls
,
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
tx
,
Self
::
monitor_worker_loads_with_client
(
monitor_interval
,
monitor_urls
,
monitor_client
,
tx
,
policy_clone
,
monitor_interval
,
)
monitor_client
,
.await
;
prefill_policy_clone
,
})))
decode_policy_clone
,
}
else
{
)
None
.await
;
};
})))
}
else
{
None
};
let
prefill_workers
=
Arc
::
new
(
RwLock
::
new
(
prefill_workers
));
let
prefill_workers
=
Arc
::
new
(
RwLock
::
new
(
prefill_workers
));
let
decode_workers
=
Arc
::
new
(
RwLock
::
new
(
decode_workers
));
let
decode_workers
=
Arc
::
new
(
RwLock
::
new
(
decode_workers
));
...
@@ -254,8 +285,10 @@ impl PDRouter {
...
@@ -254,8 +285,10 @@ impl PDRouter {
Ok
(
PDRouter
{
Ok
(
PDRouter
{
prefill_workers
,
prefill_workers
,
decode_workers
,
decode_workers
,
policy
,
prefill_policy
,
decode_policy
,
prefill_tree
,
prefill_tree
,
decode_tree
,
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
worker_loads
,
worker_loads
,
...
@@ -736,18 +769,21 @@ impl PDRouter {
...
@@ -736,18 +769,21 @@ impl PDRouter {
return
Err
(
"No decode workers available. Please check if decode servers are configured and healthy."
.to_string
());
return
Err
(
"No decode workers available. Please check if decode servers are configured and healthy."
.to_string
());
}
}
// Use the policy to select worker pair
// Select prefill worker using prefill policy
match
self
let
prefill_idx
=
self
.policy
.prefill_policy
.select_worker_pair
(
&
prefill_workers
,
&
decode_workers
,
request_text
)
.select_worker
(
&
prefill_workers
,
request_text
)
{
.ok_or
(
"Failed to select prefill worker"
)
?
;
Some
((
prefill_idx
,
decode_idx
))
=>
{
let
prefill
=
prefill_workers
[
prefill_idx
]
.clone_worker
();
// Select decode worker using decode policy
let
decode
=
decode_workers
[
decode_idx
]
.clone_worker
();
let
decode_idx
=
self
Ok
((
prefill
,
decode
))
.decode_policy
}
.select_worker
(
&
decode_workers
,
request_text
)
None
=>
Err
(
"Failed to select worker pair"
.to_string
()),
.ok_or
(
"Failed to select decode worker"
)
?
;
}
let
prefill
=
prefill_workers
[
prefill_idx
]
.clone_worker
();
let
decode
=
decode_workers
[
decode_idx
]
.clone_worker
();
Ok
((
prefill
,
decode
))
}
}
// Background task to monitor worker loads with shared client
// Background task to monitor worker loads with shared client
...
@@ -756,7 +792,8 @@ impl PDRouter {
...
@@ -756,7 +792,8 @@ impl PDRouter {
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
interval_secs
:
u64
,
interval_secs
:
u64
,
client
:
reqwest
::
Client
,
client
:
reqwest
::
Client
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
)
{
)
{
loop
{
loop
{
let
mut
loads
=
HashMap
::
new
();
let
mut
loads
=
HashMap
::
new
();
...
@@ -781,8 +818,9 @@ impl PDRouter {
...
@@ -781,8 +818,9 @@ impl PDRouter {
debug!
(
"Worker loads updated: {:?}"
,
loads
);
debug!
(
"Worker loads updated: {:?}"
,
loads
);
// Update the policy with current loads
// Update both policies with current loads
policy
.update_loads
(
&
loads
);
prefill_policy
.update_loads
(
&
loads
);
decode_policy
.update_loads
(
&
loads
);
// Check if receiver is still active
// Check if receiver is still active
if
tx
.send
(
loads
)
.is_err
()
{
if
tx
.send
(
loads
)
.is_err
()
{
...
@@ -1463,13 +1501,16 @@ mod tests {
...
@@ -1463,13 +1501,16 @@ mod tests {
use
actix_web
::
test
::
TestRequest
;
use
actix_web
::
test
::
TestRequest
;
fn
create_test_pd_router
()
->
PDRouter
{
fn
create_test_pd_router
()
->
PDRouter
{
let
policy
=
Arc
::
new
(
RandomPolicy
::
new
());
let
prefill_policy
=
Arc
::
new
(
RandomPolicy
::
new
());
let
decode_policy
=
Arc
::
new
(
RandomPolicy
::
new
());
PDRouter
{
PDRouter
{
prefill_workers
:
Arc
::
new
(
RwLock
::
new
(
vec!
[])),
prefill_workers
:
Arc
::
new
(
RwLock
::
new
(
vec!
[])),
decode_workers
:
Arc
::
new
(
RwLock
::
new
(
vec!
[])),
decode_workers
:
Arc
::
new
(
RwLock
::
new
(
vec!
[])),
policy
,
prefill_policy
,
decode_policy
,
prefill_tree
:
None
,
prefill_tree
:
None
,
decode_tree
:
None
,
timeout_secs
:
5
,
timeout_secs
:
5
,
interval_secs
:
1
,
interval_secs
:
1
,
worker_loads
:
Arc
::
new
(
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
())
.1
),
worker_loads
:
Arc
::
new
(
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
())
.1
),
...
@@ -1608,9 +1649,9 @@ mod tests {
...
@@ -1608,9 +1649,9 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_cache_tree_operations
()
{
async
fn
test_cache_tree_operations
()
{
let
policy
=
Arc
::
new
(
CacheAwarePolicy
::
new
());
let
cache_
policy
=
Arc
::
new
(
CacheAwarePolicy
::
new
());
let
mut
router
=
create_test_pd_router
();
let
mut
router
=
create_test_pd_router
();
router
.policy
=
policy
;
router
.
prefill_
policy
=
cache_
policy
;
// Initialize cache tree
// Initialize cache tree
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
...
@@ -1638,9 +1679,9 @@ mod tests {
...
@@ -1638,9 +1679,9 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_cache_tree_rebuild_on_remove
()
{
async
fn
test_cache_tree_rebuild_on_remove
()
{
let
policy
=
Arc
::
new
(
CacheAwarePolicy
::
new
());
let
cache_
policy
=
Arc
::
new
(
CacheAwarePolicy
::
new
());
let
mut
router
=
create_test_pd_router
();
let
mut
router
=
create_test_pd_router
();
router
.policy
=
policy
;
router
.
prefill_
policy
=
cache_
policy
;
// Initialize cache tree
// Initialize cache tree
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
...
@@ -1880,9 +1921,10 @@ mod tests {
...
@@ -1880,9 +1921,10 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_load_monitor_updates
()
{
async
fn
test_load_monitor_updates
()
{
let
policy
=
Arc
::
new
(
crate
::
policies
::
PowerOfTwoPolicy
::
new
());
let
power_of_two_
policy
=
Arc
::
new
(
crate
::
policies
::
PowerOfTwoPolicy
::
new
());
let
mut
router
=
create_test_pd_router
();
let
mut
router
=
create_test_pd_router
();
router
.policy
=
policy
;
router
.prefill_policy
=
power_of_two_policy
.clone
();
router
.decode_policy
=
power_of_two_policy
;
// Create load channel
// Create load channel
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
...
...
sgl-router/tests/test_pd_routing.rs
View file @
2ab97023
...
@@ -122,6 +122,8 @@ mod test_pd_routing {
...
@@ -122,6 +122,8 @@ mod test_pd_routing {
"http://decode1:8080"
.to_string
(),
"http://decode1:8080"
.to_string
(),
"http://decode2:8080"
.to_string
(),
"http://decode2:8080"
.to_string
(),
],
],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
PolicyConfig
::
Random
,
PolicyConfig
::
Random
,
),
),
...
@@ -129,6 +131,8 @@ mod test_pd_routing {
...
@@ -129,6 +131,8 @@ mod test_pd_routing {
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill:8080"
.to_string
(),
Some
(
9000
))],
prefill_urls
:
vec!
[(
"http://prefill:8080"
.to_string
(),
Some
(
9000
))],
decode_urls
:
vec!
[
"http://decode:8080"
.to_string
()],
decode_urls
:
vec!
[
"http://decode:8080"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
PolicyConfig
::
PowerOfTwo
{
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
5
,
load_check_interval_secs
:
5
,
...
@@ -142,6 +146,8 @@ mod test_pd_routing {
...
@@ -142,6 +146,8 @@ mod test_pd_routing {
(
"http://p3:8080"
.to_string
(),
Some
(
9002
)),
(
"http://p3:8080"
.to_string
(),
Some
(
9002
)),
],
],
decode_urls
:
vec!
[
"http://d1:8080"
.to_string
(),
"http://d2:8080"
.to_string
()],
decode_urls
:
vec!
[
"http://d1:8080"
.to_string
(),
"http://d2:8080"
.to_string
()],
prefill_policy
:
None
,
decode_policy
:
None
,
},
},
PolicyConfig
::
CacheAware
{
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.7
,
cache_threshold
:
0.7
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment