Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9a0cc2e9
Unverified
Commit
9a0cc2e9
authored
Jan 23, 2025
by
Byron Hsu
Committed by
GitHub
Jan 23, 2025
Browse files
[router] Forward all request headers from router to workers (#3070)
parent
7bad7e75
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
25 deletions
+132
-25
scripts/killall_sglang.sh
scripts/killall_sglang.sh
+9
-0
sgl-router/py_test/test_launch_server.py
sgl-router/py_test/test_launch_server.py
+56
-0
sgl-router/src/router.rs
sgl-router/src/router.rs
+53
-15
sgl-router/src/server.rs
sgl-router/src/server.rs
+14
-10
No files found.
scripts/killall_sglang.sh
View file @
9a0cc2e9
#!/bin/bash
#!/bin/bash
# Check if sudo is available
if
command
-v
sudo
>
/dev/null 2>&1
;
then
sudo
apt-get update
sudo
apt-get
install
-y
lsof
else
apt-get update
apt-get
install
-y
lsof
fi
# Show current GPU status
# Show current GPU status
nvidia-smi
nvidia-smi
...
...
sgl-router/py_test/test_launch_server.py
View file @
9a0cc2e9
...
@@ -22,6 +22,7 @@ def popen_launch_router(
...
@@ -22,6 +22,7 @@ def popen_launch_router(
timeout
:
float
,
timeout
:
float
,
policy
:
str
=
"cache_aware"
,
policy
:
str
=
"cache_aware"
,
max_payload_size
:
int
=
None
,
max_payload_size
:
int
=
None
,
api_key
:
str
=
None
,
):
):
"""
"""
Launch the router server process.
Launch the router server process.
...
@@ -33,6 +34,7 @@ def popen_launch_router(
...
@@ -33,6 +34,7 @@ def popen_launch_router(
timeout: Server launch timeout
timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes
max_payload_size: Maximum payload size in bytes
api_key: API key for the router
"""
"""
_
,
host
,
port
=
base_url
.
split
(
":"
)
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
host
=
host
[
2
:]
...
@@ -55,6 +57,9 @@ def popen_launch_router(
...
@@ -55,6 +57,9 @@ def popen_launch_router(
policy
,
policy
,
]
]
if
api_key
is
not
None
:
command
.
extend
([
"--api-key"
,
api_key
])
if
max_payload_size
is
not
None
:
if
max_payload_size
is
not
None
:
command
.
extend
([
"--router-max-payload-size"
,
str
(
max_payload_size
)])
command
.
extend
([
"--router-max-payload-size"
,
str
(
max_payload_size
)])
...
@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase):
...
@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase):
f
"1.2MB payload should fail with 413 but got status
{
response
.
status_code
}
"
,
f
"1.2MB payload should fail with 413 but got status
{
response
.
status_code
}
"
,
)
)
def
test_5_api_key
(
self
):
print
(
"Running test_5_api_key..."
)
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
api_key
=
"correct_api_key"
,
)
# # Test case 1: request without api key should fail
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is, "
,
"temperature"
:
0
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
401
,
"Request without api key should fail with 401"
,
)
# Test case 2: request with invalid api key should fail
with
requests
.
Session
()
as
session
:
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is, "
,
"temperature"
:
0
},
headers
=
{
"Authorization"
:
"Bearer 123"
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
401
,
"Request with invalid api key should fail with 401"
,
)
# Test case 3: request with correct api key should succeed
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is "
,
"temperature"
:
0
},
headers
=
{
"Authorization"
:
"Bearer correct_api_key"
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
,
"Request with correct api key should succeed"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
sgl-router/src/router.rs
View file @
9a0cc2e9
...
@@ -12,6 +12,18 @@ use std::thread;
...
@@ -12,6 +12,18 @@ use std::thread;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tokio
;
use
tokio
;
fn
copy_request_headers
(
req
:
&
HttpRequest
)
->
Vec
<
(
String
,
String
)
>
{
req
.headers
()
.iter
()
.filter_map
(|(
name
,
value
)|
{
value
.to_str
()
.ok
()
.map
(|
v
|
(
name
.to_string
(),
v
.to_string
()))
})
.collect
()
}
#[derive(Debug)]
#[derive(Debug)]
pub
enum
Router
{
pub
enum
Router
{
RoundRobin
{
RoundRobin
{
...
@@ -303,8 +315,18 @@ impl Router {
...
@@ -303,8 +315,18 @@ impl Router {
client
:
&
reqwest
::
Client
,
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
,
worker_url
:
&
str
,
route
:
&
str
,
route
:
&
str
,
req
:
&
HttpRequest
,
)
->
HttpResponse
{
)
->
HttpResponse
{
match
client
.get
(
format!
(
"{}{}"
,
worker_url
,
route
))
.send
()
.await
{
let
mut
request_builder
=
client
.get
(
format!
(
"{}{}"
,
worker_url
,
route
));
// Copy all headers from original request except for /health because it does not need authorization
if
route
!=
"/health"
{
for
(
name
,
value
)
in
copy_request_headers
(
req
)
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
...
@@ -322,7 +344,12 @@ impl Router {
...
@@ -322,7 +344,12 @@ impl Router {
}
}
}
}
pub
async
fn
route_to_first
(
&
self
,
client
:
&
reqwest
::
Client
,
route
:
&
str
)
->
HttpResponse
{
pub
async
fn
route_to_first
(
&
self
,
client
:
&
reqwest
::
Client
,
route
:
&
str
,
req
:
&
HttpRequest
,
)
->
HttpResponse
{
const
MAX_REQUEST_RETRIES
:
u32
=
3
;
const
MAX_REQUEST_RETRIES
:
u32
=
3
;
const
MAX_TOTAL_RETRIES
:
u32
=
6
;
const
MAX_TOTAL_RETRIES
:
u32
=
6
;
let
mut
total_retries
=
0
;
let
mut
total_retries
=
0
;
...
@@ -338,10 +365,17 @@ impl Router {
...
@@ -338,10 +365,17 @@ impl Router {
info!
(
"Retrying request after {} failed attempts"
,
total_retries
);
info!
(
"Retrying request after {} failed attempts"
,
total_retries
);
}
}
let
response
=
self
.send_request
(
client
,
&
worker_url
,
route
)
.await
;
let
response
=
self
.send_request
(
client
,
&
worker_url
,
route
,
req
)
.await
;
if
response
.status
()
.is_success
()
{
if
response
.status
()
.is_success
()
{
return
response
;
return
response
;
}
else
{
// if the worker is healthy, it means the request is bad, so return the error response
let
health_response
=
self
.send_request
(
client
,
&
worker_url
,
"/health"
,
req
)
.await
;
if
health_response
.status
()
.is_success
()
{
return
response
;
}
}
}
warn!
(
warn!
(
...
@@ -496,19 +530,16 @@ impl Router {
...
@@ -496,19 +530,16 @@ impl Router {
.map
(|
v
|
v
.get
(
"stream"
)
.and_then
(|
s
|
s
.as_bool
())
.unwrap_or
(
false
))
.map
(|
v
|
v
.get
(
"stream"
)
.and_then
(|
s
|
s
.as_bool
())
.unwrap_or
(
false
))
.unwrap_or
(
false
);
.unwrap_or
(
false
);
let
res
=
match
client
let
mut
request_builder
=
client
.post
(
format!
(
"{}{}"
,
worker_url
,
route
))
.post
(
format!
(
"{}{}"
,
worker_url
,
route
))
.header
(
.body
(
body
.to_vec
());
"Content-Type"
,
req
.headers
()
// Copy all headers from original request
.get
(
"Content-Type"
)
for
(
name
,
value
)
in
copy_request_headers
(
req
)
{
.and_then
(|
h
|
h
.to_str
()
.ok
())
request_builder
=
request_builder
.header
(
name
,
value
);
.unwrap_or
(
"application/json"
),
}
)
.body
(
body
.to_vec
())
let
res
=
match
request_builder
.send
()
.await
{
.send
()
.await
{
Ok
(
res
)
=>
res
,
Ok
(
res
)
=>
res
,
Err
(
_
)
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
Err
(
_
)
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
};
...
@@ -596,6 +627,13 @@ impl Router {
...
@@ -596,6 +627,13 @@ impl Router {
if
response
.status
()
.is_success
()
{
if
response
.status
()
.is_success
()
{
return
response
;
return
response
;
}
else
{
// if the worker is healthy, it means the request is bad, so return the error response
let
health_response
=
self
.send_request
(
client
,
&
worker_url
,
"/health"
,
req
)
.await
;
if
health_response
.status
()
.is_success
()
{
return
response
;
}
}
}
warn!
(
warn!
(
...
...
sgl-router/src/server.rs
View file @
9a0cc2e9
...
@@ -26,33 +26,37 @@ impl AppState {
...
@@ -26,33 +26,37 @@ impl AppState {
}
}
#[get(
"/health"
)]
#[get(
"/health"
)]
async
fn
health
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
health
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_to_first
(
&
data
.client
,
"/health"
)
.await
data
.router
.route_to_first
(
&
data
.client
,
"/health"
,
&
req
)
.await
}
}
#[get(
"/health_generate"
)]
#[get(
"/health_generate"
)]
async
fn
health_generate
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
health_generate
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
data
.router
.route_to_first
(
&
data
.client
,
"/health_generate"
)
.route_to_first
(
&
data
.client
,
"/health_generate"
,
&
req
)
.await
.await
}
}
#[get(
"/get_server_info"
)]
#[get(
"/get_server_info"
)]
async
fn
get_server_info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
get_server_info
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
data
.router
.route_to_first
(
&
data
.client
,
"/get_server_info"
)
.route_to_first
(
&
data
.client
,
"/get_server_info"
,
&
req
)
.await
.await
}
}
#[get(
"/v1/models"
)]
#[get(
"/v1/models"
)]
async
fn
v1_models
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
v1_models
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_to_first
(
&
data
.client
,
"/v1/models"
)
.await
data
.router
.route_to_first
(
&
data
.client
,
"/v1/models"
,
&
req
)
.await
}
}
#[get(
"/get_model_info"
)]
#[get(
"/get_model_info"
)]
async
fn
get_model_info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
get_model_info
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
data
.router
.route_to_first
(
&
data
.client
,
"/get_model_info"
)
.route_to_first
(
&
data
.client
,
"/get_model_info"
,
&
req
)
.await
.await
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment