Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4d62bca5
Unverified
Commit
4d62bca5
authored
Nov 25, 2024
by
Byron Hsu
Committed by
GitHub
Nov 25, 2024
Browse files
[router] Replace print with logger (#2183)
parent
e1e595d7
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
250 additions
and
48 deletions
+250
-48
rust/Cargo.lock
rust/Cargo.lock
+102
-0
rust/Cargo.toml
rust/Cargo.toml
+3
-0
rust/py_src/sglang_router/launch_router.py
rust/py_src/sglang_router/launch_router.py
+28
-1
rust/py_src/sglang_router/launch_server.py
rust/py_src/sglang_router/launch_server.py
+35
-14
rust/py_src/sglang_router/router.py
rust/py_src/sglang_router/router.py
+3
-0
rust/src/lib.rs
rust/src/lib.rs
+14
-8
rust/src/main.rs
rust/src/main.rs
+12
-2
rust/src/router.rs
rust/src/router.rs
+5
-5
rust/src/server.rs
rust/src/server.rs
+43
-14
rust/src/tree.rs
rust/src/tree.rs
+5
-4
No files found.
rust/Cargo.lock
View file @
4d62bca5
...
@@ -237,6 +237,21 @@ dependencies = [
...
@@ -237,6 +237,21 @@ dependencies = [
"alloc-no-stdlib",
"alloc-no-stdlib",
]
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
[[package]]
name = "anstream"
name = "anstream"
version = "0.6.18"
version = "0.6.18"
...
@@ -411,6 +426,20 @@ version = "1.0.0"
...
@@ -411,6 +426,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.52.6",
]
[[package]]
[[package]]
name = "clap"
name = "clap"
version = "4.5.20"
version = "4.5.20"
...
@@ -721,6 +750,29 @@ dependencies = [
...
@@ -721,6 +750,29 @@ dependencies = [
"cfg-if",
"cfg-if",
]
]
[[package]]
name = "env_filter"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]]
[[package]]
name = "equivalent"
name = "equivalent"
version = "1.0.1"
version = "1.0.1"
...
@@ -1016,6 +1068,12 @@ version = "1.0.3"
...
@@ -1016,6 +1068,12 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
[[package]]
name = "hyper"
name = "hyper"
version = "1.5.0"
version = "1.5.0"
...
@@ -1088,6 +1146,29 @@ dependencies = [
...
@@ -1088,6 +1146,29 @@ dependencies = [
"tracing",
"tracing",
]
]
[[package]]
name = "iana-time-zone"
version = "0.1.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
[[package]]
name = "icu_collections"
name = "icu_collections"
version = "1.5.0"
version = "1.5.0"
...
@@ -1523,6 +1604,15 @@ version = "0.1.0"
...
@@ -1523,6 +1604,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
[[package]]
name = "number_prefix"
name = "number_prefix"
version = "0.4.0"
version = "0.4.0"
...
@@ -2116,10 +2206,13 @@ version = "0.0.0"
...
@@ -2116,10 +2206,13 @@ version = "0.0.0"
dependencies = [
dependencies = [
"actix-web",
"actix-web",
"bytes",
"bytes",
"chrono",
"clap",
"clap",
"dashmap",
"dashmap",
"env_logger",
"futures-util",
"futures-util",
"http 1.1.0",
"http 1.1.0",
"log",
"pyo3",
"pyo3",
"rand",
"rand",
"reqwest",
"reqwest",
...
@@ -2688,6 +2781,15 @@ dependencies = [
...
@@ -2688,6 +2781,15 @@ dependencies = [
"rustls-pki-types",
"rustls-pki-types",
]
]
[[package]]
name = "windows-core"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
[[package]]
name = "windows-registry"
name = "windows-registry"
version = "0.2.0"
version = "0.2.0"
...
...
rust/Cargo.toml
View file @
4d62bca5
...
@@ -26,6 +26,9 @@ pyo3 = { version = "0.22.5", features = ["extension-module"] }
...
@@ -26,6 +26,9 @@ pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers
=
{
version
=
"0.20.3"
,
features
=
["http"]
}
tokenizers
=
{
version
=
"0.20.3"
,
features
=
["http"]
}
dashmap
=
"6.1.0"
dashmap
=
"6.1.0"
http
=
"1.1.0"
http
=
"1.1.0"
env_logger
=
"0.11.5"
log
=
"0.4.22"
chrono
=
"0.4.38"
[profile.release]
[profile.release]
lto
=
"thin"
lto
=
"thin"
...
...
rust/py_src/sglang_router/launch_router.py
View file @
4d62bca5
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
logging
import
sys
import
sys
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
...
@@ -7,6 +8,22 @@ from sglang_router import Router
...
@@ -7,6 +8,22 @@ from sglang_router import Router
from
sglang_router_rs
import
PolicyType
from
sglang_router_rs
import
PolicyType
def
setup_logger
():
logger
=
logging
.
getLogger
(
"router"
)
logger
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
handler
=
logging
.
StreamHandler
()
handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
handler
)
return
logger
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
RouterArgs
:
class
RouterArgs
:
# Worker configuration
# Worker configuration
...
@@ -21,6 +38,7 @@ class RouterArgs:
...
@@ -21,6 +38,7 @@ class RouterArgs:
balance_rel_threshold
:
float
=
1.0001
balance_rel_threshold
:
float
=
1.0001
eviction_interval
:
int
=
60
eviction_interval
:
int
=
60
max_tree_size
:
int
=
2
**
24
max_tree_size
:
int
=
2
**
24
verbose
:
bool
=
False
@
staticmethod
@
staticmethod
def
add_cli_args
(
def
add_cli_args
(
...
@@ -98,6 +116,11 @@ class RouterArgs:
...
@@ -98,6 +116,11 @@ class RouterArgs:
default
=
RouterArgs
.
max_tree_size
,
default
=
RouterArgs
.
max_tree_size
,
help
=
"Maximum size of the approximation tree for cache-aware routing"
,
help
=
"Maximum size of the approximation tree for cache-aware routing"
,
)
)
parser
.
add_argument
(
f
"--
{
prefix
}
verbose"
,
action
=
"store_true"
,
help
=
"Enable verbose logging"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
def
from_cli_args
(
...
@@ -121,6 +144,7 @@ class RouterArgs:
...
@@ -121,6 +144,7 @@ class RouterArgs:
balance_rel_threshold
=
getattr
(
args
,
f
"
{
prefix
}
balance_rel_threshold"
),
balance_rel_threshold
=
getattr
(
args
,
f
"
{
prefix
}
balance_rel_threshold"
),
eviction_interval
=
getattr
(
args
,
f
"
{
prefix
}
eviction_interval"
),
eviction_interval
=
getattr
(
args
,
f
"
{
prefix
}
eviction_interval"
),
max_tree_size
=
getattr
(
args
,
f
"
{
prefix
}
max_tree_size"
),
max_tree_size
=
getattr
(
args
,
f
"
{
prefix
}
max_tree_size"
),
verbose
=
getattr
(
args
,
f
"
{
prefix
}
verbose"
,
False
),
)
)
...
@@ -145,6 +169,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -145,6 +169,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
Returns:
Returns:
Router instance if successful, None if failed
Router instance if successful, None if failed
"""
"""
logger
=
logging
.
getLogger
(
"router"
)
try
:
try
:
# Convert to RouterArgs if needed
# Convert to RouterArgs if needed
if
not
isinstance
(
args
,
RouterArgs
):
if
not
isinstance
(
args
,
RouterArgs
):
...
@@ -162,13 +187,14 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -162,13 +187,14 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
balance_rel_threshold
=
router_args
.
balance_rel_threshold
,
balance_rel_threshold
=
router_args
.
balance_rel_threshold
,
eviction_interval_secs
=
router_args
.
eviction_interval
,
eviction_interval_secs
=
router_args
.
eviction_interval
,
max_tree_size
=
router_args
.
max_tree_size
,
max_tree_size
=
router_args
.
max_tree_size
,
verbose
=
router_args
.
verbose
,
)
)
router
.
start
()
router
.
start
()
return
router
return
router
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error starting router:
{
e
}
"
,
file
=
sys
.
stderr
)
logger
.
error
(
f
"Error starting router:
{
e
}
"
,
file
=
sys
.
stderr
)
return
None
return
None
...
@@ -202,6 +228,7 @@ Examples:
...
@@ -202,6 +228,7 @@ Examples:
def
main
()
->
None
:
def
main
()
->
None
:
logger
=
setup_logger
()
router_args
=
parse_router_args
(
sys
.
argv
[
1
:])
router_args
=
parse_router_args
(
sys
.
argv
[
1
:])
router
=
launch_router
(
router_args
)
router
=
launch_router
(
router_args
)
...
...
rust/py_src/sglang_router/launch_server.py
View file @
4d62bca5
import
argparse
import
argparse
import
copy
import
copy
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
import
random
import
random
...
@@ -17,6 +18,22 @@ from sglang.srt.utils import is_port_available
...
@@ -17,6 +18,22 @@ from sglang.srt.utils import is_port_available
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
def
setup_logger
():
logger
=
logging
.
getLogger
(
"router"
)
logger
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
handler
=
logging
.
StreamHandler
()
handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
handler
)
return
logger
# Create new process group
# Create new process group
def
run_server
(
server_args
,
dp_rank
):
def
run_server
(
server_args
,
dp_rank
):
os
.
setpgrp
()
# Create new process group
os
.
setpgrp
()
# Create new process group
...
@@ -42,20 +59,20 @@ def launch_server_process(
...
@@ -42,20 +59,20 @@ def launch_server_process(
def
cleanup_processes
(
processes
:
List
[
mp
.
Process
]):
def
cleanup_processes
(
processes
:
List
[
mp
.
Process
]):
"""Clean up all processes using process groups."""
logger
=
logging
.
getLogger
(
"router"
)
print
(
"
\n
Cleaning up processes..."
)
logger
.
info
(
"
Cleaning up processes..."
)
for
proc
in
processes
:
for
proc
in
processes
:
if
proc
.
is_alive
():
if
proc
.
is_alive
():
try
:
try
:
# Kill the entire process group
os
.
killpg
(
os
.
getpgid
(
proc
.
pid
),
signal
.
SIGTERM
)
os
.
killpg
(
os
.
getpgid
(
proc
.
pid
),
signal
.
SIGTERM
)
# Give processes some time to terminate gracefully
proc
.
join
(
timeout
=
3
)
proc
.
join
(
timeout
=
3
)
# If process is still alive, force kill
if
proc
.
is_alive
():
if
proc
.
is_alive
():
logger
.
warning
(
f
"Process
{
proc
.
pid
}
did not terminate gracefully, force killing..."
)
os
.
killpg
(
os
.
getpgid
(
proc
.
pid
),
signal
.
SIGKILL
)
os
.
killpg
(
os
.
getpgid
(
proc
.
pid
),
signal
.
SIGKILL
)
except
ProcessLookupError
:
except
ProcessLookupError
:
pass
# Process already terminated
pass
def
setup_signal_handlers
(
cleanup_func
):
def
setup_signal_handlers
(
cleanup_func
):
...
@@ -101,6 +118,8 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
...
@@ -101,6 +118,8 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
def
main
():
def
main
():
logger
=
setup_logger
()
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp
.
set_start_method
(
"spawn"
)
mp
.
set_start_method
(
"spawn"
)
...
@@ -130,8 +149,8 @@ def main():
...
@@ -130,8 +149,8 @@ def main():
server_processes
=
[]
server_processes
=
[]
try
:
try
:
# Launch server processes
for
i
,
worker_port
in
enumerate
(
worker_ports
):
for
i
,
worker_port
in
enumerate
(
worker_ports
):
logger
.
info
(
f
"Launching DP server process
{
i
}
on port
{
worker_port
}
"
)
proc
=
launch_server_process
(
server_args
,
worker_port
,
i
)
proc
=
launch_server_process
(
server_args
,
worker_port
,
i
)
server_processes
.
append
(
proc
)
server_processes
.
append
(
proc
)
...
@@ -140,18 +159,19 @@ def main():
...
@@ -140,18 +159,19 @@ def main():
# Wait for all servers to be healthy
# Wait for all servers to be healthy
all_healthy
=
True
all_healthy
=
True
for
port
in
worker_ports
:
for
port
in
worker_ports
:
if
not
wait_for_server_health
(
server_args
.
host
,
port
):
if
not
wait_for_server_health
(
server_args
.
host
,
port
):
print
(
f
"Server on port
{
port
}
failed to become healthy"
)
logger
.
error
(
f
"Server on port
{
port
}
failed to become healthy"
)
all_healthy
=
False
all_healthy
=
False
break
break
if
not
all_healthy
:
if
not
all_healthy
:
print
(
"Not all servers are healthy. Shutting down..."
)
logger
.
error
(
"Not all servers are healthy. Shutting down..."
)
cleanup_processes
(
server_processes
)
cleanup_processes
(
server_processes
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
print
(
"All servers are healthy. Starting router..."
)
logger
.
info
(
"All servers are healthy. Starting router..."
)
# Update router args with worker URLs
# Update router args with worker URLs
router_args
.
worker_urls
=
[
router_args
.
worker_urls
=
[
...
@@ -162,16 +182,17 @@ def main():
...
@@ -162,16 +182,17 @@ def main():
router
=
launch_router
(
router_args
)
router
=
launch_router
(
router_args
)
if
router
is
None
:
if
router
is
None
:
print
(
"Failed to start router. Shutting down..."
)
logger
.
error
(
"Failed to start router. Shutting down..."
)
cleanup_processes
(
server_processes
)
cleanup_processes
(
server_processes
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
print
(
"
\n
Received shutdown signal..."
)
logger
.
info
(
"
Received shutdown signal..."
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error occurred:
{
e
}
"
)
logger
.
error
(
f
"Error occurred:
{
e
}
"
)
print
(
get_exception_traceback
())
logger
.
error
(
get_exception_traceback
())
finally
:
finally
:
logger
.
info
(
"Cleaning up processes..."
)
cleanup_processes
(
server_processes
)
cleanup_processes
(
server_processes
)
...
...
rust/py_src/sglang_router/router.py
View file @
4d62bca5
...
@@ -27,6 +27,7 @@ class Router:
...
@@ -27,6 +27,7 @@ class Router:
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
"""
"""
def
__init__
(
def
__init__
(
...
@@ -40,6 +41,7 @@ class Router:
...
@@ -40,6 +41,7 @@ class Router:
balance_rel_threshold
:
float
=
1.0001
,
balance_rel_threshold
:
float
=
1.0001
,
eviction_interval_secs
:
int
=
60
,
eviction_interval_secs
:
int
=
60
,
max_tree_size
:
int
=
2
**
24
,
max_tree_size
:
int
=
2
**
24
,
verbose
:
bool
=
False
,
):
):
self
.
_router
=
_Router
(
self
.
_router
=
_Router
(
worker_urls
=
worker_urls
,
worker_urls
=
worker_urls
,
...
@@ -51,6 +53,7 @@ class Router:
...
@@ -51,6 +53,7 @@ class Router:
balance_rel_threshold
=
balance_rel_threshold
,
balance_rel_threshold
=
balance_rel_threshold
,
eviction_interval_secs
=
eviction_interval_secs
,
eviction_interval_secs
=
eviction_interval_secs
,
max_tree_size
=
max_tree_size
,
max_tree_size
=
max_tree_size
,
verbose
=
verbose
,
)
)
def
start
(
self
)
->
None
:
def
start
(
self
)
->
None
:
...
...
rust/src/lib.rs
View file @
4d62bca5
...
@@ -22,6 +22,7 @@ struct Router {
...
@@ -22,6 +22,7 @@ struct Router {
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
eviction_interval_secs
:
u64
,
eviction_interval_secs
:
u64
,
max_tree_size
:
usize
,
max_tree_size
:
usize
,
verbose
:
bool
,
}
}
#[pymethods]
#[pymethods]
...
@@ -36,7 +37,8 @@ impl Router {
...
@@ -36,7 +37,8 @@ impl Router {
balance_abs_threshold
=
32
,
balance_abs_threshold
=
32
,
balance_rel_threshold
=
1.0001
,
balance_rel_threshold
=
1.0001
,
eviction_interval_secs
=
60
,
eviction_interval_secs
=
60
,
max_tree_size
=
2u
size
.
pow(
24
)
max_tree_size
=
2u
size
.
pow(
24
),
verbose
=
false
))]
))]
fn
new
(
fn
new
(
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
...
@@ -48,6 +50,7 @@ impl Router {
...
@@ -48,6 +50,7 @@ impl Router {
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
eviction_interval_secs
:
u64
,
eviction_interval_secs
:
u64
,
max_tree_size
:
usize
,
max_tree_size
:
usize
,
verbose
:
bool
,
)
->
PyResult
<
Self
>
{
)
->
PyResult
<
Self
>
{
Ok
(
Router
{
Ok
(
Router
{
host
,
host
,
...
@@ -59,14 +62,11 @@ impl Router {
...
@@ -59,14 +62,11 @@ impl Router {
balance_rel_threshold
,
balance_rel_threshold
,
eviction_interval_secs
,
eviction_interval_secs
,
max_tree_size
,
max_tree_size
,
verbose
,
})
})
}
}
fn
start
(
&
self
)
->
PyResult
<
()
>
{
fn
start
(
&
self
)
->
PyResult
<
()
>
{
let
host
=
self
.host
.clone
();
let
port
=
self
.port
;
let
worker_urls
=
self
.worker_urls
.clone
();
let
policy_config
=
match
&
self
.policy
{
let
policy_config
=
match
&
self
.policy
{
PolicyType
::
Random
=>
router
::
PolicyConfig
::
RandomConfig
,
PolicyType
::
Random
=>
router
::
PolicyConfig
::
RandomConfig
,
PolicyType
::
RoundRobin
=>
router
::
PolicyConfig
::
RoundRobinConfig
,
PolicyType
::
RoundRobin
=>
router
::
PolicyConfig
::
RoundRobinConfig
,
...
@@ -80,9 +80,15 @@ impl Router {
...
@@ -80,9 +80,15 @@ impl Router {
};
};
actix_web
::
rt
::
System
::
new
()
.block_on
(
async
move
{
actix_web
::
rt
::
System
::
new
()
.block_on
(
async
move
{
server
::
startup
(
host
,
port
,
worker_urls
,
policy_config
)
server
::
startup
(
server
::
ServerConfig
{
.await
host
:
self
.host
.clone
(),
.unwrap
();
port
:
self
.port
,
worker_urls
:
self
.worker_urls
.clone
(),
policy_config
,
verbose
:
self
.verbose
,
})
.await
.unwrap
();
});
});
Ok
(())
Ok
(())
...
...
rust/src/main.rs
View file @
4d62bca5
use
clap
::
Parser
;
use
clap
::
Parser
;
use
clap
::
ValueEnum
;
use
clap
::
ValueEnum
;
use
sglang_router_rs
::{
router
::
PolicyConfig
,
server
};
use
sglang_router_rs
::{
router
::
PolicyConfig
,
server
,
server
::
ServerConfig
};
#[derive(Debug,
Clone,
ValueEnum)]
#[derive(Debug,
Clone,
ValueEnum)]
pub
enum
PolicyType
{
pub
enum
PolicyType
{
...
@@ -89,6 +89,9 @@ struct Args {
...
@@ -89,6 +89,9 @@ struct Args {
help
=
"Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
help
=
"Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
)]
)]
max_tree_size
:
usize
,
max_tree_size
:
usize
,
#[arg(long,
default_value_t
=
false
,
help
=
"Enable verbose logging"
)]
verbose
:
bool
,
}
}
impl
Args
{
impl
Args
{
...
@@ -111,5 +114,12 @@ impl Args {
...
@@ -111,5 +114,12 @@ impl Args {
async
fn
main
()
->
std
::
io
::
Result
<
()
>
{
async
fn
main
()
->
std
::
io
::
Result
<
()
>
{
let
args
=
Args
::
parse
();
let
args
=
Args
::
parse
();
let
policy_config
=
args
.get_policy_config
();
let
policy_config
=
args
.get_policy_config
();
server
::
startup
(
args
.host
,
args
.port
,
args
.worker_urls
,
policy_config
)
.await
server
::
startup
(
ServerConfig
{
host
:
args
.host
,
port
:
args
.port
,
worker_urls
:
args
.worker_urls
,
policy_config
,
verbose
:
args
.verbose
,
})
.await
}
}
rust/src/router.rs
View file @
4d62bca5
...
@@ -3,6 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
...
@@ -3,6 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
bytes
::
Bytes
;
use
bytes
::
Bytes
;
use
futures_util
::{
Stream
,
StreamExt
,
TryStreamExt
};
use
futures_util
::{
Stream
,
StreamExt
,
TryStreamExt
};
use
log
::{
debug
,
info
};
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
fmt
::
Debug
;
use
std
::
hash
::
Hash
;
use
std
::
hash
::
Hash
;
...
@@ -171,11 +172,11 @@ impl Router {
...
@@ -171,11 +172,11 @@ impl Router {
// Print the process queue
// Print the process queue
let
locked_processed_queue
=
processed_queue_clone
.lock
()
.unwrap
();
let
locked_processed_queue
=
processed_queue_clone
.lock
()
.unwrap
();
println
!
(
"Processed Queue: {:?}"
,
locked_processed_queue
);
info
!
(
"Processed Queue: {:?}"
,
locked_processed_queue
);
// Print the running queue
// Print the running queue
let
locked_running_queue
=
running_queue_clone
.lock
()
.unwrap
();
let
locked_running_queue
=
running_queue_clone
.lock
()
.unwrap
();
println
!
(
"Running Queue: {:?}"
,
locked_running_queue
);
info
!
(
"Running Queue: {:?}"
,
locked_running_queue
);
}
}
});
});
...
@@ -266,7 +267,7 @@ impl Router {
...
@@ -266,7 +267,7 @@ impl Router {
let
selected_url
=
if
is_imbalanced
{
let
selected_url
=
if
is_imbalanced
{
// Log load balancing trigger and current queue state
// Log load balancing trigger and current queue state
println
!
(
info
!
(
"Load balancing triggered due to workload imbalance:
\n
\
"Load balancing triggered due to workload imbalance:
\n
\
Max load: {}, Min load: {}
\n
\
Max load: {}, Min load: {}
\n
\
Current running queue: {:?}"
,
Current running queue: {:?}"
,
...
@@ -368,8 +369,7 @@ impl Router {
...
@@ -368,8 +369,7 @@ impl Router {
let
mut
locked_queue
=
running_queue
.lock
()
.unwrap
();
let
mut
locked_queue
=
running_queue
.lock
()
.unwrap
();
let
count
=
locked_queue
.get_mut
(
&
worker_url
)
.unwrap
();
let
count
=
locked_queue
.get_mut
(
&
worker_url
)
.unwrap
();
*
count
=
count
.saturating_sub
(
1
);
*
count
=
count
.saturating_sub
(
1
);
// print
debug!
(
"streaming is done!!"
)
// println!("streaming is done!!")
}
}
}),
}),
)
)
...
...
rust/src/server.rs
View file @
4d62bca5
...
@@ -2,6 +2,9 @@ use crate::router::PolicyConfig;
...
@@ -2,6 +2,9 @@ use crate::router::PolicyConfig;
use
crate
::
router
::
Router
;
use
crate
::
router
::
Router
;
use
actix_web
::{
get
,
post
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
};
use
actix_web
::{
get
,
post
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
};
use
bytes
::
Bytes
;
use
bytes
::
Bytes
;
use
env_logger
::
Builder
;
use
log
::{
debug
,
info
,
LevelFilter
};
use
std
::
io
::
Write
;
#[derive(Debug)]
#[derive(Debug)]
pub
struct
AppState
{
pub
struct
AppState
{
...
@@ -125,23 +128,49 @@ async fn v1_completions(
...
@@ -125,23 +128,49 @@ async fn v1_completions(
.await
.await
}
}
pub
async
fn
startup
(
pub
struct
ServerConfig
{
host
:
String
,
pub
host
:
String
,
port
:
u16
,
pub
port
:
u16
,
worker_urls
:
Vec
<
String
>
,
pub
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
,
pub
policy_config
:
PolicyConfig
,
)
->
std
::
io
::
Result
<
()
>
{
pub
verbose
:
bool
,
println!
(
"Starting server on {}:{}"
,
host
,
port
);
}
println!
(
"Worker URLs: {:?}"
,
worker_urls
);
println!
(
"Policy Config: {:?}"
,
policy_config
);
pub
async
fn
startup
(
config
:
ServerConfig
)
->
std
::
io
::
Result
<
()
>
{
Builder
::
new
()
// Create client once with configuration
.format
(|
buf
,
record
|
{
use
chrono
::
Local
;
writeln!
(
buf
,
"[Router (Rust)] {} - {} - {}"
,
Local
::
now
()
.format
(
"%Y-%m-%d %H:%M:%S"
),
record
.level
(),
record
.args
()
)
})
.filter
(
None
,
if
config
.verbose
{
LevelFilter
::
Debug
}
else
{
LevelFilter
::
Info
},
)
.init
();
info!
(
"Starting server on {}:{}"
,
config
.host
,
config
.port
);
info!
(
"Worker URLs: {:?}"
,
config
.worker_urls
);
info!
(
"Policy Config: {:?}"
,
config
.policy_config
);
let
client
=
reqwest
::
Client
::
builder
()
let
client
=
reqwest
::
Client
::
builder
()
.build
()
.build
()
.expect
(
"Failed to create HTTP client"
);
.expect
(
"Failed to create HTTP client"
);
// Store both worker_urls and client in AppState
let
app_state
=
web
::
Data
::
new
(
AppState
::
new
(
let
app_state
=
web
::
Data
::
new
(
AppState
::
new
(
worker_urls
,
client
,
policy_config
));
config
.worker_urls
,
client
,
config
.policy_config
,
));
HttpServer
::
new
(
move
||
{
HttpServer
::
new
(
move
||
{
App
::
new
()
App
::
new
()
...
@@ -155,7 +184,7 @@ pub async fn startup(
...
@@ -155,7 +184,7 @@ pub async fn startup(
.service
(
health_generate
)
.service
(
health_generate
)
.service
(
get_server_info
)
.service
(
get_server_info
)
})
})
.bind
((
host
,
port
))
?
.bind
((
config
.host
,
config
.
port
))
?
.run
()
.run
()
.await
.await
}
}
rust/src/tree.rs
View file @
4d62bca5
use
dashmap
::
mapref
::
entry
::
Entry
;
use
dashmap
::
mapref
::
entry
::
Entry
;
use
dashmap
::
DashMap
;
use
dashmap
::
DashMap
;
use
log
::
info
;
use
rand
::
distributions
::{
Alphanumeric
,
DistString
};
use
rand
::
distributions
::{
Alphanumeric
,
DistString
};
use
rand
::
thread_rng
;
use
rand
::
thread_rng
;
use
std
::
cmp
::
min
;
use
std
::
cmp
::
min
;
...
@@ -434,9 +435,9 @@ impl Tree {
...
@@ -434,9 +435,9 @@ impl Tree {
}
}
}
}
println
!
(
"Before eviction - Used size per tenant:"
);
info
!
(
"Before eviction - Used size per tenant:"
);
for
(
tenant
,
size
)
in
&
used_size_per_tenant
{
for
(
tenant
,
size
)
in
&
used_size_per_tenant
{
println
!
(
"Tenant: {}, Size: {}"
,
tenant
,
size
);
info
!
(
"Tenant: {}, Size: {}"
,
tenant
,
size
);
}
}
// Process eviction
// Process eviction
...
@@ -490,9 +491,9 @@ impl Tree {
...
@@ -490,9 +491,9 @@ impl Tree {
}
}
}
}
println
!
(
"
\n
After eviction - Used size per tenant:"
);
info
!
(
"After eviction - Used size per tenant:"
);
for
(
tenant
,
size
)
in
&
used_size_per_tenant
{
for
(
tenant
,
size
)
in
&
used_size_per_tenant
{
println
!
(
"Tenant: {}, Size: {}"
,
tenant
,
size
);
info
!
(
"Tenant: {}, Size: {}"
,
tenant
,
size
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment