Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
96766101
Unverified
Commit
96766101
authored
Nov 06, 2024
by
Byron Hsu
Committed by
GitHub
Nov 06, 2024
Browse files
[rust] refactor server and router (#1922)
parent
a146d999
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
161 deletions
+113
-161
.github/workflows/pr-test-rust.yml
.github/workflows/pr-test-rust.yml
+2
-2
rust/src/router.rs
rust/src/router.rs
+90
-74
rust/src/server.rs
rust/src/server.rs
+21
-85
No files found.
.github/workflows/pr-test-rust.yml
View file @
96766101
...
@@ -4,11 +4,11 @@ on:
...
@@ -4,11 +4,11 @@ on:
push
:
push
:
branches
:
[
main
]
branches
:
[
main
]
paths
:
paths
:
-
"
rust/*"
-
"
rust/*
*
"
pull_request
:
pull_request
:
branches
:
[
main
]
branches
:
[
main
]
paths
:
paths
:
-
"
rust/*"
-
"
rust/*
*
"
workflow_dispatch
:
workflow_dispatch
:
concurrency
:
concurrency
:
...
...
rust/src/router.rs
View file @
96766101
// src/router.rs
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
bytes
::
Bytes
;
use
futures_util
::
TryStreamExt
;
use
std
::
fmt
::
Debug
;
use
std
::
fmt
::
Debug
;
/// Generic Router trait that can be implemented with different policies
pub
trait
Router
:
Send
+
Sync
+
Debug
{
/// Select a worker URL based on the implementation's policy
/// Returns None if no worker is available
fn
select
(
&
self
)
->
Option
<
String
>
;
// get first worker
fn
get_first
(
&
self
)
->
Option
<
String
>
;
}
// Round Robin Router
#[derive(Debug)]
#[derive(Debug)]
pub
struct
RoundRobinRouter
{
pub
enum
Router
{
RoundRobin
{
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
,
// AtomicUsize is a thread-safe integer
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
,
},
Random
{
worker_urls
:
Vec
<
String
>
,
},
}
}
impl
RoundRobinRouter
{
impl
Router
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
)
->
Self
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy
:
String
)
->
Self
{
Self
{
match
policy
.to_lowercase
()
.as_str
()
{
"random"
=>
Router
::
Random
{
worker_urls
},
"round_robin"
=>
Router
::
RoundRobin
{
worker_urls
,
worker_urls
,
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
},
_
=>
panic!
(
"Unknown routing policy: {}. The available policies are 'random' and 'round_robin'"
,
policy
),
}
}
}
}
}
impl
Router
for
RoundRobinRouter
{
pub
fn
get_first
(
&
self
)
->
Option
<
String
>
{
fn
select
(
&
self
)
->
Option
<
String
>
{
match
self
{
if
self
.worker_urls
.is_empty
()
{
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
}
=>
{
return
None
;
if
worker_urls
.is_empty
()
{
None
}
else
{
Some
(
worker_urls
[
0
]
.clone
())
}
}
// Use relaxed because operation order doesn't matter in round robin
let
index
=
self
.current_index
.fetch_add
(
1
,
std
::
sync
::
atomic
::
Ordering
::
Relaxed
)
%
self
.worker_urls
.len
();
Some
(
self
.worker_urls
[
index
]
.clone
())
}
}
fn
get_first
(
&
self
)
->
Option
<
String
>
{
if
self
.worker_urls
.is_empty
()
{
return
None
;
}
}
Some
(
self
.worker_urls
[
0
]
.clone
())
}
}
}
// Random Router
pub
async
fn
dispatch
(
#[derive(Debug)]
&
self
,
pub
struct
RandomRouter
{
client
:
&
reqwest
::
Client
,
worker_urls
:
Vec
<
String
>
,
req
:
HttpRequest
,
}
body
:
Bytes
,
)
->
HttpResponse
{
let
worker_url
=
match
self
{
Router
::
RoundRobin
{
worker_urls
,
current_index
,
}
=>
{
current_index
.fetch_update
(
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
|
x
|
Some
((
x
+
1
)
%
worker_urls
.len
()),
)
.expect_err
(
"Error updating index in round robin"
);
impl
RandomRouter
{
&
worker_urls
[
current_index
.load
(
std
::
sync
::
atomic
::
Ordering
::
SeqCst
)]
pub
fn
new
(
worker_urls
:
Vec
<
String
>
)
->
Self
{
Self
{
worker_urls
}
}
}
}
Router
::
Random
{
worker_urls
}
=>
{
&
worker_urls
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.len
()]
}
};
impl
Router
for
RandomRouter
{
// Check if client requested streaming
fn
select
(
&
self
)
->
Option
<
String
>
{
let
is_stream
=
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
body
)
use
rand
::
seq
::
SliceRandom
;
.map
(|
v
|
v
.get
(
"stream"
)
.and_then
(|
s
|
s
.as_bool
())
.unwrap_or
(
false
))
.unwrap_or
(
false
);
if
self
.worker_urls
.is_empty
()
{
let
res
=
match
client
return
None
;
.post
(
format!
(
"{}/generate"
,
worker_url
))
}
.header
(
"Content-Type"
,
req
.headers
()
.get
(
"Content-Type"
)
.and_then
(|
h
|
h
.to_str
()
.ok
())
.unwrap_or
(
"application/json"
),
)
.body
(
body
.to_vec
())
.send
()
.await
{
Ok
(
res
)
=>
res
,
Err
(
_
)
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
self
.worker_urls
.choose
(
&
mut
rand
::
thread_rng
())
.cloned
()
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
()
)
}
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
fn
get_first
(
&
self
)
->
Option
<
String
>
{
if
!
is_stream
{
if
self
.worker_urls
.is_empty
()
{
match
res
.bytes
()
.await
{
return
None
;
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
}
}
Some
(
self
.worker_urls
[
0
]
.clone
())
}
else
{
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
(
res
.bytes_stream
()
.map_err
(|
_
|
{
actix_web
::
error
::
ErrorInternalServerError
(
"Failed to read string"
)
}))
}
}
}
// create a router based on routing policy
pub
fn
create_router
(
worker_urls
:
Vec
<
String
>
,
policy
:
String
)
->
Box
<
dyn
Router
>
{
match
policy
.to_lowercase
()
.as_str
()
{
"random"
=>
Box
::
new
(
RandomRouter
::
new
(
worker_urls
)),
"round_robin"
=>
Box
::
new
(
RoundRobinRouter
::
new
(
worker_urls
)),
_
=>
panic!
(
"Unknown routing policy: {}. The available policies are 'random' and 'round_robin'"
,
policy
),
}
}
}
}
rust/src/server.rs
View file @
96766101
use
crate
::
router
::
create_router
;
use
crate
::
router
::
Router
;
use
crate
::
router
::
Router
;
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::{
get
,
post
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
};
use
actix_web
::{
get
,
post
,
web
,
App
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
};
use
bytes
::
Bytes
;
use
bytes
::
Bytes
;
use
futures_util
::
StreamExt
;
#[derive(Debug)]
#[derive(Debug)]
pub
struct
AppState
{
pub
struct
AppState
{
router
:
Box
<
dyn
Router
>
,
router
:
Router
,
client
:
reqwest
::
Client
,
client
:
reqwest
::
Client
,
}
}
impl
AppState
{
impl
AppState
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy
:
String
,
client
:
reqwest
::
Client
)
->
Self
{
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy
:
String
,
client
:
reqwest
::
Client
)
->
Self
{
// Create router based on policy
// Create router based on policy
let
router
=
create_router
(
worker_urls
,
policy
);
let
router
=
Router
::
new
(
worker_urls
,
policy
);
Self
{
router
,
client
}
Self
{
router
,
client
}
}
}
}
}
#[get(
"/v1/models"
)]
async
fn
forward_request
(
async
fn
v1_model
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
client
:
&
reqwest
::
Client
,
let
worker_url
=
match
data
.router
.get_first
()
{
worker_url
:
String
,
Some
(
url
)
=>
url
,
route
:
String
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
)
->
HttpResponse
{
};
match
client
.get
(
format!
(
"{}{}"
,
worker_url
,
route
))
.send
()
.await
{
// Use the shared client
match
data
.client
.get
(
format!
(
"{}/v1/models"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
=>
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
...
@@ -48,85 +38,31 @@ async fn v1_model(data: web::Data<AppState>) -> impl Responder {
...
@@ -48,85 +38,31 @@ async fn v1_model(data: web::Data<AppState>) -> impl Responder {
}
}
}
}
#[get(
"/get_model_info"
)]
#[get(
"/v1/models"
)]
async
fn
get_model_info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
v1_model
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
// TODO: extract forward_to_route
let
worker_url
=
match
data
.router
.get_first
()
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
};
// Use the shared client
match
data
.client
.get
(
format!
(
"{}/get_model_info"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
// print the status
forward_request
(
&
data
.client
,
worker_url
,
"/v1/models"
.to_string
())
.await
println!
(
"Worker URL: {}, Status: {}"
,
worker_url
,
status
);
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
}
}
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
}
}
}
// no deser and ser, just forward and return
#[get(
"/get_model_info"
)]
#[post(
"/generate"
)]
async
fn
get_model_info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
generate
(
req
:
HttpRequest
,
body
:
Bytes
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
// create a router struct
// TODO: use router abstraction for different policy
let
worker_url
=
match
data
.router
.select
()
{
Some
(
url
)
=>
url
,
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
};
// Check if client requested streaming
forward_request
(
&
data
.client
,
worker_url
,
"/get_model_info"
.to_string
())
.await
let
is_stream
=
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
body
)
}
.map
(|
v
|
v
.get
(
"stream"
)
.and_then
(|
s
|
s
.as_bool
())
.unwrap_or
(
false
))
.unwrap_or
(
false
);
let
res
=
match
data
.client
.post
(
format!
(
"{}/generate"
,
worker_url
))
.header
(
"Content-Type"
,
req
.headers
()
.get
(
"Content-Type"
)
.and_then
(|
h
|
h
.to_str
()
.ok
())
.unwrap_or
(
"application/json"
),
)
.body
(
body
.to_vec
())
.send
()
.await
{
Ok
(
res
)
=>
res
,
Err
(
_
)
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
if
!
is_stream
{
// no deser and ser, just forward and return
match
res
.bytes
()
.await
{
#[post(
"/generate"
)]
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
async
fn
generate
(
req
:
HttpRequest
,
body
:
Bytes
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
Err
(
_
)
=>
HttpResponse
::
InternalServerError
()
.finish
(),
data
.router
.dispatch
(
&
data
.client
,
req
,
body
)
.await
}
}
else
{
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
(
res
.bytes_stream
()
.map
(|
b
|
match
b
{
Ok
(
b
)
=>
Ok
::
<
_
,
actix_web
::
Error
>
(
b
),
Err
(
_
)
=>
Err
(
actix_web
::
error
::
ErrorInternalServerError
(
"Failed to read stream"
,
)),
}))
}
}
}
pub
async
fn
startup
(
pub
async
fn
startup
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment