Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
67b65794
Unverified
Commit
67b65794
authored
Dec 06, 2024
by
Byron Hsu
Committed by
GitHub
Dec 06, 2024
Browse files
[router] support `/add_worker` api (#2369)
parent
37ee906f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
134 additions
and
18 deletions
+134
-18
rust/py_test/test_launch_server.py
rust/py_test/test_launch_server.py
+80
-1
rust/src/router.rs
rust/src/router.rs
+33
-16
rust/src/server.rs
rust/src/server.rs
+21
-1
No files found.
rust/py_test/test_launch_server.py
View file @
67b65794
import
socket
import
subprocess
import
time
import
unittest
...
...
@@ -49,7 +50,7 @@ def popen_launch_router(
# Use current environment
env
=
None
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
start_time
=
time
.
time
()
with
requests
.
Session
()
as
session
:
...
...
@@ -57,6 +58,52 @@ def popen_launch_router(
try
:
response
=
session
.
get
(
f
"
{
base_url
}
/health"
)
if
response
.
status_code
==
200
:
print
(
f
"Router
{
base_url
}
is healthy"
)
return
process
except
requests
.
RequestException
:
pass
time
.
sleep
(
10
)
raise
TimeoutError
(
"Router failed to start within the timeout period."
)
def
find_available_port
():
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
def
popen_launch_server
(
model
:
str
,
base_url
:
str
,
timeout
:
float
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
command
=
[
"python3"
,
"-m"
,
"sglang.launch_server"
,
"--model-path"
,
model
,
"--host"
,
host
,
"--port"
,
port
,
"--base-gpu-id"
,
"1"
,
]
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
start_time
=
time
.
time
()
with
requests
.
Session
()
as
session
:
while
time
.
time
()
-
start_time
<
timeout
:
try
:
response
=
session
.
get
(
f
"
{
base_url
}
/health"
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
base_url
}
is healthy"
)
return
process
except
requests
.
RequestException
:
pass
...
...
@@ -76,10 +123,13 @@ class TestEvalAccuracyMini(unittest.TestCase):
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
cls
.
other_process
=
[]
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
for
process
in
cls
.
other_process
:
kill_process_tree
(
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
@@ -98,6 +148,35 @@ class TestEvalAccuracyMini(unittest.TestCase):
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_add_worker
(
self
):
# 1. start a worker, and wait until it is healthy
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self
.
other_process
.
append
(
worker_process
)
# 2. use /add_worker api to add it the the router
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# 3. run mmlu
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.65
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
if
__name__
==
"__main__"
:
unittest
.
main
()
rust/src/router.rs
View file @
67b65794
...
...
@@ -7,18 +7,18 @@ use log::{debug, info};
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::{
Arc
,
Mutex
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
thread
;
use
std
::
time
::
Duration
;
#[derive(Debug)]
pub
enum
Router
{
RoundRobin
{
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>
>>
,
current_index
:
AtomicUsize
,
},
Random
{
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>
>>
,
},
CacheAware
{
/*
...
...
@@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>
>>
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
running_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
processed_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
...
...
@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl
Router
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
)
->
Self
{
match
policy_config
{
PolicyConfig
::
RandomConfig
=>
Router
::
Random
{
worker_urls
},
PolicyConfig
::
RandomConfig
=>
Router
::
Random
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
},
PolicyConfig
::
RoundRobinConfig
=>
Router
::
RoundRobin
{
worker_urls
,
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
))
,
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
},
PolicyConfig
::
CacheAwareConfig
{
...
...
@@ -183,7 +185,7 @@ impl Router {
}
Router
::
CacheAware
{
worker_urls
,
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
))
,
tree
,
running_queue
,
processed_queue
,
...
...
@@ -201,10 +203,10 @@ impl Router {
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
if
worker_urls
.is_empty
()
{
if
worker_urls
.
read
()
.unwrap
()
.
is_empty
()
{
None
}
else
{
Some
(
worker_urls
[
0
]
.clone
())
Some
(
worker_urls
.read
()
.unwrap
()
[
0
]
.clone
())
}
}
}
...
...
@@ -228,15 +230,15 @@ impl Router {
.fetch_update
(
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
|
x
|
Some
((
x
+
1
)
%
worker_urls
.len
()),
|
x
|
Some
((
x
+
1
)
%
worker_urls
.
read
()
.unwrap
()
.
len
()),
)
.unwrap
();
worker_urls
[
idx
]
.clone
()
worker_urls
.read
()
.unwrap
()
[
idx
]
.clone
()
}
Router
::
Random
{
worker_urls
}
=>
{
worker_urls
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.
len
()]
.clone
()
}
Router
::
Random
{
worker_urls
}
=>
worker_urls
.read
()
.unwrap
()
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.
read
()
.unwrap
()
.len
()
]
.clone
(),
Router
::
CacheAware
{
worker_urls
,
...
...
@@ -277,7 +279,7 @@ impl Router {
.iter
()
.min_by_key
(|(
_u
rl
,
&
count
)|
count
)
.map
(|(
url
,
_
)|
url
.clone
())
.unwrap_or_else
(||
worker_urls
[
0
]
.clone
())
.unwrap_or_else
(||
worker_urls
.read
()
.unwrap
()
[
0
]
.clone
())
}
else
{
// Use cache-aware routing when load is balanced
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
&
text
);
...
...
@@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first
let
response
=
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
Err
(
e
)
=>
{
let
error_msg
=
format!
(
"Failed to get response body: {}"
,
e
);
HttpResponse
::
InternalServerError
()
.body
(
error_msg
)
}
};
// Then decrement running queue counter if using CacheAware
...
...
@@ -379,4 +384,16 @@ impl Router {
}))
}
}
pub
fn
add_worker
(
&
self
,
worker_url
:
String
)
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
let
mut
urls
=
worker_urls
.write
()
.unwrap
();
info!
(
"Added worker: {}"
,
worker_url
);
urls
.push
(
worker_url
);
}
}
}
}
rust/src/server.rs
View file @
67b65794
use
crate
::
router
::
PolicyConfig
;
use
crate
::
router
::
Router
;
use
actix_web
::{
get
,
post
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
};
use
actix_web
::{
delete
,
get
,
post
,
put
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
,
};
use
bytes
::
Bytes
;
use
env_logger
::
Builder
;
use
log
::{
info
,
LevelFilter
};
use
std
::
collections
::
HashMap
;
use
std
::
io
::
Write
;
#[derive(Debug)]
...
...
@@ -128,6 +131,22 @@ async fn v1_completions(
.await
}
#[post(
"/add_worker"
)]
async
fn
add_worker
(
query
:
web
::
Query
<
HashMap
<
String
,
String
>>
,
data
:
web
::
Data
<
AppState
>
,
)
->
impl
Responder
{
let
worker_url
=
match
query
.get
(
"url"
)
{
Some
(
url
)
=>
url
.to_string
(),
None
=>
{
return
HttpResponse
::
BadRequest
()
.body
(
"Worker URL required. Provide 'url' query parameter"
)
}
};
data
.router
.add_worker
(
worker_url
);
HttpResponse
::
Ok
()
.finish
()
}
pub
struct
ServerConfig
{
pub
host
:
String
,
pub
port
:
u16
,
...
...
@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service
(
health
)
.service
(
health_generate
)
.service
(
get_server_info
)
.service
(
add_worker
)
})
.bind
((
config
.host
,
config
.port
))
?
.run
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment