Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d4de9a62
Unverified
Commit
d4de9a62
authored
Dec 11, 2024
by
Byron Hsu
Committed by
GitHub
Dec 11, 2024
Browse files
[router] Refactor: decouple select and send stage (#2440)
parent
7310aede
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
101 deletions
+107
-101
rust/src/router.rs
rust/src/router.rs
+93
-44
rust/src/server.rs
rust/src/server.rs
+14
-57
No files found.
rust/src/router.rs
View file @
d4de9a62
...
...
@@ -106,28 +106,6 @@ pub enum PolicyConfig {
},
}
fn
get_text_from_request
(
body
:
&
Bytes
,
route
:
&
str
)
->
String
{
// convert body to json
let
json
=
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
body
)
.unwrap
();
if
route
==
"generate"
{
// get the "text" field
let
text
=
json
.get
(
"text"
)
.and_then
(|
t
|
t
.as_str
())
.unwrap_or
(
""
);
return
text
.to_string
();
}
else
if
route
==
"v1/chat/completions"
{
// get the messages field as raw text
if
let
Some
(
messages
)
=
json
.get
(
"messages"
)
{
// Convert messages back to a string, preserving all JSON formatting
return
serde_json
::
to_string
(
messages
)
.unwrap_or_default
();
}
}
else
if
route
==
"v1/completions"
{
let
prompt
=
json
.get
(
"prompt"
)
.and_then
(|
t
|
t
.as_str
())
.unwrap_or
(
""
);
return
prompt
.to_string
();
}
return
""
.to_string
();
}
impl
Router
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy_config
:
PolicyConfig
)
->
Result
<
Self
,
String
>
{
// Wait until all workers are healthy
...
...
@@ -204,20 +182,6 @@ impl Router {
})
}
pub
fn
get_first
(
&
self
)
->
Option
<
String
>
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
if
worker_urls
.read
()
.unwrap
()
.is_empty
()
{
None
}
else
{
Some
(
worker_urls
.read
()
.unwrap
()[
0
]
.clone
())
}
}
}
}
fn
wait_for_healthy_workers
(
worker_urls
:
&
[
String
],
timeout_secs
:
u64
,
...
...
@@ -271,14 +235,76 @@ impl Router {
}
}
pub
async
fn
dispatch
(
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
if
worker_urls
.read
()
.unwrap
()
.is_empty
()
{
Err
(
"No workers are available"
.to_string
())
}
else
{
Ok
(
worker_urls
.read
()
.unwrap
()[
0
]
.clone
())
}
}
}
}
async
fn
send_request
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
HttpRequest
,
body
:
Bytes
,
worker_url
:
String
,
route
:
&
str
,
)
->
HttpResponse
{
let
text
=
get_text_from_request
(
&
body
,
route
);
match
client
.get
(
format!
(
"{}{}"
,
worker_url
,
route
))
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Failed to read response body: {}"
,
e
)),
}
}
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Failed to send request to worker {}: {}"
,
worker_url
,
e
)),
}
}
pub
async
fn
route_to_first
(
&
self
,
client
:
&
reqwest
::
Client
,
route
:
&
str
)
->
HttpResponse
{
match
self
.select_first_worker
()
{
Ok
(
worker_url
)
=>
self
.send_request
(
client
,
worker_url
,
route
)
.await
,
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
e
),
}
}
fn
get_text_from_request
(
&
self
,
body
:
&
Bytes
,
route
:
&
str
)
->
String
{
// convert body to json
let
json
=
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
body
)
.unwrap
();
if
route
==
"generate"
{
// get the "text" field
let
text
=
json
.get
(
"text"
)
.and_then
(|
t
|
t
.as_str
())
.unwrap_or
(
""
);
return
text
.to_string
();
}
else
if
route
==
"v1/chat/completions"
{
// get the messages field as raw text
if
let
Some
(
messages
)
=
json
.get
(
"messages"
)
{
// Convert messages back to a string, preserving all JSON formatting
return
serde_json
::
to_string
(
messages
)
.unwrap_or_default
();
}
}
else
if
route
==
"v1/completions"
{
let
prompt
=
json
.get
(
"prompt"
)
.and_then
(|
t
|
t
.as_str
())
.unwrap_or
(
""
);
return
prompt
.to_string
();
}
return
""
.to_string
();
}
// TODO: return Result<String, String> instead of panicking
fn
select_generate_worker
(
&
self
,
body
:
&
Bytes
,
route
:
&
str
)
->
String
{
let
text
=
self
.get_text_from_request
(
&
body
,
route
);
let
worker_url
=
match
self
{
Router
::
RoundRobin
{
...
...
@@ -366,12 +392,23 @@ impl Router {
}
};
worker_url
}
async
fn
send_generate_request
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
HttpRequest
,
body
:
Bytes
,
route
:
&
str
,
worker_url
:
&
str
,
)
->
HttpResponse
{
let
is_stream
=
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
body
)
.map
(|
v
|
v
.get
(
"stream"
)
.and_then
(|
s
|
s
.as_bool
())
.unwrap_or
(
false
))
.unwrap_or
(
false
);
let
res
=
match
client
.post
(
format!
(
"{}
/
{}"
,
worker_url
.clone
()
,
route
))
.post
(
format!
(
"{}{}"
,
worker_url
,
route
))
.header
(
"Content-Type"
,
req
.headers
()
...
...
@@ -403,7 +440,7 @@ impl Router {
// Then decrement running queue counter if using CacheAware
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
if
let
Ok
(
mut
queue
)
=
running_queue
.lock
()
{
if
let
Some
(
count
)
=
queue
.get_mut
(
&
worker_url
)
{
if
let
Some
(
count
)
=
queue
.get_mut
(
worker_url
)
{
*
count
=
count
.saturating_sub
(
1
);
}
}
...
...
@@ -412,7 +449,7 @@ impl Router {
response
}
else
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
let
running_queue
=
Arc
::
clone
(
running_queue
);
let
worker_url
=
worker_url
.
clone
();
let
worker_url
=
worker_url
.
to_string
();
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
...
...
@@ -431,7 +468,7 @@ impl Router {
let
mut
locked_queue
=
running_queue
.lock
()
.unwrap
();
let
count
=
locked_queue
.get_mut
(
&
worker_url
)
.unwrap
();
*
count
=
count
.saturating_sub
(
1
);
debug!
(
"
s
treaming is done!!"
)
debug!
(
"
S
treaming is done!!"
)
}
}),
)
...
...
@@ -444,6 +481,18 @@ impl Router {
}
}
pub
async
fn
route_generate_request
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
HttpRequest
,
body
:
Bytes
,
route
:
&
str
,
)
->
HttpResponse
{
let
worker_url
=
self
.select_generate_worker
(
&
body
,
route
);
self
.send_generate_request
(
client
,
req
,
body
,
route
,
&
worker_url
)
.await
}
pub
async
fn
add_worker
(
&
self
,
worker_url
:
String
)
->
Result
<
String
,
String
>
{
let
interval_secs
=
10
;
// check every 10 seconds
let
timeout_secs
=
300
;
// 5 minutes
...
...
rust/src/server.rs
View file @
d4de9a62
...
...
@@ -29,84 +29,41 @@ impl AppState {
}
}
async
fn
forward_request
(
client
:
&
reqwest
::
Client
,
worker_url
:
String
,
route
:
String
,
)
->
HttpResponse
{
match
client
.get
(
format!
(
"{}{}"
,
worker_url
,
route
))
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
// print the status
println!
(
"Forwarding Request Worker URL: {}, Route: {}, Status: {}"
,
worker_url
,
route
,
status
);
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
}
}
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
}
}
#[get(
"/health"
)]
async
fn
health
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
forward_request
(
&
data
.client
,
worker_url
,
"/health"
.to_string
())
.await
data
.router
.route_to_first
(
&
data
.client
,
"/health"
)
.await
}
#[get(
"/health_generate"
)]
async
fn
health_generate
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
forward_request
(
&
data
.client
,
worker_url
,
"/health_generate"
.to_string
())
.await
data
.router
.route_to_first
(
&
data
.client
,
"/health_generate"
)
.await
}
#[get(
"/get_server_info"
)]
async
fn
get_server_info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_info"
.to_string
())
.await
data
.router
.route_to_first
(
&
data
.client
,
"/get_server_info"
)
.await
}
#[get(
"/v1/models"
)]
async
fn
v1_models
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
forward_request
(
&
data
.client
,
worker_url
,
"/v1/models"
.to_string
())
.await
data
.router
.route_to_first
(
&
data
.client
,
"/v1/models"
)
.await
}
#[get(
"/get_model_info"
)]
async
fn
get_model_info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
forward_request
(
&
data
.client
,
worker_url
,
"/get_model_info"
.to_string
())
.await
data
.router
.route_to_first
(
&
data
.client
,
"/get_model_info"
)
.await
}
#[post(
"/generate"
)]
async
fn
generate
(
req
:
HttpRequest
,
body
:
Bytes
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.
dispatch
(
&
data
.client
,
req
,
body
,
"generate"
)
.
route_generate_request
(
&
data
.client
,
req
,
body
,
"
/
generate"
)
.await
}
...
...
@@ -117,7 +74,7 @@ async fn v1_chat_completions(
data
:
web
::
Data
<
AppState
>
,
)
->
impl
Responder
{
data
.router
.
dispatch
(
&
data
.client
,
req
,
body
,
"v1/chat/completions"
)
.
route_generate_request
(
&
data
.client
,
req
,
body
,
"
/
v1/chat/completions"
)
.await
}
...
...
@@ -128,7 +85,7 @@ async fn v1_completions(
data
:
web
::
Data
<
AppState
>
,
)
->
impl
Responder
{
data
.router
.
dispatch
(
&
data
.client
,
req
,
body
,
"v1/completions"
)
.
route_generate_request
(
&
data
.client
,
req
,
body
,
"
/
v1/completions"
)
.await
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment