Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
56321e9f
Unverified
Commit
56321e9f
authored
Sep 22, 2025
by
Jimmy
Committed by
GitHub
Sep 21, 2025
Browse files
[Router]fix: fix get_load missing api_key (#10385)
parent
12d6cf18
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
372 additions
and
111 deletions
+372
-111
sgl-router/py_test/e2e/test_regular_router.py
sgl-router/py_test/e2e/test_regular_router.py
+3
-1
sgl-router/py_test/fixtures/mock_worker.py
sgl-router/py_test/fixtures/mock_worker.py
+2
-1
sgl-router/src/core/worker.rs
sgl-router/src/core/worker.rs
+49
-16
sgl-router/src/core/worker_builder.rs
sgl-router/src/core/worker_builder.rs
+24
-0
sgl-router/src/core/worker_registry.rs
sgl-router/src/core/worker_registry.rs
+16
-0
sgl-router/src/policies/cache_aware.rs
sgl-router/src/policies/cache_aware.rs
+2
-0
sgl-router/src/policies/mod.rs
sgl-router/src/policies/mod.rs
+3
-0
sgl-router/src/protocols/worker_spec.rs
sgl-router/src/protocols/worker_spec.rs
+4
-0
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+5
-1
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+5
-1
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+5
-1
sgl-router/src/routers/http/pd_router.rs
sgl-router/src/routers/http/pd_router.rs
+69
-22
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+69
-27
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+5
-1
sgl-router/src/routers/router_manager.rs
sgl-router/src/routers/router_manager.rs
+47
-19
sgl-router/src/routers/worker_initializer.rs
sgl-router/src/routers/worker_initializer.rs
+40
-12
sgl-router/src/server.rs
sgl-router/src/server.rs
+6
-5
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+9
-4
sgl-router/tests/cache_aware_backward_compat_test.rs
sgl-router/tests/cache_aware_backward_compat_test.rs
+6
-0
sgl-router/tests/policy_registry_integration.rs
sgl-router/tests/policy_registry_integration.rs
+3
-0
No files found.
sgl-router/py_test/e2e/test_regular_router.py
View file @
56321e9f
...
@@ -129,7 +129,9 @@ def test_dp_aware_worker_expansion_and_api_key(
...
@@ -129,7 +129,9 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers
# Attach worker; router should expand to dp_size logical workers
r
=
requests
.
post
(
r
=
requests
.
post
(
f
"
{
router_url
}
/add_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
180
f
"
{
router_url
}
/add_worker"
,
params
=
{
"url"
:
worker_url
,
"api_key"
:
api_key
},
timeout
=
180
,
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
...
...
sgl-router/py_test/fixtures/mock_worker.py
View file @
56321e9f
...
@@ -139,7 +139,8 @@ def create_app(args: argparse.Namespace) -> FastAPI:
...
@@ -139,7 +139,8 @@ def create_app(args: argparse.Namespace) -> FastAPI:
)
)
@
app
.
get
(
"/get_load"
)
@
app
.
get
(
"/get_load"
)
async
def
get_load
():
async
def
get_load
(
request
:
Request
):
check_api_key
(
request
)
return
JSONResponse
({
"load"
:
_inflight
})
return
JSONResponse
({
"load"
:
_inflight
})
def
make_json_response
(
obj
:
dict
,
status_code
:
int
=
200
)
->
JSONResponse
:
def
make_json_response
(
obj
:
dict
,
status_code
:
int
=
200
)
->
JSONResponse
:
...
...
sgl-router/src/core/worker.rs
View file @
56321e9f
...
@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
...
@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
pub
trait
Worker
:
Send
+
Sync
+
fmt
::
Debug
{
pub
trait
Worker
:
Send
+
Sync
+
fmt
::
Debug
{
/// Get the worker's URL
/// Get the worker's URL
fn
url
(
&
self
)
->
&
str
;
fn
url
(
&
self
)
->
&
str
;
/// Get the worker's API key
fn
api_key
(
&
self
)
->
&
Option
<
String
>
;
/// Get the worker's type (Regular, Prefill, or Decode)
/// Get the worker's type (Regular, Prefill, or Decode)
fn
worker_type
(
&
self
)
->
WorkerType
;
fn
worker_type
(
&
self
)
->
WorkerType
;
...
@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
...
@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
pub
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
,
pub
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
,
/// Health check configuration
/// Health check configuration
pub
health_config
:
HealthConfig
,
pub
health_config
:
HealthConfig
,
/// API key
pub
api_key
:
Option
<
String
>
,
}
}
/// Basic worker implementation
/// Basic worker implementation
...
@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
...
@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
&
self
.metadata.url
&
self
.metadata.url
}
}
fn
api_key
(
&
self
)
->
&
Option
<
String
>
{
&
self
.metadata.api_key
}
fn
worker_type
(
&
self
)
->
WorkerType
{
fn
worker_type
(
&
self
)
->
WorkerType
{
self
.metadata.worker_type
.clone
()
self
.metadata.worker_type
.clone
()
}
}
...
@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
...
@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
self
.base_worker
.url
()
self
.base_worker
.url
()
}
}
fn
api_key
(
&
self
)
->
&
Option
<
String
>
{
self
.base_worker
.api_key
()
}
fn
worker_type
(
&
self
)
->
WorkerType
{
fn
worker_type
(
&
self
)
->
WorkerType
{
self
.base_worker
.worker_type
()
self
.base_worker
.worker_type
()
}
}
...
@@ -650,19 +661,21 @@ impl WorkerFactory {
...
@@ -650,19 +661,21 @@ impl WorkerFactory {
dp_rank
:
usize
,
dp_rank
:
usize
,
dp_size
:
usize
,
dp_size
:
usize
,
worker_type
:
WorkerType
,
worker_type
:
WorkerType
,
api_key
:
Option
<
String
>
,
)
->
Box
<
dyn
Worker
>
{
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
let
mut
builder
=
DPAwareWorkerBuilder
::
new
(
base_url
,
dp_rank
,
dp_size
)
DPAwareWorkerBuilder
::
new
(
base_url
,
dp_rank
,
dp_size
)
.worker_type
(
worker_type
);
.worker_type
(
worker_type
)
if
let
Some
(
api_key
)
=
api_key
{
.build
(),
builder
=
builder
.api_key
(
api_key
);
)
}
Box
::
new
(
builder
.build
())
}
}
#[allow(dead_code)]
#[allow(dead_code)]
/// Get DP size from a worker
/// Get DP size from a worker
async
fn
get_worker_dp_size
(
url
:
&
str
,
api_key
:
&
Option
<
String
>
)
->
WorkerResult
<
usize
>
{
async
fn
get_worker_dp_size
(
url
:
&
str
,
api_key
:
&
Option
<
String
>
)
->
WorkerResult
<
usize
>
{
let
mut
req_builder
=
WORKER_CLIENT
.get
(
format!
(
"{}/get_server_info"
,
url
));
let
mut
req_builder
=
WORKER_CLIENT
.get
(
format!
(
"{}/get_server_info"
,
url
));
if
let
Some
(
key
)
=
api_key
{
if
let
Some
(
key
)
=
&
api_key
{
req_builder
=
req_builder
.bearer_auth
(
key
);
req_builder
=
req_builder
.bearer_auth
(
key
);
}
}
...
@@ -708,14 +721,18 @@ impl WorkerFactory {
...
@@ -708,14 +721,18 @@ impl WorkerFactory {
}
}
/// Convert a list of worker URLs to worker trait objects
/// Convert a list of worker URLs to worker trait objects
pub
fn
urls_to_workers
(
urls
:
Vec
<
String
>
)
->
Vec
<
Box
<
dyn
Worker
>>
{
pub
fn
urls_to_workers
(
urls
:
Vec
<
String
>
,
api_key
:
Option
<
String
>
)
->
Vec
<
Box
<
dyn
Worker
>>
{
urls
.into_iter
()
urls
.into_iter
()
.map
(|
url
|
{
.map
(|
url
|
{
Box
::
new
(
let
worker_builder
=
BasicWorkerBuilder
::
new
(
url
)
.worker_type
(
WorkerType
::
Regular
);
BasicWorkerBuilder
::
new
(
url
)
.worker_type
(
WorkerType
::
Regular
)
let
worker
=
if
let
Some
(
ref
api_key
)
=
api_key
{
.build
(),
worker_builder
.api_key
(
api_key
.clone
())
.build
()
)
as
Box
<
dyn
Worker
>
}
else
{
worker_builder
.build
()
};
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
})
})
.collect
()
.collect
()
}
}
...
@@ -961,6 +978,7 @@ mod tests {
...
@@ -961,6 +978,7 @@ mod tests {
use
crate
::
core
::
BasicWorkerBuilder
;
use
crate
::
core
::
BasicWorkerBuilder
;
let
worker
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
let
worker
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
worker
.url
(),
"http://test:8080"
);
assert_eq!
(
worker
.url
(),
"http://test:8080"
);
assert_eq!
(
worker
.worker_type
(),
WorkerType
::
Regular
);
assert_eq!
(
worker
.worker_type
(),
WorkerType
::
Regular
);
...
@@ -998,6 +1016,7 @@ mod tests {
...
@@ -998,6 +1016,7 @@ mod tests {
let
worker
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
let
worker
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.health_config
(
custom_config
.clone
())
.health_config
(
custom_config
.clone
())
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
worker
.metadata
()
.health_config.timeout_secs
,
15
);
assert_eq!
(
worker
.metadata
()
.health_config.timeout_secs
,
15
);
...
@@ -1011,6 +1030,7 @@ mod tests {
...
@@ -1011,6 +1030,7 @@ mod tests {
use
crate
::
core
::
BasicWorkerBuilder
;
use
crate
::
core
::
BasicWorkerBuilder
;
let
worker
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
let
worker
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
worker
.url
(),
"http://worker1:8080"
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080"
);
}
}
...
@@ -1020,6 +1040,7 @@ mod tests {
...
@@ -1020,6 +1040,7 @@ mod tests {
use
crate
::
core
::
BasicWorkerBuilder
;
use
crate
::
core
::
BasicWorkerBuilder
;
let
regular
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
let
regular
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
regular
.worker_type
(),
WorkerType
::
Regular
);
assert_eq!
(
regular
.worker_type
(),
WorkerType
::
Regular
);
...
@@ -1027,6 +1048,7 @@ mod tests {
...
@@ -1027,6 +1048,7 @@ mod tests {
.worker_type
(
WorkerType
::
Prefill
{
.worker_type
(
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
9090
),
bootstrap_port
:
Some
(
9090
),
})
})
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
assert_eq!
(
prefill
.worker_type
(),
prefill
.worker_type
(),
...
@@ -1037,6 +1059,7 @@ mod tests {
...
@@ -1037,6 +1059,7 @@ mod tests {
let
decode
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
let
decode
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
.worker_type
(
WorkerType
::
Decode
)
.worker_type
(
WorkerType
::
Decode
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
decode
.worker_type
(),
WorkerType
::
Decode
);
assert_eq!
(
decode
.worker_type
(),
WorkerType
::
Decode
);
}
}
...
@@ -1065,6 +1088,7 @@ mod tests {
...
@@ -1065,6 +1088,7 @@ mod tests {
use
crate
::
core
::
BasicWorkerBuilder
;
use
crate
::
core
::
BasicWorkerBuilder
;
let
worker
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
let
worker
=
BasicWorkerBuilder
::
new
(
"http://test:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
// Initial load is 0
// Initial load is 0
...
@@ -1350,7 +1374,7 @@ mod tests {
...
@@ -1350,7 +1374,7 @@ mod tests {
fn
test_urls_to_workers
()
{
fn
test_urls_to_workers
()
{
let
urls
=
vec!
[
"http://w1:8080"
.to_string
(),
"http://w2:8080"
.to_string
()];
let
urls
=
vec!
[
"http://w1:8080"
.to_string
(),
"http://w2:8080"
.to_string
()];
let
workers
=
urls_to_workers
(
urls
);
let
workers
=
urls_to_workers
(
urls
,
Some
(
"test_api_key"
.to_string
())
);
assert_eq!
(
workers
.len
(),
2
);
assert_eq!
(
workers
.len
(),
2
);
assert_eq!
(
workers
[
0
]
.url
(),
"http://w1:8080"
);
assert_eq!
(
workers
[
0
]
.url
(),
"http://w1:8080"
);
assert_eq!
(
workers
[
1
]
.url
(),
"http://w2:8080"
);
assert_eq!
(
workers
[
1
]
.url
(),
"http://w2:8080"
);
...
@@ -1547,6 +1571,7 @@ mod tests {
...
@@ -1547,6 +1571,7 @@ mod tests {
1
,
1
,
4
,
4
,
WorkerType
::
Regular
,
WorkerType
::
Regular
,
Some
(
"test_api_key"
.to_string
()),
);
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080@1"
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080@1"
);
...
@@ -1565,6 +1590,7 @@ mod tests {
...
@@ -1565,6 +1590,7 @@ mod tests {
WorkerType
::
Prefill
{
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8090
),
bootstrap_port
:
Some
(
8090
),
},
},
Some
(
"test_api_key"
.to_string
()),
);
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080@0"
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080@0"
);
...
@@ -1680,8 +1706,13 @@ mod tests {
...
@@ -1680,8 +1706,13 @@ mod tests {
.worker_type
(
WorkerType
::
Decode
)
.worker_type
(
WorkerType
::
Decode
)
.build
(),
.build
(),
);
);
let
dp_aware_regular
=
let
dp_aware_regular
=
WorkerFactory
::
create_dp_aware
(
WorkerFactory
::
create_dp_aware
(
"http://dp:8080"
.to_string
(),
0
,
2
,
WorkerType
::
Regular
);
"http://dp:8080"
.to_string
(),
0
,
2
,
WorkerType
::
Regular
,
Some
(
"test_api_key"
.to_string
()),
);
let
dp_aware_prefill
=
WorkerFactory
::
create_dp_aware
(
let
dp_aware_prefill
=
WorkerFactory
::
create_dp_aware
(
"http://dp-prefill:8080"
.to_string
(),
"http://dp-prefill:8080"
.to_string
(),
1
,
1
,
...
@@ -1689,12 +1720,14 @@ mod tests {
...
@@ -1689,12 +1720,14 @@ mod tests {
WorkerType
::
Prefill
{
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
bootstrap_port
:
None
,
},
},
Some
(
"test_api_key"
.to_string
()),
);
);
let
dp_aware_decode
=
WorkerFactory
::
create_dp_aware
(
let
dp_aware_decode
=
WorkerFactory
::
create_dp_aware
(
"http://dp-decode:8080"
.to_string
(),
"http://dp-decode:8080"
.to_string
(),
0
,
0
,
4
,
4
,
WorkerType
::
Decode
,
WorkerType
::
Decode
,
Some
(
"test_api_key"
.to_string
()),
);
);
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
...
...
sgl-router/src/core/worker_builder.rs
View file @
56321e9f
...
@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
...
@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
url
:
String
,
url
:
String
,
// Optional fields with defaults
// Optional fields with defaults
api_key
:
Option
<
String
>
,
worker_type
:
WorkerType
,
worker_type
:
WorkerType
,
connection_mode
:
ConnectionMode
,
connection_mode
:
ConnectionMode
,
labels
:
HashMap
<
String
,
String
>
,
labels
:
HashMap
<
String
,
String
>
,
...
@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
...
@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
pub
fn
new
(
url
:
impl
Into
<
String
>
)
->
Self
{
pub
fn
new
(
url
:
impl
Into
<
String
>
)
->
Self
{
Self
{
Self
{
url
:
url
.into
(),
url
:
url
.into
(),
api_key
:
None
,
worker_type
:
WorkerType
::
Regular
,
worker_type
:
WorkerType
::
Regular
,
connection_mode
:
ConnectionMode
::
Http
,
connection_mode
:
ConnectionMode
::
Http
,
labels
:
HashMap
::
new
(),
labels
:
HashMap
::
new
(),
...
@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
...
@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
pub
fn
new_with_type
(
url
:
impl
Into
<
String
>
,
worker_type
:
WorkerType
)
->
Self
{
pub
fn
new_with_type
(
url
:
impl
Into
<
String
>
,
worker_type
:
WorkerType
)
->
Self
{
Self
{
Self
{
url
:
url
.into
(),
url
:
url
.into
(),
api_key
:
None
,
worker_type
,
worker_type
,
connection_mode
:
ConnectionMode
::
Http
,
connection_mode
:
ConnectionMode
::
Http
,
labels
:
HashMap
::
new
(),
labels
:
HashMap
::
new
(),
...
@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
...
@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
}
}
}
}
/// Set the API key
pub
fn
api_key
(
mut
self
,
api_key
:
impl
Into
<
String
>
)
->
Self
{
self
.api_key
=
Some
(
api_key
.into
());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
/// Set the worker type (Regular, Prefill, or Decode)
pub
fn
worker_type
(
mut
self
,
worker_type
:
WorkerType
)
->
Self
{
pub
fn
worker_type
(
mut
self
,
worker_type
:
WorkerType
)
->
Self
{
self
.worker_type
=
worker_type
;
self
.worker_type
=
worker_type
;
...
@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
...
@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
let
metadata
=
WorkerMetadata
{
let
metadata
=
WorkerMetadata
{
url
:
self
.url
.clone
(),
url
:
self
.url
.clone
(),
api_key
:
self
.api_key
,
worker_type
:
self
.worker_type
,
worker_type
:
self
.worker_type
,
connection_mode
:
self
.connection_mode
,
connection_mode
:
self
.connection_mode
,
labels
:
self
.labels
,
labels
:
self
.labels
,
...
@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
...
@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
pub
struct
DPAwareWorkerBuilder
{
pub
struct
DPAwareWorkerBuilder
{
// Required fields
// Required fields
base_url
:
String
,
base_url
:
String
,
api_key
:
Option
<
String
>
,
dp_rank
:
usize
,
dp_rank
:
usize
,
dp_size
:
usize
,
dp_size
:
usize
,
...
@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
...
@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
pub
fn
new
(
base_url
:
impl
Into
<
String
>
,
dp_rank
:
usize
,
dp_size
:
usize
)
->
Self
{
pub
fn
new
(
base_url
:
impl
Into
<
String
>
,
dp_rank
:
usize
,
dp_size
:
usize
)
->
Self
{
Self
{
Self
{
base_url
:
base_url
.into
(),
base_url
:
base_url
.into
(),
api_key
:
None
,
dp_rank
,
dp_rank
,
dp_size
,
dp_size
,
worker_type
:
WorkerType
::
Regular
,
worker_type
:
WorkerType
::
Regular
,
...
@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
...
@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
)
->
Self
{
)
->
Self
{
Self
{
Self
{
base_url
:
base_url
.into
(),
base_url
:
base_url
.into
(),
api_key
:
None
,
dp_rank
,
dp_rank
,
dp_size
,
dp_size
,
worker_type
,
worker_type
,
...
@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
...
@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
}
}
}
}
/// Set the API key
pub
fn
api_key
(
mut
self
,
api_key
:
impl
Into
<
String
>
)
->
Self
{
self
.api_key
=
Some
(
api_key
.into
());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
/// Set the worker type (Regular, Prefill, or Decode)
pub
fn
worker_type
(
mut
self
,
worker_type
:
WorkerType
)
->
Self
{
pub
fn
worker_type
(
mut
self
,
worker_type
:
WorkerType
)
->
Self
{
self
.worker_type
=
worker_type
;
self
.worker_type
=
worker_type
;
...
@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
...
@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
if
let
Some
(
client
)
=
self
.grpc_client
{
if
let
Some
(
client
)
=
self
.grpc_client
{
builder
=
builder
.grpc_client
(
client
);
builder
=
builder
.grpc_client
(
client
);
}
}
// Add API key if provided
if
let
Some
(
api_key
)
=
self
.api_key
{
builder
=
builder
.api_key
(
api_key
);
}
let
base_worker
=
builder
.build
();
let
base_worker
=
builder
.build
();
...
@@ -382,6 +405,7 @@ mod tests {
...
@@ -382,6 +405,7 @@ mod tests {
.connection_mode
(
ConnectionMode
::
Http
)
.connection_mode
(
ConnectionMode
::
Http
)
.labels
(
labels
.clone
())
.labels
(
labels
.clone
())
.health_config
(
health_config
.clone
())
.health_config
(
health_config
.clone
())
.api_key
(
"test_api_key"
)
.build
();
.build
();
assert_eq!
(
worker
.url
(),
"http://localhost:8080@3"
);
assert_eq!
(
worker
.url
(),
"http://localhost:8080@3"
);
...
...
sgl-router/src/core/worker_registry.rs
View file @
56321e9f
...
@@ -256,6 +256,18 @@ impl WorkerRegistry {
...
@@ -256,6 +256,18 @@ impl WorkerRegistry {
.collect
()
.collect
()
}
}
pub
fn
get_all_urls_with_api_key
(
&
self
)
->
Vec
<
(
String
,
Option
<
String
>
)
>
{
self
.workers
.iter
()
.map
(|
entry
|
{
(
entry
.value
()
.url
()
.to_string
(),
entry
.value
()
.api_key
()
.clone
(),
)
})
.collect
()
}
/// Get all model IDs with workers
/// Get all model IDs with workers
pub
fn
get_models
(
&
self
)
->
Vec
<
String
>
{
pub
fn
get_models
(
&
self
)
->
Vec
<
String
>
{
self
.model_workers
self
.model_workers
...
@@ -442,6 +454,7 @@ mod tests {
...
@@ -442,6 +454,7 @@ mod tests {
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels
)
.labels
(
labels
)
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
);
);
...
@@ -477,6 +490,7 @@ mod tests {
...
@@ -477,6 +490,7 @@ mod tests {
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels1
)
.labels
(
labels1
)
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
);
);
...
@@ -487,6 +501,7 @@ mod tests {
...
@@ -487,6 +501,7 @@ mod tests {
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels2
)
.labels
(
labels2
)
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
);
);
...
@@ -497,6 +512,7 @@ mod tests {
...
@@ -497,6 +512,7 @@ mod tests {
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels3
)
.labels
(
labels3
)
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
);
);
...
...
sgl-router/src/policies/cache_aware.rs
View file @
56321e9f
...
@@ -465,11 +465,13 @@ mod tests {
...
@@ -465,11 +465,13 @@ mod tests {
Arc
::
new
(
Arc
::
new
(
BasicWorkerBuilder
::
new
(
"http://w1:8000"
)
BasicWorkerBuilder
::
new
(
"http://w1:8000"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
),
),
Arc
::
new
(
Arc
::
new
(
BasicWorkerBuilder
::
new
(
"http://w2:8000"
)
BasicWorkerBuilder
::
new
(
"http://w2:8000"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
),
),
];
];
...
...
sgl-router/src/policies/mod.rs
View file @
56321e9f
...
@@ -129,16 +129,19 @@ mod tests {
...
@@ -129,16 +129,19 @@ mod tests {
Arc
::
new
(
Arc
::
new
(
BasicWorkerBuilder
::
new
(
"http://w1:8000"
)
BasicWorkerBuilder
::
new
(
"http://w1:8000"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
),
),
Arc
::
new
(
Arc
::
new
(
BasicWorkerBuilder
::
new
(
"http://w2:8000"
)
BasicWorkerBuilder
::
new
(
"http://w2:8000"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key2"
)
.build
(),
.build
(),
),
),
Arc
::
new
(
Arc
::
new
(
BasicWorkerBuilder
::
new
(
"http://w3:8000"
)
BasicWorkerBuilder
::
new
(
"http://w3:8000"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
(),
.build
(),
),
),
];
];
...
...
sgl-router/src/protocols/worker_spec.rs
View file @
56321e9f
...
@@ -11,6 +11,10 @@ pub struct WorkerConfigRequest {
...
@@ -11,6 +11,10 @@ pub struct WorkerConfigRequest {
/// Worker URL (required)
/// Worker URL (required)
pub
url
:
String
,
pub
url
:
String
,
/// Worker API key (optional)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
api_key
:
Option
<
String
>
,
/// Model ID (optional, will query from server if not provided)
/// Model ID (optional, will query from server if not provided)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
model_id
:
Option
<
String
>
,
pub
model_id
:
Option
<
String
>
,
...
...
sgl-router/src/routers/grpc/pd_router.rs
View file @
56321e9f
...
@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter {
...
@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter {
#[async_trait]
#[async_trait]
impl
WorkerManagement
for
GrpcPDRouter
{
impl
WorkerManagement
for
GrpcPDRouter
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
,
_
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
Err
(
"Not implemented"
.to_string
())
Err
(
"Not implemented"
.to_string
())
}
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
56321e9f
...
@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter {
...
@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter {
#[async_trait]
#[async_trait]
impl
WorkerManagement
for
GrpcRouter
{
impl
WorkerManagement
for
GrpcRouter
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
,
_
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
Err
(
"Not implemented"
.to_string
())
Err
(
"Not implemented"
.to_string
())
}
}
...
...
sgl-router/src/routers/http/openai_router.rs
View file @
56321e9f
...
@@ -67,7 +67,11 @@ impl OpenAIRouter {
...
@@ -67,7 +67,11 @@ impl OpenAIRouter {
#[async_trait]
#[async_trait]
impl
super
::
super
::
WorkerManagement
for
OpenAIRouter
{
impl
super
::
super
::
WorkerManagement
for
OpenAIRouter
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
,
_
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
Err
(
"Cannot add workers to OpenAI router"
.to_string
())
Err
(
"Cannot add workers to OpenAI router"
.to_string
())
}
}
...
...
sgl-router/src/routers/http/pd_router.rs
View file @
56321e9f
...
@@ -46,6 +46,8 @@ pub struct PDRouter {
...
@@ -46,6 +46,8 @@ pub struct PDRouter {
pub
prefill_client
:
Client
,
pub
prefill_client
:
Client
,
pub
retry_config
:
RetryConfig
,
pub
retry_config
:
RetryConfig
,
pub
circuit_breaker_config
:
CircuitBreakerConfig
,
pub
circuit_breaker_config
:
CircuitBreakerConfig
,
pub
api_key
:
Option
<
String
>
,
// Channel for sending prefill responses to background workers for draining
// Channel for sending prefill responses to background workers for draining
prefill_drain_tx
:
mpsc
::
Sender
<
reqwest
::
Response
>
,
prefill_drain_tx
:
mpsc
::
Sender
<
reqwest
::
Response
>
,
}
}
...
@@ -113,21 +115,25 @@ impl PDRouter {
...
@@ -113,21 +115,25 @@ impl PDRouter {
(
results
,
errors
)
(
results
,
errors
)
}
}
fn
_
get_worker_url_and_key
(
&
self
,
w
:
&
Arc
<
dyn
Worker
>
)
->
(
String
,
Option
<
String
>
)
{
(
w
.url
()
.to_string
(),
w
.api_key
()
.clone
())
}
// Helper to get prefill worker URLs
// Helper to get prefill worker URLs
fn
get_prefill_worker_urls
(
&
self
)
->
Vec
<
String
>
{
fn
get_prefill_worker_urls
_with_api_key
(
&
self
)
->
Vec
<
(
String
,
Option
<
String
>
)
>
{
self
.worker_registry
self
.worker_registry
.get_prefill_workers
()
.get_prefill_workers
()
.iter
()
.iter
()
.map
(|
w
|
w
.url
()
.to_string
(
))
.map
(|
w
|
self
._get_worker_url_and_key
(
w
))
.collect
()
.collect
()
}
}
// Helper to get decode worker URLs
// Helper to get decode worker URLs
fn
get_decode_worker_urls
(
&
self
)
->
Vec
<
String
>
{
fn
get_decode_worker_urls
_with_api_key
(
&
self
)
->
Vec
<
(
String
,
Option
<
String
>
)
>
{
self
.worker_registry
self
.worker_registry
.get_decode_workers
()
.get_decode_workers
()
.iter
()
.iter
()
.map
(|
w
|
w
.url
()
.to_string
(
))
.map
(|
w
|
self
._get_worker_url_and_key
(
w
))
.collect
()
.collect
()
}
}
...
@@ -208,6 +214,7 @@ impl PDRouter {
...
@@ -208,6 +214,7 @@ impl PDRouter {
pub
async
fn
add_prefill_server
(
pub
async
fn
add_prefill_server
(
&
self
,
&
self
,
url
:
String
,
url
:
String
,
api_key
:
Option
<
String
>
,
bootstrap_port
:
Option
<
u16
>
,
bootstrap_port
:
Option
<
u16
>
,
)
->
Result
<
String
,
PDRouterError
>
{
)
->
Result
<
String
,
PDRouterError
>
{
// Wait for the new server to be healthy
// Wait for the new server to be healthy
...
@@ -220,10 +227,15 @@ impl PDRouter {
...
@@ -220,10 +227,15 @@ impl PDRouter {
// Create Worker for the new prefill server with circuit breaker configuration
// Create Worker for the new prefill server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let
worker
=
BasicWorkerBuilder
::
new
(
url
.clone
())
let
worker
_builder
=
BasicWorkerBuilder
::
new
(
url
.clone
())
.worker_type
(
WorkerType
::
Prefill
{
bootstrap_port
})
.worker_type
(
WorkerType
::
Prefill
{
bootstrap_port
})
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
())
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
());
.build
();
let
worker
=
if
let
Some
(
api_key
)
=
api_key
{
worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
worker_builder
.build
()
};
let
worker_arc
:
Arc
<
dyn
Worker
>
=
Arc
::
new
(
worker
);
let
worker_arc
:
Arc
<
dyn
Worker
>
=
Arc
::
new
(
worker
);
...
@@ -243,7 +255,11 @@ impl PDRouter {
...
@@ -243,7 +255,11 @@ impl PDRouter {
Ok
(
format!
(
"Successfully added prefill server: {}"
,
url
))
Ok
(
format!
(
"Successfully added prefill server: {}"
,
url
))
}
}
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
)
->
Result
<
String
,
PDRouterError
>
{
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
,
api_key
:
Option
<
String
>
,
)
->
Result
<
String
,
PDRouterError
>
{
// Wait for the new server to be healthy
// Wait for the new server to be healthy
self
.wait_for_server_health
(
&
url
)
.await
?
;
self
.wait_for_server_health
(
&
url
)
.await
?
;
...
@@ -254,10 +270,15 @@ impl PDRouter {
...
@@ -254,10 +270,15 @@ impl PDRouter {
// Create Worker for the new decode server with circuit breaker configuration
// Create Worker for the new decode server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let
worker
=
BasicWorkerBuilder
::
new
(
url
.clone
())
let
worker
_builder
=
BasicWorkerBuilder
::
new
(
url
.clone
())
.worker_type
(
WorkerType
::
Decode
)
.worker_type
(
WorkerType
::
Decode
)
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
())
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
());
.build
();
let
worker
=
if
let
Some
(
api_key
)
=
api_key
{
worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
worker_builder
.build
()
};
let
worker_arc
:
Arc
<
dyn
Worker
>
=
Arc
::
new
(
worker
);
let
worker_arc
:
Arc
<
dyn
Worker
>
=
Arc
::
new
(
worker
);
...
@@ -366,6 +387,12 @@ impl PDRouter {
...
@@ -366,6 +387,12 @@ impl PDRouter {
.chain
(
decode_workers
.iter
())
.chain
(
decode_workers
.iter
())
.map
(|
w
|
w
.url
()
.to_string
())
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
.collect
();
// Get all worker API keys for monitoring
let
all_api_keys
:
Vec
<
Option
<
String
>>
=
prefill_workers
.iter
()
.chain
(
decode_workers
.iter
())
.map
(|
w
|
w
.api_key
()
.clone
())
.collect
();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let
circuit_breaker_config
=
ctx
.router_config
.effective_circuit_breaker_config
();
let
circuit_breaker_config
=
ctx
.router_config
.effective_circuit_breaker_config
();
...
@@ -387,6 +414,7 @@ impl PDRouter {
...
@@ -387,6 +414,7 @@ impl PDRouter {
let
load_monitor_handle
=
let
load_monitor_handle
=
if
prefill_policy
.name
()
==
"power_of_two"
||
decode_policy
.name
()
==
"power_of_two"
{
if
prefill_policy
.name
()
==
"power_of_two"
||
decode_policy
.name
()
==
"power_of_two"
{
let
monitor_urls
=
all_urls
.clone
();
let
monitor_urls
=
all_urls
.clone
();
let
monitor_api_keys
=
all_api_keys
.clone
();
let
monitor_interval
=
ctx
.router_config.worker_startup_check_interval_secs
;
let
monitor_interval
=
ctx
.router_config.worker_startup_check_interval_secs
;
let
monitor_client
=
ctx
.client
.clone
();
let
monitor_client
=
ctx
.client
.clone
();
let
prefill_policy_clone
=
Arc
::
clone
(
&
prefill_policy
);
let
prefill_policy_clone
=
Arc
::
clone
(
&
prefill_policy
);
...
@@ -395,6 +423,7 @@ impl PDRouter {
...
@@ -395,6 +423,7 @@ impl PDRouter {
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
Self
::
monitor_worker_loads_with_client
(
Self
::
monitor_worker_loads_with_client
(
monitor_urls
,
monitor_urls
,
monitor_api_keys
,
tx
,
tx
,
monitor_interval
,
monitor_interval
,
monitor_client
,
monitor_client
,
...
@@ -500,6 +529,7 @@ impl PDRouter {
...
@@ -500,6 +529,7 @@ impl PDRouter {
prefill_drain_tx
,
prefill_drain_tx
,
retry_config
:
ctx
.router_config
.effective_retry_config
(),
retry_config
:
ctx
.router_config
.effective_retry_config
(),
circuit_breaker_config
:
core_cb_config
,
circuit_breaker_config
:
core_cb_config
,
api_key
:
ctx
.router_config.api_key
.clone
(),
})
})
}
}
...
@@ -1150,6 +1180,7 @@ impl PDRouter {
...
@@ -1150,6 +1180,7 @@ impl PDRouter {
// Background task to monitor worker loads with shared client
// Background task to monitor worker loads with shared client
async
fn
monitor_worker_loads_with_client
(
async
fn
monitor_worker_loads_with_client
(
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
worker_api_keys
:
Vec
<
Option
<
String
>>
,
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
interval_secs
:
u64
,
interval_secs
:
u64
,
client
:
Client
,
client
:
Client
,
...
@@ -1161,11 +1192,13 @@ impl PDRouter {
...
@@ -1161,11 +1192,13 @@ impl PDRouter {
let
futures
:
Vec
<
_
>
=
worker_urls
let
futures
:
Vec
<
_
>
=
worker_urls
.iter
()
.iter
()
.map
(|
url
|
{
.zip
(
worker_api_keys
.iter
())
.map
(|(
url
,
api_key
)|
{
let
client
=
client
.clone
();
let
client
=
client
.clone
();
let
url
=
url
.clone
();
let
url
=
url
.clone
();
let
api_key
=
api_key
.clone
();
async
move
{
async
move
{
let
load
=
get_worker_load
(
&
client
,
&
url
)
.await
.unwrap_or
(
0
);
let
load
=
get_worker_load
(
&
client
,
&
url
,
&
api_key
)
.await
.unwrap_or
(
0
);
(
url
,
load
)
(
url
,
load
)
}
}
})
})
...
@@ -1515,8 +1548,16 @@ impl PDRouter {
...
@@ -1515,8 +1548,16 @@ impl PDRouter {
// Helper functions
// Helper functions
async
fn
get_worker_load
(
client
:
&
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
async
fn
get_worker_load
(
match
client
.get
(
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
client
:
&
Client
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Option
<
isize
>
{
let
mut
req_builder
=
client
.get
(
format!
(
"{}/get_load"
,
worker_url
));
if
let
Some
(
key
)
=
api_key
{
req_builder
=
req_builder
.bearer_auth
(
key
);
}
match
req_builder
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
Value
>
(
&
bytes
)
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
Value
>
(
&
bytes
)
{
Ok
(
data
)
=>
data
Ok
(
data
)
=>
data
...
@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
...
@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
#[async_trait]
#[async_trait]
impl
WorkerManagement
for
PDRouter
{
impl
WorkerManagement
for
PDRouter
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
async
fn
add_worker
(
&
self
,
_
worker_url
:
&
str
,
_
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
// For PD router, we don't support adding workers via this generic method
// For PD router, we don't support adding workers via this generic method
Err
(
Err
(
"PD router requires specific add_prefill_server or add_decode_server methods"
"PD router requires specific add_prefill_server or add_decode_server methods"
...
@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter {
...
@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter {
let
mut
errors
=
Vec
::
new
();
let
mut
errors
=
Vec
::
new
();
// Process prefill workers
// Process prefill workers
let
prefill_urls
=
self
.get_prefill_worker_urls
();
let
prefill_urls
_with_key
=
self
.get_prefill_worker_urls
_with_api_key
();
for
worker_url
in
prefill_urls
{
for
(
worker_url
,
api_key
)
in
prefill_urls
_with_key
{
match
get_worker_load
(
&
self
.client
,
&
worker_url
)
.await
{
match
get_worker_load
(
&
self
.client
,
&
worker_url
,
&
api_key
)
.await
{
Some
(
load
)
=>
{
Some
(
load
)
=>
{
loads
.insert
(
format!
(
"prefill_{}"
,
worker_url
),
load
);
loads
.insert
(
format!
(
"prefill_{}"
,
worker_url
),
load
);
}
}
...
@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter {
...
@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter {
}
}
// Process decode workers
// Process decode workers
let
decode_urls
=
self
.get_decode_worker_urls
();
let
decode_urls
_with_key
=
self
.get_decode_worker_urls
_with_api_key
();
for
worker_url
in
decode_urls
{
for
(
worker_url
,
api_key
)
in
decode_urls
_with_key
{
match
get_worker_load
(
&
self
.client
,
&
worker_url
)
.await
{
match
get_worker_load
(
&
self
.client
,
&
worker_url
,
&
api_key
)
.await
{
Some
(
load
)
=>
{
Some
(
load
)
=>
{
loads
.insert
(
format!
(
"decode_{}"
,
worker_url
),
load
);
loads
.insert
(
format!
(
"decode_{}"
,
worker_url
),
load
);
}
}
...
@@ -2069,12 +2114,14 @@ mod tests {
...
@@ -2069,12 +2114,14 @@ mod tests {
prefill_drain_tx
:
mpsc
::
channel
(
100
)
.0
,
prefill_drain_tx
:
mpsc
::
channel
(
100
)
.0
,
retry_config
:
RetryConfig
::
default
(),
retry_config
:
RetryConfig
::
default
(),
circuit_breaker_config
:
CircuitBreakerConfig
::
default
(),
circuit_breaker_config
:
CircuitBreakerConfig
::
default
(),
api_key
:
Some
(
"test_api_key"
.to_string
()),
}
}
}
}
fn
create_test_worker
(
url
:
String
,
worker_type
:
WorkerType
,
healthy
:
bool
)
->
Box
<
dyn
Worker
>
{
fn
create_test_worker
(
url
:
String
,
worker_type
:
WorkerType
,
healthy
:
bool
)
->
Box
<
dyn
Worker
>
{
let
worker
=
BasicWorkerBuilder
::
new
(
url
)
let
worker
=
BasicWorkerBuilder
::
new
(
url
)
.worker_type
(
worker_type
)
.worker_type
(
worker_type
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
worker
.set_healthy
(
healthy
);
worker
.set_healthy
(
healthy
);
Box
::
new
(
worker
)
Box
::
new
(
worker
)
...
...
sgl-router/src/routers/http/router.rs
View file @
56321e9f
...
@@ -38,6 +38,7 @@ pub struct Router {
...
@@ -38,6 +38,7 @@ pub struct Router {
worker_startup_timeout_secs
:
u64
,
worker_startup_timeout_secs
:
u64
,
worker_startup_check_interval_secs
:
u64
,
worker_startup_check_interval_secs
:
u64
,
dp_aware
:
bool
,
dp_aware
:
bool
,
#[allow(dead_code)]
api_key
:
Option
<
String
>
,
api_key
:
Option
<
String
>
,
retry_config
:
RetryConfig
,
retry_config
:
RetryConfig
,
circuit_breaker_config
:
CircuitBreakerConfig
,
circuit_breaker_config
:
CircuitBreakerConfig
,
...
@@ -71,7 +72,6 @@ impl Router {
...
@@ -71,7 +72,6 @@ impl Router {
};
};
// Cache-aware policies are initialized in WorkerInitializer
// Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy
// Setup load monitoring for PowerOfTwo policy
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
worker_loads
=
Arc
::
new
(
rx
);
let
worker_loads
=
Arc
::
new
(
rx
);
...
@@ -82,6 +82,14 @@ impl Router {
...
@@ -82,6 +82,14 @@ impl Router {
// Check if default policy is power_of_two for load monitoring
// Check if default policy is power_of_two for load monitoring
let
load_monitor_handle
=
if
default_policy
.name
()
==
"power_of_two"
{
let
load_monitor_handle
=
if
default_policy
.name
()
==
"power_of_two"
{
let
monitor_urls
=
worker_urls
.clone
();
let
monitor_urls
=
worker_urls
.clone
();
let
monitor_api_keys
=
monitor_urls
.iter
()
.map
(|
url
|
{
ctx
.worker_registry
.get_by_url
(
url
)
.and_then
(|
w
|
w
.api_key
()
.clone
())
})
.collect
::
<
Vec
<
Option
<
String
>>>
();
let
monitor_interval
=
ctx
.router_config.worker_startup_check_interval_secs
;
let
monitor_interval
=
ctx
.router_config.worker_startup_check_interval_secs
;
let
policy_clone
=
default_policy
.clone
();
let
policy_clone
=
default_policy
.clone
();
let
client_clone
=
ctx
.client
.clone
();
let
client_clone
=
ctx
.client
.clone
();
...
@@ -89,6 +97,7 @@ impl Router {
...
@@ -89,6 +97,7 @@ impl Router {
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
Self
::
monitor_worker_loads
(
Self
::
monitor_worker_loads
(
monitor_urls
,
monitor_urls
,
monitor_api_keys
,
tx
,
tx
,
monitor_interval
,
monitor_interval
,
policy_clone
,
policy_clone
,
...
@@ -912,7 +921,11 @@ impl Router {
...
@@ -912,7 +921,11 @@ impl Router {
}
}
}
}
pub
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
pub
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
let
start_time
=
std
::
time
::
Instant
::
now
();
let
start_time
=
std
::
time
::
Instant
::
now
();
let
client
=
reqwest
::
Client
::
builder
()
let
client
=
reqwest
::
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
self
.worker_startup_timeout_secs
))
.timeout
(
Duration
::
from_secs
(
self
.worker_startup_timeout_secs
))
...
@@ -938,7 +951,7 @@ impl Router {
...
@@ -938,7 +951,7 @@ impl Router {
// Need to contact the worker to extract the dp_size,
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
// and add them as multiple workers
let
url_vec
=
vec!
[
String
::
from
(
worker_url
)];
let
url_vec
=
vec!
[
String
::
from
(
worker_url
)];
let
dp_url_vec
=
Self
::
get_dp_aware_workers
(
&
url_vec
,
&
self
.
api_key
)
let
dp_url_vec
=
Self
::
get_dp_aware_workers
(
&
url_vec
,
api_key
)
.map_err
(|
e
|
format!
(
"Failed to get dp-aware workers: {}"
,
e
))
?
;
.map_err
(|
e
|
format!
(
"Failed to get dp-aware workers: {}"
,
e
))
?
;
let
mut
worker_added
:
bool
=
false
;
let
mut
worker_added
:
bool
=
false
;
for
dp_url
in
&
dp_url_vec
{
for
dp_url
in
&
dp_url_vec
{
...
@@ -948,10 +961,18 @@ impl Router {
...
@@ -948,10 +961,18 @@ impl Router {
}
}
info!
(
"Added worker: {}"
,
dp_url
);
info!
(
"Added worker: {}"
,
dp_url
);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let
new_worker
=
BasicWorkerBuilder
::
new
(
dp_url
.to_string
())
let
new_worker_builder
=
.worker_type
(
WorkerType
::
Regular
)
BasicWorkerBuilder
::
new
(
dp_url
.to_string
())
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
())
.worker_type
(
WorkerType
::
Regular
)
.build
();
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
(),
);
let
new_worker
=
if
let
Some
(
api_key
)
=
api_key
{
new_worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
new_worker_builder
.build
()
};
let
worker_arc
=
Arc
::
new
(
new_worker
);
let
worker_arc
=
Arc
::
new
(
new_worker
);
self
.worker_registry
.register
(
worker_arc
.clone
());
self
.worker_registry
.register
(
worker_arc
.clone
());
...
@@ -978,10 +999,16 @@ impl Router {
...
@@ -978,10 +999,16 @@ impl Router {
info!
(
"Added worker: {}"
,
worker_url
);
info!
(
"Added worker: {}"
,
worker_url
);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let
new_worker
=
BasicWorkerBuilder
::
new
(
worker_url
.to_string
())
let
new_worker_builder
=
.worker_type
(
WorkerType
::
Regular
)
BasicWorkerBuilder
::
new
(
worker_url
.to_string
())
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
())
.worker_type
(
WorkerType
::
Regular
)
.build
();
.circuit_breaker_config
(
self
.circuit_breaker_config
.clone
());
let
new_worker
=
if
let
Some
(
api_key
)
=
api_key
{
new_worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
new_worker_builder
.build
()
};
let
worker_arc
=
Arc
::
new
(
new_worker
);
let
worker_arc
=
Arc
::
new
(
new_worker
);
self
.worker_registry
.register
(
worker_arc
.clone
());
self
.worker_registry
.register
(
worker_arc
.clone
());
...
@@ -1094,7 +1121,7 @@ impl Router {
...
@@ -1094,7 +1121,7 @@ impl Router {
}
}
}
}
async
fn
get_worker_load
(
&
self
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
async
fn
get_worker_load
(
&
self
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
)
->
Option
<
isize
>
{
let
worker_url
=
if
self
.dp_aware
{
let
worker_url
=
if
self
.dp_aware
{
// Need to extract the URL from "http://host:port@dp_rank"
// Need to extract the URL from "http://host:port@dp_rank"
let
(
worker_url_prefix
,
_
dp_rank
)
=
match
Self
::
extract_dp_rank
(
worker_url
)
{
let
(
worker_url_prefix
,
_
dp_rank
)
=
match
Self
::
extract_dp_rank
(
worker_url
)
{
...
@@ -1109,12 +1136,12 @@ impl Router {
...
@@ -1109,12 +1136,12 @@ impl Router {
worker_url
worker_url
};
};
match
self
let
mut
req_builder
=
self
.client
.get
(
format!
(
"{}/get_load"
,
worker_url
));
.client
if
let
Some
(
key
)
=
api_key
{
.get
(
format!
(
"{}/get_load"
,
worker_url
))
req_builder
=
req_builder
.bearer_auth
(
key
);
.send
()
}
.await
{
match
req_builder
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
bytes
)
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
bytes
)
{
Ok
(
data
)
=>
data
Ok
(
data
)
=>
data
...
@@ -1149,6 +1176,7 @@ impl Router {
...
@@ -1149,6 +1176,7 @@ impl Router {
// Background task to monitor worker loads
// Background task to monitor worker loads
async
fn
monitor_worker_loads
(
async
fn
monitor_worker_loads
(
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
worker_api_keys
:
Vec
<
Option
<
String
>>
,
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
interval_secs
:
u64
,
interval_secs
:
u64
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
...
@@ -1160,8 +1188,8 @@ impl Router {
...
@@ -1160,8 +1188,8 @@ impl Router {
interval
.tick
()
.await
;
interval
.tick
()
.await
;
let
mut
loads
=
HashMap
::
new
();
let
mut
loads
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
for
(
url
,
api_key
)
in
worker_urls
.iter
()
.zip
(
worker_api_keys
.iter
())
{
if
let
Some
(
load
)
=
Self
::
get_worker_load_static
(
&
client
,
url
)
.await
{
if
let
Some
(
load
)
=
Self
::
get_worker_load_static
(
&
client
,
url
,
api_key
)
.await
{
loads
.insert
(
url
.clone
(),
load
);
loads
.insert
(
url
.clone
(),
load
);
}
}
}
}
...
@@ -1179,7 +1207,11 @@ impl Router {
...
@@ -1179,7 +1207,11 @@ impl Router {
}
}
// Static version of get_worker_load for use in monitoring task
// Static version of get_worker_load for use in monitoring task
async
fn
get_worker_load_static
(
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
async
fn
get_worker_load_static
(
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Option
<
isize
>
{
let
worker_url
=
if
worker_url
.contains
(
"@"
)
{
let
worker_url
=
if
worker_url
.contains
(
"@"
)
{
// Need to extract the URL from "http://host:port@dp_rank"
// Need to extract the URL from "http://host:port@dp_rank"
let
(
worker_url_prefix
,
_
dp_rank
)
=
match
Self
::
extract_dp_rank
(
worker_url
)
{
let
(
worker_url_prefix
,
_
dp_rank
)
=
match
Self
::
extract_dp_rank
(
worker_url
)
{
...
@@ -1194,7 +1226,11 @@ impl Router {
...
@@ -1194,7 +1226,11 @@ impl Router {
worker_url
worker_url
};
};
match
client
.get
(
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
let
mut
req_builder
=
client
.get
(
format!
(
"{}/get_load"
,
worker_url
));
if
let
Some
(
key
)
=
api_key
{
req_builder
=
req_builder
.bearer_auth
(
key
);
}
match
req_builder
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
bytes
)
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
bytes
)
{
Ok
(
data
)
=>
data
Ok
(
data
)
=>
data
...
@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
...
@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
#[async_trait]
#[async_trait]
impl
WorkerManagement
for
Router
{
impl
WorkerManagement
for
Router
{
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
async
fn
add_worker
(
Router
::
add_worker
(
self
,
worker_url
)
.await
&
self
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
Router
::
add_worker
(
self
,
worker_url
,
api_key
)
.await
}
}
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
...
@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
...
@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
}
}
async
fn
get_worker_loads
(
&
self
)
->
Response
{
async
fn
get_worker_loads
(
&
self
)
->
Response
{
let
urls
=
self
.
get_
worker_
urls
();
let
urls
_with_key
=
self
.worker_
registry
.get_all_urls_with_api_key
();
let
mut
loads
=
Vec
::
new
();
let
mut
loads
=
Vec
::
new
();
// Get loads from all workers
// Get loads from all workers
for
url
in
&
urls
{
for
(
url
,
api_key
)
in
&
urls_with_key
{
let
load
=
self
.get_worker_load
(
url
)
.await
.unwrap_or
(
-
1
);
let
load
=
self
.get_worker_load
(
url
,
api_key
)
.await
.unwrap_or
(
-
1
);
loads
.push
(
serde_json
::
json!
({
loads
.push
(
serde_json
::
json!
({
"worker"
:
url
,
"worker"
:
url
,
"load"
:
load
"load"
:
load
...
@@ -1521,9 +1561,11 @@ mod tests {
...
@@ -1521,9 +1561,11 @@ mod tests {
// Register test workers
// Register test workers
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
worker_registry
.register
(
Arc
::
new
(
worker1
));
worker_registry
.register
(
Arc
::
new
(
worker1
));
worker_registry
.register
(
Arc
::
new
(
worker2
));
worker_registry
.register
(
Arc
::
new
(
worker2
));
...
...
sgl-router/src/routers/mod.rs
View file @
56321e9f
...
@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router};
...
@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router};
#[async_trait]
#[async_trait]
pub
trait
WorkerManagement
:
Send
+
Sync
{
pub
trait
WorkerManagement
:
Send
+
Sync
{
/// Add a worker to the router
/// Add a worker to the router
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
;
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
;
/// Remove a worker from the router
/// Remove a worker from the router
fn
remove_worker
(
&
self
,
worker_url
:
&
str
);
fn
remove_worker
(
&
self
,
worker_url
:
&
str
);
...
...
sgl-router/src/routers/router_manager.rs
View file @
56321e9f
...
@@ -161,7 +161,7 @@ impl RouterManager {
...
@@ -161,7 +161,7 @@ impl RouterManager {
let
model_id
=
if
let
Some
(
model_id
)
=
config
.model_id
{
let
model_id
=
if
let
Some
(
model_id
)
=
config
.model_id
{
model_id
model_id
}
else
{
}
else
{
match
self
.query_server_info
(
&
config
.url
)
.await
{
match
self
.query_server_info
(
&
config
.url
,
&
config
.api_key
)
.await
{
Ok
(
info
)
=>
{
Ok
(
info
)
=>
{
// Extract model_id from server info
// Extract model_id from server info
info
.model_id
info
.model_id
...
@@ -208,29 +208,44 @@ impl RouterManager {
...
@@ -208,29 +208,44 @@ impl RouterManager {
}
}
let
worker
=
match
config
.worker_type
.as_deref
()
{
let
worker
=
match
config
.worker_type
.as_deref
()
{
Some
(
"prefill"
)
=>
Box
::
new
(
Some
(
"prefill"
)
=>
{
BasicWorkerBuilder
::
new
(
config
.url
.clone
())
let
mut
builder
=
BasicWorkerBuilder
::
new
(
config
.url
.clone
())
.worker_type
(
WorkerType
::
Prefill
{
.worker_type
(
WorkerType
::
Prefill
{
bootstrap_port
:
config
.bootstrap_port
,
bootstrap_port
:
config
.bootstrap_port
,
})
})
.labels
(
labels
.clone
())
.labels
(
labels
.clone
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
());
.build
(),
)
as
Box
<
dyn
Worker
>
,
if
let
Some
(
api_key
)
=
config
.api_key
.clone
()
{
Some
(
"decode"
)
=>
Box
::
new
(
builder
=
builder
.api_key
(
api_key
);
BasicWorkerBuilder
::
new
(
config
.url
.clone
())
}
Box
::
new
(
builder
.build
())
as
Box
<
dyn
Worker
>
}
Some
(
"decode"
)
=>
{
let
mut
builder
=
BasicWorkerBuilder
::
new
(
config
.url
.clone
())
.worker_type
(
WorkerType
::
Decode
)
.worker_type
(
WorkerType
::
Decode
)
.labels
(
labels
.clone
())
.labels
(
labels
.clone
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
());
.build
(),
)
as
Box
<
dyn
Worker
>
,
if
let
Some
(
api_key
)
=
config
.api_key
.clone
()
{
_
=>
Box
::
new
(
builder
=
builder
.api_key
(
api_key
);
BasicWorkerBuilder
::
new
(
config
.url
.clone
())
}
Box
::
new
(
builder
.build
())
as
Box
<
dyn
Worker
>
}
_
=>
{
let
mut
builder
=
BasicWorkerBuilder
::
new
(
config
.url
.clone
())
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels
.clone
())
.labels
(
labels
.clone
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
())
.circuit_breaker_config
(
CircuitBreakerConfig
::
default
());
.build
(),
)
as
Box
<
dyn
Worker
>
,
if
let
Some
(
api_key
)
=
config
.api_key
.clone
()
{
builder
=
builder
.api_key
(
api_key
);
}
Box
::
new
(
builder
.build
())
as
Box
<
dyn
Worker
>
}
};
};
// Register worker
// Register worker
...
@@ -346,10 +361,18 @@ impl RouterManager {
...
@@ -346,10 +361,18 @@ impl RouterManager {
}
}
/// Query server info from a worker URL
/// Query server info from a worker URL
async
fn
query_server_info
(
&
self
,
url
:
&
str
)
->
Result
<
ServerInfo
,
String
>
{
async
fn
query_server_info
(
&
self
,
url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Result
<
ServerInfo
,
String
>
{
let
info_url
=
format!
(
"{}/get_server_info"
,
url
.trim_end_matches
(
'/'
));
let
info_url
=
format!
(
"{}/get_server_info"
,
url
.trim_end_matches
(
'/'
));
match
self
.client
.get
(
&
info_url
)
.send
()
.await
{
let
mut
req_builder
=
self
.client
.get
(
&
info_url
);
if
let
Some
(
key
)
=
api_key
{
req_builder
=
req_builder
.bearer_auth
(
key
);
}
match
req_builder
.send
()
.await
{
Ok
(
response
)
=>
{
Ok
(
response
)
=>
{
if
response
.status
()
.is_success
()
{
if
response
.status
()
.is_success
()
{
response
response
...
@@ -477,10 +500,15 @@ impl RouterManager {
...
@@ -477,10 +500,15 @@ impl RouterManager {
#[async_trait]
#[async_trait]
impl
WorkerManagement
for
RouterManager
{
impl
WorkerManagement
for
RouterManager
{
/// Add a worker - in multi-router mode, this adds to the registry
/// Add a worker - in multi-router mode, this adds to the registry
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
Result
<
String
,
String
>
{
// Create a basic worker config request
// Create a basic worker config request
let
config
=
WorkerConfigRequest
{
let
config
=
WorkerConfigRequest
{
url
:
worker_url
.to_string
(),
url
:
worker_url
.to_string
(),
api_key
:
api_key
.clone
(),
model_id
:
None
,
model_id
:
None
,
worker_type
:
None
,
worker_type
:
None
,
priority
:
None
,
priority
:
None
,
...
...
sgl-router/src/routers/worker_initializer.rs
View file @
56321e9f
...
@@ -27,8 +27,12 @@ impl WorkerInitializer {
...
@@ -27,8 +27,12 @@ impl WorkerInitializer {
match
&
config
.mode
{
match
&
config
.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
// use router's api_key, repeat for each worker
let
worker_api_keys
:
Vec
<
Option
<
String
>>
=
worker_urls
.iter
()
.map
(|
_
|
config
.api_key
.clone
())
.collect
();
Self
::
create_regular_workers
(
Self
::
create_regular_workers
(
worker_urls
,
worker_urls
,
&
worker_api_keys
,
&
config
.connection_mode
,
&
config
.connection_mode
,
config
,
config
,
worker_registry
,
worker_registry
,
...
@@ -41,8 +45,16 @@ impl WorkerInitializer {
...
@@ -41,8 +45,16 @@ impl WorkerInitializer {
decode_urls
,
decode_urls
,
..
..
}
=>
{
}
=>
{
// use router's api_key, repeat for each prefill/decode worker
let
prefill_api_keys
:
Vec
<
Option
<
String
>>
=
prefill_urls
.iter
()
.map
(|
_
|
config
.api_key
.clone
())
.collect
();
let
decode_api_keys
:
Vec
<
Option
<
String
>>
=
decode_urls
.iter
()
.map
(|
_
|
config
.api_key
.clone
())
.collect
();
Self
::
create_prefill_workers
(
Self
::
create_prefill_workers
(
prefill_urls
,
prefill_urls
,
&
prefill_api_keys
,
&
config
.connection_mode
,
&
config
.connection_mode
,
config
,
config
,
worker_registry
,
worker_registry
,
...
@@ -51,6 +63,7 @@ impl WorkerInitializer {
...
@@ -51,6 +63,7 @@ impl WorkerInitializer {
.await
?
;
.await
?
;
Self
::
create_decode_workers
(
Self
::
create_decode_workers
(
decode_urls
,
decode_urls
,
&
decode_api_keys
,
&
config
.connection_mode
,
&
config
.connection_mode
,
config
,
config
,
worker_registry
,
worker_registry
,
...
@@ -79,6 +92,7 @@ impl WorkerInitializer {
...
@@ -79,6 +92,7 @@ impl WorkerInitializer {
/// Create regular workers for standard routing mode
/// Create regular workers for standard routing mode
async
fn
create_regular_workers
(
async
fn
create_regular_workers
(
urls
:
&
[
String
],
urls
:
&
[
String
],
api_keys
:
&
[
Option
<
String
>
],
config_connection_mode
:
&
ConfigConnectionMode
,
config_connection_mode
:
&
ConfigConnectionMode
,
config
:
&
RouterConfig
,
config
:
&
RouterConfig
,
registry
:
&
Arc
<
WorkerRegistry
>
,
registry
:
&
Arc
<
WorkerRegistry
>
,
...
@@ -109,14 +123,18 @@ impl WorkerInitializer {
...
@@ -109,14 +123,18 @@ impl WorkerInitializer {
let
mut
registered_workers
:
HashMap
<
String
,
Vec
<
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
let
mut
registered_workers
:
HashMap
<
String
,
Vec
<
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
for
url
in
urls
{
for
(
url
,
api_key
)
in
urls
.iter
()
.zip
(
api_keys
.iter
())
{
// TODO: Add DP-aware support when we have dp_rank/dp_size info
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let
worker
=
BasicWorkerBuilder
::
new
(
url
.clone
())
let
worker
_builder
=
BasicWorkerBuilder
::
new
(
url
.clone
())
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.connection_mode
(
connection_mode
.clone
())
.connection_mode
(
connection_mode
.clone
())
.circuit_breaker_config
(
core_cb_config
.clone
())
.circuit_breaker_config
(
core_cb_config
.clone
())
.health_config
(
health_config
.clone
())
.health_config
(
health_config
.clone
());
.build
();
let
worker
=
if
let
Some
(
api_key
)
=
api_key
.clone
()
{
worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
worker_builder
.build
()
};
let
worker_arc
=
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
;
let
worker_arc
=
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
;
let
model_id
=
worker_arc
.model_id
();
let
model_id
=
worker_arc
.model_id
();
...
@@ -148,6 +166,7 @@ impl WorkerInitializer {
...
@@ -148,6 +166,7 @@ impl WorkerInitializer {
/// Create prefill workers for disaggregated routing mode
/// Create prefill workers for disaggregated routing mode
async
fn
create_prefill_workers
(
async
fn
create_prefill_workers
(
prefill_entries
:
&
[(
String
,
Option
<
u16
>
)],
prefill_entries
:
&
[(
String
,
Option
<
u16
>
)],
api_keys
:
&
[
Option
<
String
>
],
config_connection_mode
:
&
ConfigConnectionMode
,
config_connection_mode
:
&
ConfigConnectionMode
,
config
:
&
RouterConfig
,
config
:
&
RouterConfig
,
registry
:
&
Arc
<
WorkerRegistry
>
,
registry
:
&
Arc
<
WorkerRegistry
>
,
...
@@ -181,16 +200,20 @@ impl WorkerInitializer {
...
@@ -181,16 +200,20 @@ impl WorkerInitializer {
let
mut
registered_workers
:
HashMap
<
String
,
Vec
<
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
let
mut
registered_workers
:
HashMap
<
String
,
Vec
<
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
for
(
url
,
bootstrap_port
)
in
prefill_entries
{
for
(
(
url
,
bootstrap_port
)
,
api_key
)
in
prefill_entries
.iter
()
.zip
(
api_keys
.iter
())
{
// TODO: Add DP-aware support when we have dp_rank/dp_size info
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let
worker
=
BasicWorkerBuilder
::
new
(
url
.clone
())
let
worker
_builder
=
BasicWorkerBuilder
::
new
(
url
.clone
())
.worker_type
(
WorkerType
::
Prefill
{
.worker_type
(
WorkerType
::
Prefill
{
bootstrap_port
:
*
bootstrap_port
,
bootstrap_port
:
*
bootstrap_port
,
})
})
.connection_mode
(
connection_mode
.clone
())
.connection_mode
(
connection_mode
.clone
())
.circuit_breaker_config
(
core_cb_config
.clone
())
.circuit_breaker_config
(
core_cb_config
.clone
())
.health_config
(
health_config
.clone
())
.health_config
(
health_config
.clone
());
.build
();
let
worker
=
if
let
Some
(
api_key
)
=
api_key
.clone
()
{
worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
worker_builder
.build
()
};
let
worker_arc
=
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
;
let
worker_arc
=
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
;
let
model_id
=
worker_arc
.model_id
();
let
model_id
=
worker_arc
.model_id
();
...
@@ -227,6 +250,7 @@ impl WorkerInitializer {
...
@@ -227,6 +250,7 @@ impl WorkerInitializer {
/// Create decode workers for disaggregated routing mode
/// Create decode workers for disaggregated routing mode
async
fn
create_decode_workers
(
async
fn
create_decode_workers
(
urls
:
&
[
String
],
urls
:
&
[
String
],
api_keys
:
&
[
Option
<
String
>
],
config_connection_mode
:
&
ConfigConnectionMode
,
config_connection_mode
:
&
ConfigConnectionMode
,
config
:
&
RouterConfig
,
config
:
&
RouterConfig
,
registry
:
&
Arc
<
WorkerRegistry
>
,
registry
:
&
Arc
<
WorkerRegistry
>
,
...
@@ -257,14 +281,18 @@ impl WorkerInitializer {
...
@@ -257,14 +281,18 @@ impl WorkerInitializer {
let
mut
registered_workers
:
HashMap
<
String
,
Vec
<
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
let
mut
registered_workers
:
HashMap
<
String
,
Vec
<
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
for
url
in
urls
{
for
(
url
,
api_key
)
in
urls
.iter
()
.zip
(
api_keys
.iter
())
{
// TODO: Add DP-aware support when we have dp_rank/dp_size info
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let
worker
=
BasicWorkerBuilder
::
new
(
url
.clone
())
let
worker
_builder
=
BasicWorkerBuilder
::
new
(
url
.clone
())
.worker_type
(
WorkerType
::
Decode
)
.worker_type
(
WorkerType
::
Decode
)
.connection_mode
(
connection_mode
.clone
())
.connection_mode
(
connection_mode
.clone
())
.circuit_breaker_config
(
core_cb_config
.clone
())
.circuit_breaker_config
(
core_cb_config
.clone
())
.health_config
(
health_config
.clone
())
.health_config
(
health_config
.clone
());
.build
();
let
worker
=
if
let
Some
(
api_key
)
=
api_key
.clone
()
{
worker_builder
.api_key
(
api_key
)
.build
()
}
else
{
worker_builder
.build
()
};
let
worker_arc
=
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
;
let
worker_arc
=
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
;
let
model_id
=
worker_arc
.model_id
();
let
model_id
=
worker_arc
.model_id
();
...
...
sgl-router/src/server.rs
View file @
56321e9f
...
@@ -282,15 +282,16 @@ async fn v1_responses_list_input_items(
...
@@ -282,15 +282,16 @@ async fn v1_responses_list_input_items(
// ---------- Worker management endpoints (Legacy) ----------
// ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)]
#[derive(Deserialize)]
struct
Url
Query
{
struct
AddWorker
Query
{
url
:
String
,
url
:
String
,
api_key
:
Option
<
String
>
,
}
}
async
fn
add_worker
(
async
fn
add_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
Query
(
Url
Query
{
url
}):
Query
<
Url
Query
>
,
Query
(
AddWorker
Query
{
url
,
api_key
}):
Query
<
AddWorker
Query
>
,
)
->
Response
{
)
->
Response
{
match
state
.router
.add_worker
(
&
url
)
.await
{
match
state
.router
.add_worker
(
&
url
,
&
api_key
)
.await
{
Ok
(
message
)
=>
(
StatusCode
::
OK
,
message
)
.into_response
(),
Ok
(
message
)
=>
(
StatusCode
::
OK
,
message
)
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
error
)
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
error
)
.into_response
(),
}
}
...
@@ -303,7 +304,7 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
...
@@ -303,7 +304,7 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
async
fn
remove_worker
(
async
fn
remove_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
Query
(
Url
Query
{
url
}):
Query
<
Url
Query
>
,
Query
(
AddWorker
Query
{
url
,
..
}):
Query
<
AddWorker
Query
>
,
)
->
Response
{
)
->
Response
{
state
.router
.remove_worker
(
&
url
);
state
.router
.remove_worker
(
&
url
);
(
(
...
@@ -337,7 +338,7 @@ async fn create_worker(
...
@@ -337,7 +338,7 @@ async fn create_worker(
}
}
}
else
{
}
else
{
// In single router mode, use the router's add_worker with basic config
// In single router mode, use the router's add_worker with basic config
match
state
.router
.add_worker
(
&
config
.url
)
.await
{
match
state
.router
.add_worker
(
&
config
.url
,
&
config
.api_key
)
.await
{
Ok
(
message
)
=>
{
Ok
(
message
)
=>
{
let
response
=
WorkerApiResponse
{
let
response
=
WorkerApiResponse
{
success
:
true
,
success
:
true
,
...
...
sgl-router/src/service_discovery.rs
View file @
56321e9f
...
@@ -389,16 +389,20 @@ async fn handle_pod_event(
...
@@ -389,16 +389,20 @@ async fn handle_pod_event(
if
let
Some
(
pd_router
)
=
router
.as_any
()
.downcast_ref
::
<
PDRouter
>
()
{
if
let
Some
(
pd_router
)
=
router
.as_any
()
.downcast_ref
::
<
PDRouter
>
()
{
match
&
pod_info
.pod_type
{
match
&
pod_info
.pod_type
{
Some
(
PodType
::
Prefill
)
=>
pd_router
Some
(
PodType
::
Prefill
)
=>
pd_router
.add_prefill_server
(
worker_url
.clone
(),
pod_info
.bootstrap_port
)
.add_prefill_server
(
worker_url
.clone
(),
pd_router
.api_key
.clone
(),
pod_info
.bootstrap_port
,
)
.await
.await
.map_err
(|
e
|
e
.to_string
()),
.map_err
(|
e
|
e
.to_string
()),
Some
(
PodType
::
Decode
)
=>
pd_router
Some
(
PodType
::
Decode
)
=>
pd_router
.add_decode_server
(
worker_url
.clone
())
.add_decode_server
(
worker_url
.clone
()
,
pd_router
.api_key
.clone
()
)
.await
.await
.map_err
(|
e
|
e
.to_string
()),
.map_err
(|
e
|
e
.to_string
()),
Some
(
PodType
::
Regular
)
|
None
=>
{
Some
(
PodType
::
Regular
)
|
None
=>
{
// Fall back to regular add_worker for regular pods
// Fall back to regular add_worker for regular pods
router
.add_worker
(
&
worker_url
)
.await
router
.add_worker
(
&
worker_url
,
&
pd_router
.api_key
)
.await
}
}
}
}
}
else
{
}
else
{
...
@@ -406,7 +410,8 @@ async fn handle_pod_event(
...
@@ -406,7 +410,8 @@ async fn handle_pod_event(
}
}
}
else
{
}
else
{
// Regular mode or no pod type specified
// Regular mode or no pod type specified
router
.add_worker
(
&
worker_url
)
.await
// In pod, no need api key
router
.add_worker
(
&
worker_url
,
&
None
)
.await
};
};
match
result
{
match
result
{
...
...
sgl-router/tests/cache_aware_backward_compat_test.rs
View file @
56321e9f
...
@@ -18,6 +18,7 @@ fn test_backward_compatibility_with_empty_model_id() {
...
@@ -18,6 +18,7 @@ fn test_backward_compatibility_with_empty_model_id() {
// Create workers with empty model_id (simulating existing routers)
// Create workers with empty model_id (simulating existing routers)
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
// No model_id label - should default to "unknown"
// No model_id label - should default to "unknown"
...
@@ -25,6 +26,7 @@ fn test_backward_compatibility_with_empty_model_id() {
...
@@ -25,6 +26,7 @@ fn test_backward_compatibility_with_empty_model_id() {
labels2
.insert
(
"model_id"
.to_string
(),
"unknown"
.to_string
());
labels2
.insert
(
"model_id"
.to_string
(),
"unknown"
.to_string
());
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.labels
(
labels2
)
.labels
(
labels2
)
.build
();
.build
();
...
@@ -59,6 +61,7 @@ fn test_mixed_model_ids() {
...
@@ -59,6 +61,7 @@ fn test_mixed_model_ids() {
// Create workers with different model_id scenarios
// Create workers with different model_id scenarios
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
// No model_id label - defaults to "unknown" which goes to "default" tree
// No model_id label - defaults to "unknown" which goes to "default" tree
...
@@ -67,6 +70,7 @@ fn test_mixed_model_ids() {
...
@@ -67,6 +70,7 @@ fn test_mixed_model_ids() {
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels2
)
.labels
(
labels2
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
let
mut
labels3
=
HashMap
::
new
();
let
mut
labels3
=
HashMap
::
new
();
...
@@ -123,10 +127,12 @@ fn test_remove_worker_by_url_backward_compat() {
...
@@ -123,10 +127,12 @@ fn test_remove_worker_by_url_backward_compat() {
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
let
worker1
=
BasicWorkerBuilder
::
new
(
"http://worker1:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.labels
(
labels1
)
.labels
(
labels1
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
let
worker2
=
BasicWorkerBuilder
::
new
(
"http://worker2:8080"
)
.worker_type
(
WorkerType
::
Regular
)
.worker_type
(
WorkerType
::
Regular
)
.api_key
(
"test_api_key"
)
.build
();
.build
();
// No model_id label - defaults to "unknown"
// No model_id label - defaults to "unknown"
...
...
sgl-router/tests/policy_registry_integration.rs
View file @
56321e9f
...
@@ -41,6 +41,7 @@ async fn test_policy_registry_with_router_manager() {
...
@@ -41,6 +41,7 @@ async fn test_policy_registry_with_router_manager() {
let
_
worker1_config
=
WorkerConfigRequest
{
let
_
worker1_config
=
WorkerConfigRequest
{
url
:
"http://worker1:8000"
.to_string
(),
url
:
"http://worker1:8000"
.to_string
(),
model_id
:
Some
(
"llama-3"
.to_string
()),
model_id
:
Some
(
"llama-3"
.to_string
()),
api_key
:
Some
(
"test_api_key"
.to_string
()),
worker_type
:
None
,
worker_type
:
None
,
priority
:
None
,
priority
:
None
,
cost
:
None
,
cost
:
None
,
...
@@ -66,6 +67,7 @@ async fn test_policy_registry_with_router_manager() {
...
@@ -66,6 +67,7 @@ async fn test_policy_registry_with_router_manager() {
let
_
worker2_config
=
WorkerConfigRequest
{
let
_
worker2_config
=
WorkerConfigRequest
{
url
:
"http://worker2:8000"
.to_string
(),
url
:
"http://worker2:8000"
.to_string
(),
model_id
:
Some
(
"llama-3"
.to_string
()),
model_id
:
Some
(
"llama-3"
.to_string
()),
api_key
:
Some
(
"test_api_key"
.to_string
()),
worker_type
:
None
,
worker_type
:
None
,
priority
:
None
,
priority
:
None
,
cost
:
None
,
cost
:
None
,
...
@@ -86,6 +88,7 @@ async fn test_policy_registry_with_router_manager() {
...
@@ -86,6 +88,7 @@ async fn test_policy_registry_with_router_manager() {
let
_
worker3_config
=
WorkerConfigRequest
{
let
_
worker3_config
=
WorkerConfigRequest
{
url
:
"http://worker3:8000"
.to_string
(),
url
:
"http://worker3:8000"
.to_string
(),
model_id
:
Some
(
"gpt-4"
.to_string
()),
model_id
:
Some
(
"gpt-4"
.to_string
()),
api_key
:
Some
(
"test_api_key"
.to_string
()),
worker_type
:
None
,
worker_type
:
None
,
priority
:
None
,
priority
:
None
,
cost
:
None
,
cost
:
None
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment