Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
67b65794
"llm/vscode:/vscode.git/clone" did not exist on "f4bf1d514f537af9166f72fa00feda04556fc3d5"
Unverified
Commit
67b65794
authored
Dec 06, 2024
by
Byron Hsu
Committed by
GitHub
Dec 06, 2024
Browse files
[router] support `/add_worker` api (#2369)
parent
37ee906f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
134 additions
and
18 deletions
+134
-18
rust/py_test/test_launch_server.py
rust/py_test/test_launch_server.py
+80
-1
rust/src/router.rs
rust/src/router.rs
+33
-16
rust/src/server.rs
rust/src/server.rs
+21
-1
No files found.
rust/py_test/test_launch_server.py
View file @
67b65794
import
socket
import
subprocess
import
subprocess
import
time
import
time
import
unittest
import
unittest
...
@@ -49,7 +50,7 @@ def popen_launch_router(
...
@@ -49,7 +50,7 @@ def popen_launch_router(
# Use current environment
# Use current environment
env
=
None
env
=
None
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
with
requests
.
Session
()
as
session
:
with
requests
.
Session
()
as
session
:
...
@@ -57,6 +58,52 @@ def popen_launch_router(
...
@@ -57,6 +58,52 @@ def popen_launch_router(
try
:
try
:
response
=
session
.
get
(
f
"
{
base_url
}
/health"
)
response
=
session
.
get
(
f
"
{
base_url
}
/health"
)
if
response
.
status_code
==
200
:
if
response
.
status_code
==
200
:
print
(
f
"Router
{
base_url
}
is healthy"
)
return
process
except
requests
.
RequestException
:
pass
time
.
sleep
(
10
)
raise
TimeoutError
(
"Router failed to start within the timeout period."
)
def
find_available_port
():
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
def
popen_launch_server
(
model
:
str
,
base_url
:
str
,
timeout
:
float
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
command
=
[
"python3"
,
"-m"
,
"sglang.launch_server"
,
"--model-path"
,
model
,
"--host"
,
host
,
"--port"
,
port
,
"--base-gpu-id"
,
"1"
,
]
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
start_time
=
time
.
time
()
with
requests
.
Session
()
as
session
:
while
time
.
time
()
-
start_time
<
timeout
:
try
:
response
=
session
.
get
(
f
"
{
base_url
}
/health"
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
base_url
}
is healthy"
)
return
process
return
process
except
requests
.
RequestException
:
except
requests
.
RequestException
:
pass
pass
...
@@ -76,10 +123,13 @@ class TestEvalAccuracyMini(unittest.TestCase):
...
@@ -76,10 +123,13 @@ class TestEvalAccuracyMini(unittest.TestCase):
dp_size
=
1
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
)
cls
.
other_process
=
[]
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
kill_process_tree
(
cls
.
process
.
pid
)
for
process
in
cls
.
other_process
:
kill_process_tree
(
process
.
pid
)
def
test_mmlu
(
self
):
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
...
@@ -98,6 +148,35 @@ class TestEvalAccuracyMini(unittest.TestCase):
...
@@ -98,6 +148,35 @@ class TestEvalAccuracyMini(unittest.TestCase):
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_add_worker
(
self
):
# 1. start a worker, and wait until it is healthy
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self
.
other_process
.
append
(
worker_process
)
# 2. use /add_worker api to add it the the router
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# 3. run mmlu
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.65
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
rust/src/router.rs
View file @
67b65794
...
@@ -7,18 +7,18 @@ use log::{debug, info};
...
@@ -7,18 +7,18 @@ use log::{debug, info};
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::{
Arc
,
Mutex
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
thread
;
use
std
::
thread
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
#[derive(Debug)]
#[derive(Debug)]
pub
enum
Router
{
pub
enum
Router
{
RoundRobin
{
RoundRobin
{
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>
>>
,
current_index
:
AtomicUsize
,
current_index
:
AtomicUsize
,
},
},
Random
{
Random
{
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>
>>
,
},
},
CacheAware
{
CacheAware
{
/*
/*
...
@@ -81,7 +81,7 @@ pub enum Router {
...
@@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
during the next eviction cycle.
*/
*/
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Arc
<
RwLock
<
Vec
<
String
>
>>
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
running_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
running_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
processed_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
processed_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
...
@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
...
@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl
Router
{
impl
Router
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
)
->
Self
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
)
->
Self
{
match
policy_config
{
match
policy_config
{
PolicyConfig
::
RandomConfig
=>
Router
::
Random
{
worker_urls
},
PolicyConfig
::
RandomConfig
=>
Router
::
Random
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
},
PolicyConfig
::
RoundRobinConfig
=>
Router
::
RoundRobin
{
PolicyConfig
::
RoundRobinConfig
=>
Router
::
RoundRobin
{
worker_urls
,
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
))
,
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
},
},
PolicyConfig
::
CacheAwareConfig
{
PolicyConfig
::
CacheAwareConfig
{
...
@@ -183,7 +185,7 @@ impl Router {
...
@@ -183,7 +185,7 @@ impl Router {
}
}
Router
::
CacheAware
{
Router
::
CacheAware
{
worker_urls
,
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
))
,
tree
,
tree
,
running_queue
,
running_queue
,
processed_queue
,
processed_queue
,
...
@@ -201,10 +203,10 @@ impl Router {
...
@@ -201,10 +203,10 @@ impl Router {
Router
::
RoundRobin
{
worker_urls
,
..
}
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
if
worker_urls
.is_empty
()
{
if
worker_urls
.
read
()
.unwrap
()
.
is_empty
()
{
None
None
}
else
{
}
else
{
Some
(
worker_urls
[
0
]
.clone
())
Some
(
worker_urls
.read
()
.unwrap
()
[
0
]
.clone
())
}
}
}
}
}
}
...
@@ -228,15 +230,15 @@ impl Router {
...
@@ -228,15 +230,15 @@ impl Router {
.fetch_update
(
.fetch_update
(
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
|
x
|
Some
((
x
+
1
)
%
worker_urls
.len
()),
|
x
|
Some
((
x
+
1
)
%
worker_urls
.
read
()
.unwrap
()
.
len
()),
)
)
.unwrap
();
.unwrap
();
worker_urls
[
idx
]
.clone
()
worker_urls
.read
()
.unwrap
()
[
idx
]
.clone
()
}
}
Router
::
Random
{
worker_urls
}
=>
{
Router
::
Random
{
worker_urls
}
=>
worker_urls
.read
()
.unwrap
()
worker_urls
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.
len
()]
.clone
()
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.
read
()
.unwrap
()
.len
()
]
}
.clone
(),
Router
::
CacheAware
{
Router
::
CacheAware
{
worker_urls
,
worker_urls
,
...
@@ -277,7 +279,7 @@ impl Router {
...
@@ -277,7 +279,7 @@ impl Router {
.iter
()
.iter
()
.min_by_key
(|(
_u
rl
,
&
count
)|
count
)
.min_by_key
(|(
_u
rl
,
&
count
)|
count
)
.map
(|(
url
,
_
)|
url
.clone
())
.map
(|(
url
,
_
)|
url
.clone
())
.unwrap_or_else
(||
worker_urls
[
0
]
.clone
())
.unwrap_or_else
(||
worker_urls
.read
()
.unwrap
()
[
0
]
.clone
())
}
else
{
}
else
{
// Use cache-aware routing when load is balanced
// Use cache-aware routing when load is balanced
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
&
text
);
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
&
text
);
...
@@ -333,7 +335,10 @@ impl Router {
...
@@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first
// For non-streaming requests, get response first
let
response
=
match
res
.bytes
()
.await
{
let
response
=
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
Err
(
e
)
=>
{
let
error_msg
=
format!
(
"Failed to get response body: {}"
,
e
);
HttpResponse
::
InternalServerError
()
.body
(
error_msg
)
}
};
};
// Then decrement running queue counter if using CacheAware
// Then decrement running queue counter if using CacheAware
...
@@ -379,4 +384,16 @@ impl Router {
...
@@ -379,4 +384,16 @@ impl Router {
}))
}))
}
}
}
}
pub
fn
add_worker
(
&
self
,
worker_url
:
String
)
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
let
mut
urls
=
worker_urls
.write
()
.unwrap
();
info!
(
"Added worker: {}"
,
worker_url
);
urls
.push
(
worker_url
);
}
}
}
}
}
rust/src/server.rs
View file @
67b65794
use
crate
::
router
::
PolicyConfig
;
use
crate
::
router
::
PolicyConfig
;
use
crate
::
router
::
Router
;
use
crate
::
router
::
Router
;
use
actix_web
::{
get
,
post
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
};
use
actix_web
::{
delete
,
get
,
post
,
put
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
,
};
use
bytes
::
Bytes
;
use
bytes
::
Bytes
;
use
env_logger
::
Builder
;
use
env_logger
::
Builder
;
use
log
::{
info
,
LevelFilter
};
use
log
::{
info
,
LevelFilter
};
use
std
::
collections
::
HashMap
;
use
std
::
io
::
Write
;
use
std
::
io
::
Write
;
#[derive(Debug)]
#[derive(Debug)]
...
@@ -128,6 +131,22 @@ async fn v1_completions(
...
@@ -128,6 +131,22 @@ async fn v1_completions(
.await
.await
}
}
#[post(
"/add_worker"
)]
async
fn
add_worker
(
query
:
web
::
Query
<
HashMap
<
String
,
String
>>
,
data
:
web
::
Data
<
AppState
>
,
)
->
impl
Responder
{
let
worker_url
=
match
query
.get
(
"url"
)
{
Some
(
url
)
=>
url
.to_string
(),
None
=>
{
return
HttpResponse
::
BadRequest
()
.body
(
"Worker URL required. Provide 'url' query parameter"
)
}
};
data
.router
.add_worker
(
worker_url
);
HttpResponse
::
Ok
()
.finish
()
}
pub
struct
ServerConfig
{
pub
struct
ServerConfig
{
pub
host
:
String
,
pub
host
:
String
,
pub
port
:
u16
,
pub
port
:
u16
,
...
@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
...
@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service
(
health
)
.service
(
health
)
.service
(
health_generate
)
.service
(
health_generate
)
.service
(
get_server_info
)
.service
(
get_server_info
)
.service
(
add_worker
)
})
})
.bind
((
config
.host
,
config
.port
))
?
.bind
((
config
.host
,
config
.port
))
?
.run
()
.run
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment