Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0311ce8e
Unverified
Commit
0311ce8e
authored
Jan 20, 2025
by
Byron Hsu
Committed by
GitHub
Jan 20, 2025
Browse files
[router] Expose worker startup secs & Return error instead of panic for router init (#3016)
parent
5dfcacfc
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
124 additions
and
47 deletions
+124
-47
sgl-router/py_src/sglang_router/launch_router.py
sgl-router/py_src/sglang_router/launch_router.py
+14
-6
sgl-router/py_src/sglang_router/launch_server.py
sgl-router/py_src/sglang_router/launch_server.py
+27
-5
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+3
-0
sgl-router/py_test/test_launch_router.py
sgl-router/py_test/test_launch_router.py
+1
-0
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+14
-6
sgl-router/src/router.rs
sgl-router/src/router.rs
+43
-11
sgl-router/src/server.rs
sgl-router/src/server.rs
+22
-19
No files found.
sgl-router/py_src/sglang_router/launch_router.py
View file @
0311ce8e
...
@@ -33,6 +33,7 @@ class RouterArgs:
...
@@ -33,6 +33,7 @@ class RouterArgs:
# Routing policy
# Routing policy
policy
:
str
=
"cache_aware"
policy
:
str
=
"cache_aware"
worker_startup_timeout_secs
:
int
=
300
cache_threshold
:
float
=
0.5
cache_threshold
:
float
=
0.5
balance_abs_threshold
:
int
=
32
balance_abs_threshold
:
int
=
32
balance_rel_threshold
:
float
=
1.0001
balance_rel_threshold
:
float
=
1.0001
...
@@ -87,6 +88,12 @@ class RouterArgs:
...
@@ -87,6 +88,12 @@ class RouterArgs:
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
],
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
],
help
=
"Load balancing policy to use"
,
help
=
"Load balancing policy to use"
,
)
)
parser
.
add_argument
(
f
"--
{
prefix
}
worker-startup-timeout-secs"
,
type
=
int
,
default
=
RouterArgs
.
worker_startup_timeout_secs
,
help
=
"Timeout in seconds for worker startup"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
f
"--
{
prefix
}
cache-threshold"
,
f
"--
{
prefix
}
cache-threshold"
,
type
=
float
,
type
=
float
,
...
@@ -147,6 +154,9 @@ class RouterArgs:
...
@@ -147,6 +154,9 @@ class RouterArgs:
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
policy
=
getattr
(
args
,
f
"
{
prefix
}
policy"
),
policy
=
getattr
(
args
,
f
"
{
prefix
}
policy"
),
worker_startup_timeout_secs
=
getattr
(
args
,
f
"
{
prefix
}
worker_startup_timeout_secs"
),
cache_threshold
=
getattr
(
args
,
f
"
{
prefix
}
cache_threshold"
),
cache_threshold
=
getattr
(
args
,
f
"
{
prefix
}
cache_threshold"
),
balance_abs_threshold
=
getattr
(
args
,
f
"
{
prefix
}
balance_abs_threshold"
),
balance_abs_threshold
=
getattr
(
args
,
f
"
{
prefix
}
balance_abs_threshold"
),
balance_rel_threshold
=
getattr
(
args
,
f
"
{
prefix
}
balance_rel_threshold"
),
balance_rel_threshold
=
getattr
(
args
,
f
"
{
prefix
}
balance_rel_threshold"
),
...
@@ -188,9 +198,10 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -188,9 +198,10 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router
=
Router
(
router
=
Router
(
worker_urls
=
router_args
.
worker_urls
,
worker_urls
=
router_args
.
worker_urls
,
policy
=
policy_from_str
(
router_args
.
policy
),
host
=
router_args
.
host
,
host
=
router_args
.
host
,
port
=
router_args
.
port
,
port
=
router_args
.
port
,
policy
=
policy_from_str
(
router_args
.
policy
),
worker_startup_timeout_secs
=
router_args
.
worker_startup_timeout_secs
,
cache_threshold
=
router_args
.
cache_threshold
,
cache_threshold
=
router_args
.
cache_threshold
,
balance_abs_threshold
=
router_args
.
balance_abs_threshold
,
balance_abs_threshold
=
router_args
.
balance_abs_threshold
,
balance_rel_threshold
=
router_args
.
balance_rel_threshold
,
balance_rel_threshold
=
router_args
.
balance_rel_threshold
,
...
@@ -205,7 +216,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -205,7 +216,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error starting router:
{
e
}
"
)
logger
.
error
(
f
"Error starting router:
{
e
}
"
)
r
eturn
Non
e
r
aise
e
class
CustomHelpFormatter
(
class
CustomHelpFormatter
(
...
@@ -239,10 +250,7 @@ Examples:
...
@@ -239,10 +250,7 @@ Examples:
def
main
()
->
None
:
def
main
()
->
None
:
router_args
=
parse_router_args
(
sys
.
argv
[
1
:])
router_args
=
parse_router_args
(
sys
.
argv
[
1
:])
router
=
launch_router
(
router_args
)
launch_router
(
router_args
)
if
router
is
None
:
sys
.
exit
(
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
sgl-router/py_src/sglang_router/launch_server.py
View file @
0311ce8e
...
@@ -68,7 +68,7 @@ def run_server(server_args, dp_rank):
...
@@ -68,7 +68,7 @@ def run_server(server_args, dp_rank):
# create new process group
# create new process group
os
.
setpgrp
()
os
.
setpgrp
()
setproctitle
(
f
"sglang::server"
)
setproctitle
(
"sglang::server"
)
# Set SGLANG_DP_RANK environment variable
# Set SGLANG_DP_RANK environment variable
os
.
environ
[
"SGLANG_DP_RANK"
]
=
str
(
dp_rank
)
os
.
environ
[
"SGLANG_DP_RANK"
]
=
str
(
dp_rank
)
...
@@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
...
@@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
def
cleanup_processes
(
processes
:
List
[
mp
.
Process
]):
def
cleanup_processes
(
processes
:
List
[
mp
.
Process
]):
for
process
in
processes
:
for
process
in
processes
:
logger
.
info
(
f
"Terminating process
{
process
.
pid
}
"
)
logger
.
info
(
f
"Terminating process group
{
process
.
pid
}
"
)
process
.
terminate
()
try
:
logger
.
info
(
"All processes terminated"
)
os
.
killpg
(
process
.
pid
,
signal
.
SIGTERM
)
except
ProcessLookupError
:
# Process group may already be terminated
pass
# Wait for processes to terminate
for
process
in
processes
:
process
.
join
(
timeout
=
5
)
if
process
.
is_alive
():
logger
.
warning
(
f
"Process
{
process
.
pid
}
did not terminate gracefully, forcing kill"
)
try
:
os
.
killpg
(
process
.
pid
,
signal
.
SIGKILL
)
except
ProcessLookupError
:
pass
logger
.
info
(
"All process groups terminated"
)
def
main
():
def
main
():
...
@@ -173,7 +190,12 @@ def main():
...
@@ -173,7 +190,12 @@ def main():
]
]
# Start the router
# Start the router
router
=
launch_router
(
router_args
)
try
:
launch_router
(
router_args
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start router:
{
e
}
"
)
cleanup_processes
(
server_processes
)
sys
.
exit
(
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
sgl-router/py_src/sglang_router/router.py
View file @
0311ce8e
...
@@ -17,6 +17,7 @@ class Router:
...
@@ -17,6 +17,7 @@ class Router:
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
host: Host address to bind the router server. Default: '127.0.0.1'
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
tree. Default: 0.5
...
@@ -37,6 +38,7 @@ class Router:
...
@@ -37,6 +38,7 @@ class Router:
policy
:
PolicyType
=
PolicyType
.
RoundRobin
,
policy
:
PolicyType
=
PolicyType
.
RoundRobin
,
host
:
str
=
"127.0.0.1"
,
host
:
str
=
"127.0.0.1"
,
port
:
int
=
3001
,
port
:
int
=
3001
,
worker_startup_timeout_secs
:
int
=
300
,
cache_threshold
:
float
=
0.50
,
cache_threshold
:
float
=
0.50
,
balance_abs_threshold
:
int
=
32
,
balance_abs_threshold
:
int
=
32
,
balance_rel_threshold
:
float
=
1.0001
,
balance_rel_threshold
:
float
=
1.0001
,
...
@@ -50,6 +52,7 @@ class Router:
...
@@ -50,6 +52,7 @@ class Router:
policy
=
policy
,
policy
=
policy
,
host
=
host
,
host
=
host
,
port
=
port
,
port
=
port
,
worker_startup_timeout_secs
=
worker_startup_timeout_secs
,
cache_threshold
=
cache_threshold
,
cache_threshold
=
cache_threshold
,
balance_abs_threshold
=
balance_abs_threshold
,
balance_abs_threshold
=
balance_abs_threshold
,
balance_rel_threshold
=
balance_rel_threshold
,
balance_rel_threshold
=
balance_rel_threshold
,
...
...
sgl-router/py_test/test_launch_router.py
View file @
0311ce8e
...
@@ -28,6 +28,7 @@ class TestLaunchRouter(unittest.TestCase):
...
@@ -28,6 +28,7 @@ class TestLaunchRouter(unittest.TestCase):
host
=
"127.0.0.1"
,
host
=
"127.0.0.1"
,
port
=
30000
,
port
=
30000
,
policy
=
"cache_aware"
,
policy
=
"cache_aware"
,
worker_startup_timeout_secs
=
600
,
cache_threshold
=
0.5
,
cache_threshold
=
0.5
,
balance_abs_threshold
=
32
,
balance_abs_threshold
=
32
,
balance_rel_threshold
=
1.0001
,
balance_rel_threshold
=
1.0001
,
...
...
sgl-router/src/lib.rs
View file @
0311ce8e
...
@@ -17,6 +17,7 @@ struct Router {
...
@@ -17,6 +17,7 @@ struct Router {
port
:
u16
,
port
:
u16
,
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
policy
:
PolicyType
,
policy
:
PolicyType
,
worker_startup_timeout_secs
:
u64
,
cache_threshold
:
f32
,
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
...
@@ -34,6 +35,7 @@ impl Router {
...
@@ -34,6 +35,7 @@ impl Router {
policy
=
PolicyType::RoundRobin,
policy
=
PolicyType::RoundRobin,
host
=
String::from(
"127.0.0.1"
),
host
=
String::from(
"127.0.0.1"
),
port
=
3001
,
port
=
3001
,
worker_startup_timeout_secs
=
300
,
cache_threshold
=
0.50
,
cache_threshold
=
0.50
,
balance_abs_threshold
=
32
,
balance_abs_threshold
=
32
,
balance_rel_threshold
=
1.0001
,
balance_rel_threshold
=
1.0001
,
...
@@ -47,6 +49,7 @@ impl Router {
...
@@ -47,6 +49,7 @@ impl Router {
policy
:
PolicyType
,
policy
:
PolicyType
,
host
:
String
,
host
:
String
,
port
:
u16
,
port
:
u16
,
worker_startup_timeout_secs
:
u64
,
cache_threshold
:
f32
,
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
...
@@ -60,6 +63,7 @@ impl Router {
...
@@ -60,6 +63,7 @@ impl Router {
port
,
port
,
worker_urls
,
worker_urls
,
policy
,
policy
,
worker_startup_timeout_secs
,
cache_threshold
,
cache_threshold
,
balance_abs_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
balance_rel_threshold
,
...
@@ -72,9 +76,14 @@ impl Router {
...
@@ -72,9 +76,14 @@ impl Router {
fn
start
(
&
self
)
->
PyResult
<
()
>
{
fn
start
(
&
self
)
->
PyResult
<
()
>
{
let
policy_config
=
match
&
self
.policy
{
let
policy_config
=
match
&
self
.policy
{
PolicyType
::
Random
=>
router
::
PolicyConfig
::
RandomConfig
,
PolicyType
::
Random
=>
router
::
PolicyConfig
::
RandomConfig
{
PolicyType
::
RoundRobin
=>
router
::
PolicyConfig
::
RoundRobinConfig
,
timeout_secs
:
self
.worker_startup_timeout_secs
,
},
PolicyType
::
RoundRobin
=>
router
::
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
},
PolicyType
::
CacheAware
=>
router
::
PolicyConfig
::
CacheAwareConfig
{
PolicyType
::
CacheAware
=>
router
::
PolicyConfig
::
CacheAwareConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
cache_threshold
:
self
.cache_threshold
,
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
...
@@ -93,10 +102,9 @@ impl Router {
...
@@ -93,10 +102,9 @@ impl Router {
max_payload_size
:
self
.max_payload_size
,
max_payload_size
:
self
.max_payload_size
,
})
})
.await
.await
.unwrap
();
.map_err
(|
e
|
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
e
.to_string
()))
?
;
});
Ok
(())
})
Ok
(())
}
}
}
}
...
...
sgl-router/src/router.rs
View file @
0311ce8e
...
@@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
...
@@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
bytes
::
Bytes
;
use
bytes
::
Bytes
;
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
log
::{
debug
,
info
,
warn
};
use
log
::{
debug
,
error
,
info
,
warn
};
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
...
@@ -17,9 +17,11 @@ pub enum Router {
...
@@ -17,9 +17,11 @@ pub enum Router {
RoundRobin
{
RoundRobin
{
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
current_index
:
AtomicUsize
,
current_index
:
AtomicUsize
,
timeout_secs
:
u64
,
},
},
Random
{
Random
{
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
timeout_secs
:
u64
,
},
},
CacheAware
{
CacheAware
{
/*
/*
...
@@ -89,36 +91,51 @@ pub enum Router {
...
@@ -89,36 +91,51 @@ pub enum Router {
cache_threshold
:
f32
,
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
timeout_secs
:
u64
,
_
eviction_thread
:
Option
<
thread
::
JoinHandle
<
()
>>
,
_
eviction_thread
:
Option
<
thread
::
JoinHandle
<
()
>>
,
},
},
}
}
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
enum
PolicyConfig
{
pub
enum
PolicyConfig
{
RandomConfig
,
RandomConfig
{
RoundRobinConfig
,
timeout_secs
:
u64
,
},
RoundRobinConfig
{
timeout_secs
:
u64
,
},
CacheAwareConfig
{
CacheAwareConfig
{
cache_threshold
:
f32
,
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
eviction_interval_secs
:
u64
,
eviction_interval_secs
:
u64
,
max_tree_size
:
usize
,
max_tree_size
:
usize
,
timeout_secs
:
u64
,
},
},
}
}
impl
Router
{
impl
Router
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
)
->
Result
<
Self
,
String
>
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
)
->
Result
<
Self
,
String
>
{
// Get timeout from policy config
let
timeout_secs
=
match
&
policy_config
{
PolicyConfig
::
RandomConfig
{
timeout_secs
}
=>
*
timeout_secs
,
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
}
=>
*
timeout_secs
,
PolicyConfig
::
CacheAwareConfig
{
timeout_secs
,
..
}
=>
*
timeout_secs
,
};
// Wait until all workers are healthy
// Wait until all workers are healthy
Self
::
wait_for_healthy_workers
(
&
worker_urls
,
300
,
10
)
?
;
Self
::
wait_for_healthy_workers
(
&
worker_urls
,
timeout_secs
,
10
)
?
;
// Create router based on policy...
// Create router based on policy...
Ok
(
match
policy_config
{
Ok
(
match
policy_config
{
PolicyConfig
::
RandomConfig
=>
Router
::
Random
{
PolicyConfig
::
RandomConfig
{
timeout_secs
}
=>
Router
::
Random
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
timeout_secs
,
},
},
PolicyConfig
::
RoundRobinConfig
=>
Router
::
RoundRobin
{
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
}
=>
Router
::
RoundRobin
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
timeout_secs
,
},
},
PolicyConfig
::
CacheAwareConfig
{
PolicyConfig
::
CacheAwareConfig
{
cache_threshold
,
cache_threshold
,
...
@@ -126,6 +143,7 @@ impl Router {
...
@@ -126,6 +143,7 @@ impl Router {
balance_rel_threshold
,
balance_rel_threshold
,
eviction_interval_secs
,
eviction_interval_secs
,
max_tree_size
,
max_tree_size
,
timeout_secs
,
}
=>
{
}
=>
{
let
mut
running_queue
=
HashMap
::
new
();
let
mut
running_queue
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
for
url
in
&
worker_urls
{
...
@@ -176,6 +194,7 @@ impl Router {
...
@@ -176,6 +194,7 @@ impl Router {
cache_threshold
,
cache_threshold
,
balance_abs_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
balance_rel_threshold
,
timeout_secs
,
_
eviction_thread
:
Some
(
eviction_thread
),
_
eviction_thread
:
Some
(
eviction_thread
),
}
}
}
}
...
@@ -192,6 +211,10 @@ impl Router {
...
@@ -192,6 +211,10 @@ impl Router {
loop
{
loop
{
if
start_time
.elapsed
()
>
Duration
::
from_secs
(
timeout_secs
)
{
if
start_time
.elapsed
()
>
Duration
::
from_secs
(
timeout_secs
)
{
error!
(
"Timeout {}s waiting for workers to become healthy"
,
timeout_secs
);
return
Err
(
format!
(
return
Err
(
format!
(
"Timeout {}s waiting for workers to become healthy"
,
"Timeout {}s waiting for workers to become healthy"
,
timeout_secs
timeout_secs
...
@@ -238,7 +261,7 @@ impl Router {
...
@@ -238,7 +261,7 @@ impl Router {
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
Random
{
worker_urls
,
..
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
if
worker_urls
.read
()
.unwrap
()
.is_empty
()
{
if
worker_urls
.read
()
.unwrap
()
.is_empty
()
{
Err
(
"No workers are available"
.to_string
())
Err
(
"No workers are available"
.to_string
())
...
@@ -349,6 +372,7 @@ impl Router {
...
@@ -349,6 +372,7 @@ impl Router {
Router
::
RoundRobin
{
Router
::
RoundRobin
{
worker_urls
,
worker_urls
,
current_index
,
current_index
,
..
}
=>
{
}
=>
{
let
idx
=
current_index
let
idx
=
current_index
.fetch_update
(
.fetch_update
(
...
@@ -360,7 +384,7 @@ impl Router {
...
@@ -360,7 +384,7 @@ impl Router {
worker_urls
.read
()
.unwrap
()[
idx
]
.clone
()
worker_urls
.read
()
.unwrap
()[
idx
]
.clone
()
}
}
Router
::
Random
{
worker_urls
}
=>
worker_urls
.read
()
.unwrap
()
Router
::
Random
{
worker_urls
,
..
}
=>
worker_urls
.read
()
.unwrap
()
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.read
()
.unwrap
()
.len
()]
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.read
()
.unwrap
()
.len
()]
.clone
(),
.clone
(),
...
@@ -571,13 +595,21 @@ impl Router {
...
@@ -571,13 +595,21 @@ impl Router {
pub
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
pub
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
let
interval_secs
=
10
;
// check every 10 seconds
let
interval_secs
=
10
;
// check every 10 seconds
let
timeout_secs
=
300
;
// 5 minutes
let
timeout_secs
=
match
self
{
Router
::
Random
{
timeout_secs
,
..
}
=>
*
timeout_secs
,
Router
::
RoundRobin
{
timeout_secs
,
..
}
=>
*
timeout_secs
,
Router
::
CacheAware
{
timeout_secs
,
..
}
=>
*
timeout_secs
,
};
let
start_time
=
std
::
time
::
Instant
::
now
();
let
start_time
=
std
::
time
::
Instant
::
now
();
let
client
=
reqwest
::
Client
::
new
();
let
client
=
reqwest
::
Client
::
new
();
loop
{
loop
{
if
start_time
.elapsed
()
>
Duration
::
from_secs
(
timeout_secs
)
{
if
start_time
.elapsed
()
>
Duration
::
from_secs
(
timeout_secs
)
{
error!
(
"Timeout {}s waiting for worker {} to become healthy"
,
timeout_secs
,
worker_url
);
return
Err
(
format!
(
return
Err
(
format!
(
"Timeout {}s waiting for worker {} to become healthy"
,
"Timeout {}s waiting for worker {} to become healthy"
,
timeout_secs
,
worker_url
timeout_secs
,
worker_url
...
@@ -589,7 +621,7 @@ impl Router {
...
@@ -589,7 +621,7 @@ impl Router {
if
res
.status
()
.is_success
()
{
if
res
.status
()
.is_success
()
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
Random
{
worker_urls
,
..
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
info!
(
"Worker {} health check passed"
,
worker_url
);
info!
(
"Worker {} health check passed"
,
worker_url
);
let
mut
urls
=
worker_urls
.write
()
.unwrap
();
let
mut
urls
=
worker_urls
.write
()
.unwrap
();
...
@@ -663,7 +695,7 @@ impl Router {
...
@@ -663,7 +695,7 @@ impl Router {
pub
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
pub
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
Random
{
worker_urls
,
..
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
let
mut
urls
=
worker_urls
.write
()
.unwrap
();
let
mut
urls
=
worker_urls
.write
()
.unwrap
();
if
let
Some
(
index
)
=
urls
.iter
()
.position
(|
url
|
url
==
&
worker_url
)
{
if
let
Some
(
index
)
=
urls
.iter
()
.position
(|
url
|
url
==
&
worker_url
)
{
...
...
sgl-router/src/server.rs
View file @
0311ce8e
...
@@ -18,14 +18,10 @@ impl AppState {
...
@@ -18,14 +18,10 @@ impl AppState {
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
client
:
reqwest
::
Client
,
client
:
reqwest
::
Client
,
policy_config
:
PolicyConfig
,
policy_config
:
PolicyConfig
,
)
->
Self
{
)
->
Result
<
Self
,
String
>
{
// Create router based on policy
// Create router based on policy
let
router
=
match
Router
::
new
(
worker_urls
,
policy_config
)
{
let
router
=
Router
::
new
(
worker_urls
,
policy_config
)
?
;
Ok
(
router
)
=>
router
,
Ok
(
Self
{
router
,
client
})
Err
(
error
)
=>
panic!
(
"Failed to create router: {}"
,
error
),
};
Self
{
router
,
client
}
}
}
}
}
...
@@ -131,6 +127,7 @@ pub struct ServerConfig {
...
@@ -131,6 +127,7 @@ pub struct ServerConfig {
}
}
pub
async
fn
startup
(
config
:
ServerConfig
)
->
std
::
io
::
Result
<
()
>
{
pub
async
fn
startup
(
config
:
ServerConfig
)
->
std
::
io
::
Result
<
()
>
{
// Initialize logger
Builder
::
new
()
Builder
::
new
()
.format
(|
buf
,
record
|
{
.format
(|
buf
,
record
|
{
use
chrono
::
Local
;
use
chrono
::
Local
;
...
@@ -152,24 +149,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
...
@@ -152,24 +149,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
)
)
.init
();
.init
();
info!
(
"🚧 Initializing router on {}:{}"
,
config
.host
,
config
.port
);
info!
(
"🚧 Initializing workers on {:?}"
,
config
.worker_urls
);
info!
(
"🚧 Policy Config: {:?}"
,
config
.policy_config
);
info!
(
"🚧 Max payload size: {} MB"
,
config
.max_payload_size
/
(
1024
*
1024
)
);
let
client
=
reqwest
::
Client
::
builder
()
let
client
=
reqwest
::
Client
::
builder
()
.build
()
.build
()
.expect
(
"Failed to create HTTP client"
);
.expect
(
"Failed to create HTTP client"
);
let
app_state
=
web
::
Data
::
new
(
AppState
::
new
(
let
app_state
=
web
::
Data
::
new
(
config
.worker_urls
.clone
(),
AppState
::
new
(
client
,
config
.worker_urls
.clone
(),
config
.policy_config
.clone
(),
client
,
));
config
.policy_config
.clone
(),
)
info!
(
"✅ Starting router on {}:{}"
,
config
.host
,
config
.port
);
.map_err
(|
e
|
std
::
io
::
Error
::
new
(
std
::
io
::
ErrorKind
::
Other
,
e
))
?
,
info!
(
"✅ Serving Worker URLs: {:?}"
,
config
.worker_urls
);
info!
(
"✅ Policy Config: {:?}"
,
config
.policy_config
);
info!
(
"✅ Max payload size: {} MB"
,
config
.max_payload_size
/
(
1024
*
1024
)
);
);
info!
(
"✅ Serving router on {}:{}"
,
config
.host
,
config
.port
);
info!
(
"✅ Serving workers on {:?}"
,
config
.worker_urls
);
HttpServer
::
new
(
move
||
{
HttpServer
::
new
(
move
||
{
App
::
new
()
App
::
new
()
.app_data
(
app_state
.clone
())
.app_data
(
app_state
.clone
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment