Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fa0462c
Unverified
Commit
2fa0462c
authored
Aug 04, 2025
by
Simo Lin
Committed by
GitHub
Aug 04, 2025
Browse files
[router] introduce dp worker abstraction (#8639)
parent
915140fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
567 additions
and
11 deletions
+567
-11
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+2
-2
sgl-router/src/core/worker.rs
sgl-router/src/core/worker.rs
+565
-9
No files found.
sgl-router/src/core/mod.rs
View file @
2fa0462c
...
...
@@ -11,6 +11,6 @@ pub mod worker;
// Re-export commonly used types at the module level
pub
use
error
::{
WorkerError
,
WorkerResult
};
pub
use
worker
::{
start_health_checker
,
BasicWorker
,
HealthChecker
,
Worker
,
WorkerCollection
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
start_health_checker
,
BasicWorker
,
DPAwareWorker
,
HealthChecker
,
Worker
,
WorkerCollection
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
};
sgl-router/src/core/worker.rs
View file @
2fa0462c
use
super
::{
WorkerError
,
WorkerResult
};
use
async_trait
::
async_trait
;
use
futures
;
use
once_cell
::
sync
::
Lazy
;
use
serde_json
;
use
std
::
fmt
;
use
std
::
sync
::
atomic
::{
AtomicBool
,
AtomicUsize
,
Ordering
};
use
std
::
sync
::
Arc
;
// Shared HTTP client for
health checks
static
HEALTH_CHECK
_CLIENT
:
Lazy
<
reqwest
::
Client
>
=
Lazy
::
new
(||
{
// Shared HTTP client for
worker operations (health checks, server info, etc.)
static
WORKER
_CLIENT
:
Lazy
<
reqwest
::
Client
>
=
Lazy
::
new
(||
{
reqwest
::
Client
::
builder
()
.timeout
(
std
::
time
::
Duration
::
from_secs
(
30
))
// Default timeout, overridden per request
.build
()
.expect
(
"Failed to create
health check
HTTP client"
)
.expect
(
"Failed to create
worker
HTTP client"
)
});
/// Core worker abstraction that represents a backend service
...
...
@@ -64,6 +66,43 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Clone the worker (for trait objects)
fn
clone_worker
(
&
self
)
->
Box
<
dyn
Worker
>
;
// === DP-aware methods ===
/// Check if this worker is DP-aware
fn
is_dp_aware
(
&
self
)
->
bool
{
false
}
/// Get the base URL without any DP rank suffix
fn
base_url
(
&
self
)
->
&
str
{
self
.url
()
}
/// Get DP rank if this is a DP-aware worker
fn
dp_rank
(
&
self
)
->
Option
<
usize
>
{
None
}
/// Get DP size if this worker is part of a DP group
fn
dp_size
(
&
self
)
->
Option
<
usize
>
{
None
}
/// Transform a request for DP-aware routing
async
fn
prepare_request
(
&
self
,
req
:
serde_json
::
Value
)
->
WorkerResult
<
serde_json
::
Value
>
{
Ok
(
req
)
}
/// Get the actual endpoint URL for requests
fn
endpoint_url
(
&
self
,
route
:
&
str
)
->
String
{
format!
(
"{}{}"
,
self
.base_url
(),
route
)
}
/// Check if this worker can handle a specific request
fn
can_handle
(
&
self
,
_
req
:
&
serde_json
::
Value
)
->
bool
{
true
}
}
/// Worker type classification
...
...
@@ -212,12 +251,7 @@ impl Worker for BasicWorker {
let
timeout
=
Duration
::
from_secs
(
self
.metadata.health_config.timeout_secs
);
// Use the shared client with a custom timeout for this request
match
HEALTH_CHECK_CLIENT
.get
(
&
health_url
)
.timeout
(
timeout
)
.send
()
.await
{
match
WORKER_CLIENT
.get
(
&
health_url
)
.timeout
(
timeout
)
.send
()
.await
{
Ok
(
response
)
=>
{
if
response
.status
()
.is_success
()
{
self
.set_healthy
(
true
);
...
...
@@ -273,6 +307,160 @@ impl Worker for BasicWorker {
}
}
/// A DP-aware worker that handles data-parallel routing
#[derive(Debug,
Clone)]
pub
struct
DPAwareWorker
{
/// The underlying basic worker
base_worker
:
BasicWorker
,
/// DP rank for this worker
dp_rank
:
usize
,
/// Total DP size
dp_size
:
usize
,
/// Base URL without DP suffix
base_url
:
String
,
}
impl
DPAwareWorker
{
/// Create a new DP-aware worker of any type
pub
fn
new
(
base_url
:
String
,
dp_rank
:
usize
,
dp_size
:
usize
,
worker_type
:
WorkerType
)
->
Self
{
// Create URL with DP rank suffix for identification
let
worker_url
=
format!
(
"{}@{}"
,
base_url
,
dp_rank
);
let
base_worker
=
BasicWorker
::
new
(
worker_url
,
worker_type
);
Self
{
base_worker
,
dp_rank
,
dp_size
,
base_url
,
}
}
}
#[async_trait]
impl
Worker
for
DPAwareWorker
{
fn
url
(
&
self
)
->
&
str
{
self
.base_worker
.url
()
}
fn
worker_type
(
&
self
)
->
WorkerType
{
self
.base_worker
.worker_type
()
}
fn
is_healthy
(
&
self
)
->
bool
{
self
.base_worker
.is_healthy
()
}
fn
set_healthy
(
&
self
,
healthy
:
bool
)
{
self
.base_worker
.set_healthy
(
healthy
);
}
async
fn
check_health_async
(
&
self
)
->
WorkerResult
<
()
>
{
// Use base URL for health checks
let
health_url
=
format!
(
"{}/health"
,
self
.base_url
);
let
timeout
=
std
::
time
::
Duration
::
from_secs
(
self
.base_worker.metadata.health_config.timeout_secs
);
let
health_result
=
async
{
let
response
=
WORKER_CLIENT
.get
(
&
health_url
)
.timeout
(
timeout
)
.send
()
.await
.map_err
(|
e
|
format!
(
"Health check request failed: {}"
,
e
))
?
;
if
response
.status
()
.is_success
()
{
Ok
(())
}
else
{
Err
(
format!
(
"Health check returned status: {}"
,
response
.status
()
))
}
}
.await
;
match
health_result
{
Ok
(())
=>
{
self
.set_healthy
(
true
);
Ok
(())
}
Err
(
reason
)
=>
{
self
.set_healthy
(
false
);
Err
(
WorkerError
::
HealthCheckFailed
{
url
:
self
.base_url
.clone
(),
reason
,
})
}
}
}
fn
load
(
&
self
)
->
usize
{
self
.base_worker
.load
()
}
fn
increment_load
(
&
self
)
{
self
.base_worker
.increment_load
();
}
fn
decrement_load
(
&
self
)
{
self
.base_worker
.decrement_load
();
}
fn
processed_requests
(
&
self
)
->
usize
{
self
.base_worker
.processed_requests
()
}
fn
increment_processed
(
&
self
)
{
self
.base_worker
.increment_processed
();
}
fn
metadata
(
&
self
)
->
&
WorkerMetadata
{
self
.base_worker
.metadata
()
}
fn
clone_worker
(
&
self
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
self
.clone
())
}
// DP-aware specific implementations
fn
is_dp_aware
(
&
self
)
->
bool
{
true
}
fn
base_url
(
&
self
)
->
&
str
{
&
self
.base_url
}
fn
dp_rank
(
&
self
)
->
Option
<
usize
>
{
Some
(
self
.dp_rank
)
}
fn
dp_size
(
&
self
)
->
Option
<
usize
>
{
Some
(
self
.dp_size
)
}
async
fn
prepare_request
(
&
self
,
mut
req
:
serde_json
::
Value
)
->
WorkerResult
<
serde_json
::
Value
>
{
// Inject data_parallel_rank into the request
if
let
Some
(
map
)
=
req
.as_object_mut
()
{
map
.insert
(
"data_parallel_rank"
.to_string
(),
serde_json
::
json!
(
self
.dp_rank
),
);
Ok
(
req
)
}
else
{
Err
(
WorkerError
::
InvalidConfiguration
{
message
:
"Request must be a JSON object for DP-aware routing"
.to_string
(),
})
}
}
fn
endpoint_url
(
&
self
,
route
:
&
str
)
->
String
{
// Use base URL for actual requests
format!
(
"{}{}"
,
self
.base_url
,
route
)
}
}
/// Worker factory for creating workers of different types
pub
struct
WorkerFactory
;
...
...
@@ -318,6 +506,133 @@ impl WorkerFactory {
(
regular_workers
,
prefill_workers
,
decode_workers
)
}
/// Create a DP-aware worker of specified type
pub
fn
create_dp_aware
(
base_url
:
String
,
dp_rank
:
usize
,
dp_size
:
usize
,
worker_type
:
WorkerType
,
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
DPAwareWorker
::
new
(
base_url
,
dp_rank
,
dp_size
,
worker_type
))
}
/// Get DP size from a worker
async
fn
get_worker_dp_size
(
url
:
&
str
,
api_key
:
&
Option
<
String
>
)
->
WorkerResult
<
usize
>
{
let
mut
req_builder
=
WORKER_CLIENT
.get
(
&
format!
(
"{}/get_server_info"
,
url
));
if
let
Some
(
key
)
=
api_key
{
req_builder
=
req_builder
.bearer_auth
(
key
);
}
let
response
=
req_builder
.send
()
.await
.map_err
(|
e
|
WorkerError
::
NetworkError
{
url
:
url
.to_string
(),
error
:
e
.to_string
(),
})
?
;
if
!
response
.status
()
.is_success
()
{
return
Err
(
WorkerError
::
NetworkError
{
url
:
url
.to_string
(),
error
:
format!
(
"Server returned: {}"
,
response
.status
()),
});
}
let
info
:
serde_json
::
Value
=
response
.json
()
.await
.map_err
(|
e
|
WorkerError
::
NetworkError
{
url
:
url
.to_string
(),
error
:
format!
(
"Failed to parse JSON: {}"
,
e
),
})
?
;
let
dp_size
=
info
.get
(
"dp_size"
)
.and_then
(|
v
|
v
.as_u64
())
.ok_or_else
(||
WorkerError
::
InvalidConfiguration
{
message
:
"dp_size not found in server info"
.to_string
(),
})
?
;
if
dp_size
>
usize
::
MAX
as
u64
{
return
Err
(
WorkerError
::
InvalidConfiguration
{
message
:
format!
(
"dp_size is too large: {}"
,
dp_size
),
});
}
Ok
(
dp_size
as
usize
)
}
/// Private helper to create DP-aware workers of any type
async
fn
create_dp_aware_workers_of_type
(
url
:
&
str
,
api_key
:
&
Option
<
String
>
,
worker_type
:
WorkerType
,
)
->
WorkerResult
<
Vec
<
Box
<
dyn
Worker
>>>
{
let
dp_size
=
Self
::
get_worker_dp_size
(
url
,
api_key
)
.await
?
;
let
workers
=
(
0
..
dp_size
)
.map
(|
rank
|
Self
::
create_dp_aware
(
url
.to_string
(),
rank
,
dp_size
,
worker_type
.clone
()))
.collect
();
Ok
(
workers
)
}
/// Create DP-aware regular workers from a single URL
pub
async
fn
create_dp_aware_regular_workers
(
url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
WorkerResult
<
Vec
<
Box
<
dyn
Worker
>>>
{
Self
::
create_dp_aware_workers_of_type
(
url
,
api_key
,
WorkerType
::
Regular
)
.await
}
/// Create DP-aware prefill workers from a single URL
pub
async
fn
create_dp_aware_prefill_workers
(
url
:
&
str
,
bootstrap_port
:
Option
<
u16
>
,
api_key
:
&
Option
<
String
>
,
)
->
WorkerResult
<
Vec
<
Box
<
dyn
Worker
>>>
{
Self
::
create_dp_aware_workers_of_type
(
url
,
api_key
,
WorkerType
::
Prefill
{
bootstrap_port
})
.await
}
/// Create DP-aware decode workers from a single URL
pub
async
fn
create_dp_aware_decode_workers
(
url
:
&
str
,
api_key
:
&
Option
<
String
>
,
)
->
WorkerResult
<
Vec
<
Box
<
dyn
Worker
>>>
{
Self
::
create_dp_aware_workers_of_type
(
url
,
api_key
,
WorkerType
::
Decode
)
.await
}
/// Create workers based on configuration (for regular router)
pub
async
fn
create_workers
(
urls
:
Vec
<
String
>
,
dp_aware
:
bool
,
api_key
:
&
Option
<
String
>
,
)
->
WorkerResult
<
Vec
<
Box
<
dyn
Worker
>>>
{
if
dp_aware
{
// Create futures for all worker creations
let
worker_futs
=
urls
.iter
()
.map
(|
url
|
Self
::
create_dp_aware_regular_workers
(
url
,
api_key
));
// Execute all futures concurrently and flatten results
let
all_workers
=
futures
::
future
::
try_join_all
(
worker_futs
)
.await
?
.into_iter
()
.flatten
()
.collect
();
Ok
(
all_workers
)
}
else
{
Ok
(
urls
.into_iter
()
.map
(|
url
|
Self
::
create_regular
(
url
))
.collect
())
}
}
}
/// Helper trait for collections of workers
...
...
@@ -1086,4 +1401,245 @@ mod tests {
// Should be well over 1M ops/sec
assert
!
(
ops_per_sec
>
1_000_000.0
);
}
// ===== Tests for DPAwareWorker =====
#[test]
fn
test_dp_aware_worker_creation
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
2
,
4
,
WorkerType
::
Regular
);
assert_eq!
(
dp_worker
.url
(),
"http://worker1:8080@2"
);
assert_eq!
(
dp_worker
.base_url
(),
"http://worker1:8080"
);
assert
!
(
dp_worker
.is_dp_aware
());
assert_eq!
(
dp_worker
.dp_rank
(),
Some
(
2
));
assert_eq!
(
dp_worker
.dp_size
(),
Some
(
4
));
assert_eq!
(
dp_worker
.worker_type
(),
WorkerType
::
Regular
);
}
#[test]
fn
test_dp_aware_worker_creation_prefill
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
1
,
2
,
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
9090
),
},
);
assert_eq!
(
dp_worker
.url
(),
"http://worker1:8080@1"
);
assert
!
(
dp_worker
.is_dp_aware
());
assert_eq!
(
dp_worker
.worker_type
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
9090
)
}
);
}
#[test]
fn
test_dp_aware_worker_creation_decode
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
0
,
4
,
WorkerType
::
Decode
);
assert_eq!
(
dp_worker
.url
(),
"http://worker1:8080@0"
);
assert
!
(
dp_worker
.is_dp_aware
());
assert_eq!
(
dp_worker
.worker_type
(),
WorkerType
::
Decode
);
}
#[tokio::test]
async
fn
test_dp_aware_prepare_request
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
3
,
8
,
WorkerType
::
Regular
);
let
original_req
=
serde_json
::
json!
({
"prompt"
:
"Hello"
,
"max_tokens"
:
100
});
let
prepared_req
=
dp_worker
.prepare_request
(
original_req
)
.await
.unwrap
();
assert_eq!
(
prepared_req
[
"prompt"
],
"Hello"
);
assert_eq!
(
prepared_req
[
"max_tokens"
],
100
);
assert_eq!
(
prepared_req
[
"data_parallel_rank"
],
3
);
}
#[tokio::test]
async
fn
test_dp_aware_prepare_request_invalid
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
0
,
4
,
WorkerType
::
Regular
);
// Non-object JSON should fail
let
invalid_req
=
serde_json
::
json!
(
"not an object"
);
let
result
=
dp_worker
.prepare_request
(
invalid_req
)
.await
;
assert
!
(
result
.is_err
());
match
result
.unwrap_err
()
{
WorkerError
::
InvalidConfiguration
{
message
}
=>
{
assert
!
(
message
.contains
(
"JSON object"
));
}
_
=>
panic!
(
"Expected InvalidConfiguration error"
),
}
}
#[test]
fn
test_dp_aware_endpoint_url
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
1
,
4
,
WorkerType
::
Regular
);
assert_eq!
(
dp_worker
.endpoint_url
(
"/generate"
),
"http://worker1:8080/generate"
);
assert_eq!
(
dp_worker
.endpoint_url
(
"/health"
),
"http://worker1:8080/health"
);
}
#[test]
fn
test_dp_aware_worker_delegated_methods
()
{
let
dp_worker
=
DPAwareWorker
::
new
(
"http://worker1:8080"
.to_string
(),
0
,
2
,
WorkerType
::
Regular
);
// Test health status
assert
!
(
dp_worker
.is_healthy
());
dp_worker
.set_healthy
(
false
);
assert
!
(
!
dp_worker
.is_healthy
());
// Test load tracking
assert_eq!
(
dp_worker
.load
(),
0
);
dp_worker
.increment_load
();
assert_eq!
(
dp_worker
.load
(),
1
);
dp_worker
.decrement_load
();
assert_eq!
(
dp_worker
.load
(),
0
);
// Test processed tracking
assert_eq!
(
dp_worker
.processed_requests
(),
0
);
dp_worker
.increment_processed
();
assert_eq!
(
dp_worker
.processed_requests
(),
1
);
}
// ===== Tests for WorkerFactory async methods =====
#[tokio::test]
async
fn
test_factory_create_dp_aware
()
{
let
worker
=
WorkerFactory
::
create_dp_aware
(
"http://worker1:8080"
.to_string
(),
1
,
4
,
WorkerType
::
Regular
,
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080@1"
);
assert
!
(
worker
.is_dp_aware
());
assert_eq!
(
worker
.dp_rank
(),
Some
(
1
));
assert_eq!
(
worker
.dp_size
(),
Some
(
4
));
assert_eq!
(
worker
.worker_type
(),
WorkerType
::
Regular
);
}
#[tokio::test]
async
fn
test_factory_create_dp_aware_prefill
()
{
let
worker
=
WorkerFactory
::
create_dp_aware
(
"http://worker1:8080"
.to_string
(),
0
,
2
,
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8090
),
},
);
assert_eq!
(
worker
.url
(),
"http://worker1:8080@0"
);
assert
!
(
worker
.is_dp_aware
());
assert_eq!
(
worker
.worker_type
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
8090
)
}
);
}
#[tokio::test]
async
fn
test_factory_create_workers_regular
()
{
let
urls
=
vec!
[
"http://w1:8080"
.to_string
(),
"http://w2:8080"
.to_string
()];
let
workers
=
WorkerFactory
::
create_workers
(
urls
,
false
,
&
None
)
.await
.unwrap
();
assert_eq!
(
workers
.len
(),
2
);
assert
!
(
!
workers
[
0
]
.is_dp_aware
());
assert
!
(
!
workers
[
1
]
.is_dp_aware
());
assert_eq!
(
workers
[
0
]
.url
(),
"http://w1:8080"
);
assert_eq!
(
workers
[
1
]
.url
(),
"http://w2:8080"
);
}
// ===== Integration tests =====
#[tokio::test]
async
fn
test_mixed_worker_types
()
{
// Create a mix of worker types
let
regular
=
WorkerFactory
::
create_regular
(
"http://regular:8080"
.to_string
());
let
prefill
=
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9090
));
let
decode
=
WorkerFactory
::
create_decode
(
"http://decode:8080"
.to_string
());
let
dp_aware_regular
=
WorkerFactory
::
create_dp_aware
(
"http://dp:8080"
.to_string
(),
0
,
2
,
WorkerType
::
Regular
);
let
dp_aware_prefill
=
WorkerFactory
::
create_dp_aware
(
"http://dp-prefill:8080"
.to_string
(),
1
,
2
,
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
},
);
let
dp_aware_decode
=
WorkerFactory
::
create_dp_aware
(
"http://dp-decode:8080"
.to_string
(),
0
,
4
,
WorkerType
::
Decode
,
);
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
regular
,
prefill
,
decode
,
dp_aware_regular
,
dp_aware_prefill
,
dp_aware_decode
,
];
// Test that they all implement Worker trait properly
for
worker
in
&
workers
{
assert
!
(
worker
.is_healthy
());
assert_eq!
(
worker
.load
(),
0
);
assert_eq!
(
worker
.processed_requests
(),
0
);
}
// Test specific behaviors
assert
!
(
!
workers
[
0
]
.is_dp_aware
());
// regular
assert
!
(
!
workers
[
1
]
.is_dp_aware
());
// prefill
assert
!
(
!
workers
[
2
]
.is_dp_aware
());
// decode
assert
!
(
workers
[
3
]
.is_dp_aware
());
// dp_aware_regular
assert
!
(
workers
[
4
]
.is_dp_aware
());
// dp_aware_prefill
assert
!
(
workers
[
5
]
.is_dp_aware
());
// dp_aware_decode
// Test worker types
assert_eq!
(
workers
[
0
]
.worker_type
(),
WorkerType
::
Regular
);
assert_eq!
(
workers
[
1
]
.worker_type
(),
WorkerType
::
Prefill
{
bootstrap_port
:
Some
(
9090
)
}
);
assert_eq!
(
workers
[
2
]
.worker_type
(),
WorkerType
::
Decode
);
assert_eq!
(
workers
[
3
]
.worker_type
(),
WorkerType
::
Regular
);
assert_eq!
(
workers
[
4
]
.worker_type
(),
WorkerType
::
Prefill
{
bootstrap_port
:
None
}
);
assert_eq!
(
workers
[
5
]
.worker_type
(),
WorkerType
::
Decode
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment