Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f2d5c492
Unverified
Commit
f2d5c492
authored
Jul 11, 2025
by
Simo Lin
Committed by
GitHub
Jul 11, 2025
Browse files
[router] add worker abstraction (#7960)
parent
2a2d3478
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
961 additions
and
411 deletions
+961
-411
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+2
-0
sgl-router/src/core/error.rs
sgl-router/src/core/error.rs
+57
-0
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+16
-0
sgl-router/src/core/worker.rs
sgl-router/src/core/worker.rs
+454
-0
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+1
-0
sgl-router/src/pd_router.rs
sgl-router/src/pd_router.rs
+111
-131
sgl-router/src/pd_types.rs
sgl-router/src/pd_types.rs
+28
-49
sgl-router/src/router.rs
sgl-router/src/router.rs
+206
-171
sgl-router/src/server.rs
sgl-router/src/server.rs
+2
-3
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+4
-5
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+80
-52
No files found.
sgl-router/Cargo.toml
View file @
f2d5c492
...
...
@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
kube
=
{
version
=
"0.88.1"
,
features
=
[
"runtime"
,
"derive"
]
}
k8s-openapi
=
{
version
=
"0.21.0"
,
features
=
["v1_29"]
}
futures
=
"0.3"
async-trait
=
"0.1"
once_cell
=
"1.21"
# Added for metrics
metrics
=
"0.24.2"
metrics-exporter-prometheus
=
"0.17.0"
...
...
sgl-router/src/core/error.rs
0 → 100644
View file @
f2d5c492
//! Error types for the SGLang router core
//!
//! This module defines error types used throughout the router for worker operations.
use
std
::
fmt
;
/// Worker-related errors
#[derive(Debug)]
pub
enum
WorkerError
{
/// Health check failed
HealthCheckFailed
{
url
:
String
,
reason
:
String
},
/// Worker not found
WorkerNotFound
{
url
:
String
},
/// Invalid worker configuration
InvalidConfiguration
{
message
:
String
},
/// Network error
NetworkError
{
url
:
String
,
error
:
String
},
/// Worker is at capacity
WorkerAtCapacity
{
url
:
String
},
}
impl
fmt
::
Display
for
WorkerError
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
WorkerError
::
HealthCheckFailed
{
url
,
reason
}
=>
{
write!
(
f
,
"Health check failed for worker {}: {}"
,
url
,
reason
)
}
WorkerError
::
WorkerNotFound
{
url
}
=>
{
write!
(
f
,
"Worker not found: {}"
,
url
)
}
WorkerError
::
InvalidConfiguration
{
message
}
=>
{
write!
(
f
,
"Invalid worker configuration: {}"
,
message
)
}
WorkerError
::
NetworkError
{
url
,
error
}
=>
{
write!
(
f
,
"Network error for worker {}: {}"
,
url
,
error
)
}
WorkerError
::
WorkerAtCapacity
{
url
}
=>
{
write!
(
f
,
"Worker at capacity: {}"
,
url
)
}
}
}
}
impl
std
::
error
::
Error
for
WorkerError
{}
/// Result type for worker operations
pub
type
WorkerResult
<
T
>
=
Result
<
T
,
WorkerError
>
;
/// Convert from reqwest errors to worker errors
impl
From
<
reqwest
::
Error
>
for
WorkerError
{
fn
from
(
err
:
reqwest
::
Error
)
->
Self
{
WorkerError
::
NetworkError
{
url
:
err
.url
()
.map
(|
u
|
u
.to_string
())
.unwrap_or_default
(),
error
:
err
.to_string
(),
}
}
}
sgl-router/src/core/mod.rs
0 → 100644
View file @
f2d5c492
//! Core abstractions for the SGLang router
//!
//! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations
//! - Error types
//! - Common utilities
pub
mod
error
;
pub
mod
worker
;
// Re-export commonly used types at the module level
pub
use
error
::{
WorkerError
,
WorkerResult
};
pub
use
worker
::{
start_health_checker
,
BasicWorker
,
HealthChecker
,
Worker
,
WorkerCollection
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
};
sgl-router/src/core/worker.rs
0 → 100644
View file @
f2d5c492
use
super
::{
WorkerError
,
WorkerResult
};
use
async_trait
::
async_trait
;
use
once_cell
::
sync
::
Lazy
;
use
std
::
fmt
;
use
std
::
sync
::
atomic
::{
AtomicBool
,
AtomicUsize
,
Ordering
};
use
std
::
sync
::
Arc
;
// Shared HTTP client for health checks
static
HEALTH_CHECK_CLIENT
:
Lazy
<
reqwest
::
Client
>
=
Lazy
::
new
(||
{
reqwest
::
Client
::
builder
()
.timeout
(
std
::
time
::
Duration
::
from_secs
(
30
))
// Default timeout, overridden per request
.build
()
.expect
(
"Failed to create health check HTTP client"
)
});
/// Core worker abstraction that represents a backend service
#[async_trait]
pub
trait
Worker
:
Send
+
Sync
+
fmt
::
Debug
{
/// Get the worker's URL
fn
url
(
&
self
)
->
&
str
;
/// Get the worker's type (Regular, Prefill, or Decode)
fn
worker_type
(
&
self
)
->
WorkerType
;
/// Check if the worker is currently healthy
fn
is_healthy
(
&
self
)
->
bool
;
/// Set the worker's health status
fn
set_healthy
(
&
self
,
healthy
:
bool
);
/// Perform an async health check on the worker
async
fn
check_health_async
(
&
self
)
->
WorkerResult
<
()
>
;
/// Synchronous health check wrapper (for compatibility)
fn
check_health
(
&
self
)
->
WorkerResult
<
()
>
{
// Use a small runtime for synchronous contexts
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.build
()
.map_err
(|
e
|
WorkerError
::
HealthCheckFailed
{
url
:
self
.url
()
.to_string
(),
reason
:
format!
(
"Failed to create runtime: {}"
,
e
),
})
?
.block_on
(
self
.check_health_async
())
}
/// Get the current load (number of active requests)
fn
load
(
&
self
)
->
usize
;
/// Increment the load counter
fn
increment_load
(
&
self
);
/// Decrement the load counter
fn
decrement_load
(
&
self
);
/// Get the number of processed requests
fn
processed_requests
(
&
self
)
->
usize
;
/// Increment the processed requests counter
fn
increment_processed
(
&
self
);
/// Get worker-specific metadata
fn
metadata
(
&
self
)
->
&
WorkerMetadata
;
/// Clone the worker (for trait objects)
fn
clone_worker
(
&
self
)
->
Box
<
dyn
Worker
>
;
}
/// Worker type classification
#[derive(Debug,
Clone,
PartialEq,
Eq,
Hash)]
pub
enum
WorkerType
{
/// Regular worker for standard routing
Regular
,
/// Prefill worker for PD disaggregated mode
Prefill
{
/// Bootstrap port for communication with decode workers
bootstrap_port
:
Option
<
u16
>
,
},
/// Decode worker for PD disaggregated mode
Decode
,
}
impl
fmt
::
Display
for
WorkerType
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
WorkerType
::
Regular
=>
write!
(
f
,
"Regular"
),
WorkerType
::
Prefill
{
bootstrap_port
}
=>
match
bootstrap_port
{
Some
(
port
)
=>
write!
(
f
,
"Prefill(bootstrap:{})"
,
port
),
None
=>
write!
(
f
,
"Prefill"
),
},
WorkerType
::
Decode
=>
write!
(
f
,
"Decode"
),
}
}
}
/// Health check configuration
#[derive(Debug,
Clone)]
pub
struct
HealthConfig
{
/// Timeout for health checks in seconds
pub
timeout_secs
:
u64
,
/// Interval between health checks in seconds
pub
check_interval_secs
:
u64
,
/// Health check endpoint path
pub
endpoint
:
String
,
}
impl
Default
for
HealthConfig
{
fn
default
()
->
Self
{
Self
{
timeout_secs
:
5
,
check_interval_secs
:
30
,
endpoint
:
"/health"
.to_string
(),
}
}
}
/// Metadata associated with a worker
#[derive(Debug,
Clone)]
pub
struct
WorkerMetadata
{
/// Worker URL
pub
url
:
String
,
/// Worker type
pub
worker_type
:
WorkerType
,
/// Additional labels/tags
pub
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
,
/// Health check configuration
pub
health_config
:
HealthConfig
,
}
/// Basic worker implementation
#[derive(Debug,
Clone)]
pub
struct
BasicWorker
{
metadata
:
WorkerMetadata
,
load_counter
:
Arc
<
AtomicUsize
>
,
processed_counter
:
Arc
<
AtomicUsize
>
,
healthy
:
Arc
<
AtomicBool
>
,
}
impl
BasicWorker
{
pub
fn
new
(
url
:
String
,
worker_type
:
WorkerType
)
->
Self
{
let
metadata
=
WorkerMetadata
{
url
:
url
.clone
(),
worker_type
,
labels
:
std
::
collections
::
HashMap
::
new
(),
health_config
:
HealthConfig
::
default
(),
};
Self
{
metadata
,
load_counter
:
Arc
::
new
(
AtomicUsize
::
new
(
0
)),
processed_counter
:
Arc
::
new
(
AtomicUsize
::
new
(
0
)),
healthy
:
Arc
::
new
(
AtomicBool
::
new
(
true
)),
}
}
pub
fn
with_labels
(
mut
self
,
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
)
->
Self
{
self
.metadata.labels
=
labels
;
self
}
pub
fn
with_health_config
(
mut
self
,
config
:
HealthConfig
)
->
Self
{
self
.metadata.health_config
=
config
;
self
}
}
#[async_trait]
impl
Worker
for
BasicWorker
{
fn
url
(
&
self
)
->
&
str
{
&
self
.metadata.url
}
fn
worker_type
(
&
self
)
->
WorkerType
{
self
.metadata.worker_type
.clone
()
}
fn
is_healthy
(
&
self
)
->
bool
{
self
.healthy
.load
(
Ordering
::
Acquire
)
}
fn
set_healthy
(
&
self
,
healthy
:
bool
)
{
self
.healthy
.store
(
healthy
,
Ordering
::
Release
);
}
async
fn
check_health_async
(
&
self
)
->
WorkerResult
<
()
>
{
use
std
::
time
::
Duration
;
// Perform actual HTTP health check
let
health_url
=
format!
(
"{}{}"
,
self
.url
(),
self
.metadata.health_config.endpoint
);
let
timeout
=
Duration
::
from_secs
(
self
.metadata.health_config.timeout_secs
);
// Use the shared client with a custom timeout for this request
match
HEALTH_CHECK_CLIENT
.get
(
&
health_url
)
.timeout
(
timeout
)
.send
()
.await
{
Ok
(
response
)
=>
{
if
response
.status
()
.is_success
()
{
self
.set_healthy
(
true
);
Ok
(())
}
else
{
self
.set_healthy
(
false
);
Err
(
WorkerError
::
HealthCheckFailed
{
url
:
self
.url
()
.to_string
(),
reason
:
format!
(
"Health check returned status: {}"
,
response
.status
()),
})
}
}
Err
(
e
)
=>
{
self
.set_healthy
(
false
);
Err
(
WorkerError
::
HealthCheckFailed
{
url
:
self
.url
()
.to_string
(),
reason
:
format!
(
"Health check request failed: {}"
,
e
),
})
}
}
}
fn
load
(
&
self
)
->
usize
{
self
.load_counter
.load
(
Ordering
::
Relaxed
)
}
fn
increment_load
(
&
self
)
{
self
.load_counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
fn
decrement_load
(
&
self
)
{
self
.load_counter
.fetch_update
(
Ordering
::
Relaxed
,
Ordering
::
Relaxed
,
|
current
|
{
current
.checked_sub
(
1
)
})
.ok
();
}
fn
processed_requests
(
&
self
)
->
usize
{
self
.processed_counter
.load
(
Ordering
::
Relaxed
)
}
fn
increment_processed
(
&
self
)
{
self
.processed_counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
fn
metadata
(
&
self
)
->
&
WorkerMetadata
{
&
self
.metadata
}
fn
clone_worker
(
&
self
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
self
.clone
())
}
}
/// Worker factory for creating workers of different types
pub
struct
WorkerFactory
;
impl
WorkerFactory
{
/// Create a regular worker
pub
fn
create_regular
(
url
:
String
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
BasicWorker
::
new
(
url
,
WorkerType
::
Regular
))
}
/// Create a prefill worker with optional bootstrap port
pub
fn
create_prefill
(
url
:
String
,
bootstrap_port
:
Option
<
u16
>
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
BasicWorker
::
new
(
url
,
WorkerType
::
Prefill
{
bootstrap_port
},
))
}
/// Create a decode worker
pub
fn
create_decode
(
url
:
String
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
BasicWorker
::
new
(
url
,
WorkerType
::
Decode
))
}
/// Create workers from URLs with automatic type detection
pub
fn
create_from_urls
(
regular_urls
:
Vec
<
String
>
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
decode_urls
:
Vec
<
String
>
,
)
->
(
Vec
<
Box
<
dyn
Worker
>>
,
Vec
<
Box
<
dyn
Worker
>>
,
Vec
<
Box
<
dyn
Worker
>>
,
)
{
let
regular_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
regular_urls
.into_iter
()
.map
(
Self
::
create_regular
)
.collect
();
let
prefill_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
prefill_urls
.into_iter
()
.map
(|(
url
,
port
)|
Self
::
create_prefill
(
url
,
port
))
.collect
();
let
decode_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
decode_urls
.into_iter
()
.map
(
Self
::
create_decode
)
.collect
();
(
regular_workers
,
prefill_workers
,
decode_workers
)
}
}
/// Helper trait for collections of workers
pub
trait
WorkerCollection
{
fn
healthy_workers
(
&
self
)
->
Vec
<&
dyn
Worker
>
;
fn
total_load
(
&
self
)
->
usize
;
fn
find_worker
(
&
self
,
url
:
&
str
)
->
Option
<&
dyn
Worker
>
;
fn
find_worker_mut
(
&
mut
self
,
url
:
&
str
)
->
Option
<&
mut
Box
<
dyn
Worker
>>
;
}
impl
WorkerCollection
for
Vec
<
Box
<
dyn
Worker
>>
{
fn
healthy_workers
(
&
self
)
->
Vec
<&
dyn
Worker
>
{
self
.iter
()
.filter
(|
w
|
w
.is_healthy
())
.map
(|
w
|
w
.as_ref
())
.collect
()
}
fn
total_load
(
&
self
)
->
usize
{
self
.iter
()
.map
(|
w
|
w
.load
())
.sum
()
}
fn
find_worker
(
&
self
,
url
:
&
str
)
->
Option
<&
dyn
Worker
>
{
self
.iter
()
.find
(|
w
|
w
.url
()
==
url
)
.map
(|
w
|
w
.as_ref
())
}
fn
find_worker_mut
(
&
mut
self
,
url
:
&
str
)
->
Option
<&
mut
Box
<
dyn
Worker
>>
{
self
.iter_mut
()
.find
(|
w
|
w
.url
()
==
url
)
}
}
/// Convert a list of worker URLs to worker trait objects
pub
fn
urls_to_workers
(
urls
:
Vec
<
String
>
)
->
Vec
<
Box
<
dyn
Worker
>>
{
urls
.into_iter
()
.map
(
WorkerFactory
::
create_regular
)
.collect
()
}
/// Convert worker trait objects back to URLs
pub
fn
workers_to_urls
(
workers
:
&
[
Box
<
dyn
Worker
>
])
->
Vec
<
String
>
{
workers
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
()
}
/// RAII guard for worker load management
pub
struct
WorkerLoadGuard
<
'a
>
{
workers
:
Vec
<&
'a
dyn
Worker
>
,
}
impl
<
'a
>
WorkerLoadGuard
<
'a
>
{
/// Create a new load guard for a single worker
pub
fn
new
(
worker
:
&
'a
dyn
Worker
)
->
Self
{
worker
.increment_load
();
Self
{
workers
:
vec!
[
worker
],
}
}
/// Create a new load guard for multiple workers
pub
fn
new_multi
(
workers
:
Vec
<&
'a
dyn
Worker
>
)
->
Self
{
// Increment load counters for all workers
for
worker
in
&
workers
{
worker
.increment_load
();
}
Self
{
workers
}
}
}
impl
<
'a
>
Drop
for
WorkerLoadGuard
<
'a
>
{
fn
drop
(
&
mut
self
)
{
// Decrement load counters for all workers
for
worker
in
&
self
.workers
{
worker
.decrement_load
();
}
}
}
/// Health checker handle with graceful shutdown
pub
struct
HealthChecker
{
handle
:
tokio
::
task
::
JoinHandle
<
()
>
,
shutdown
:
Arc
<
AtomicBool
>
,
}
impl
fmt
::
Debug
for
HealthChecker
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
f
.debug_struct
(
"HealthChecker"
)
.field
(
"shutdown"
,
&
self
.shutdown
.load
(
Ordering
::
Relaxed
))
.finish
()
}
}
impl
HealthChecker
{
/// Shutdown the health checker gracefully
pub
async
fn
shutdown
(
self
)
{
self
.shutdown
.store
(
true
,
Ordering
::
Release
);
let
_
=
self
.handle
.await
;
}
}
/// Start an async background health checker for a collection of workers
pub
fn
start_health_checker
(
workers
:
std
::
sync
::
Arc
<
std
::
sync
::
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
check_interval_secs
:
u64
,
)
->
HealthChecker
{
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown_clone
=
shutdown
.clone
();
let
handle
=
tokio
::
spawn
(
async
move
{
let
mut
interval
=
tokio
::
time
::
interval
(
tokio
::
time
::
Duration
::
from_secs
(
check_interval_secs
));
loop
{
interval
.tick
()
.await
;
// Check for shutdown signal
if
shutdown_clone
.load
(
Ordering
::
Acquire
)
{
tracing
::
info!
(
"Health checker shutting down"
);
break
;
}
// Check health of all workers
let
workers_to_check
=
match
workers
.read
()
{
Ok
(
guard
)
=>
guard
.iter
()
.map
(|
w
|
w
.clone_worker
())
.collect
::
<
Vec
<
_
>>
(),
Err
(
poisoned
)
=>
{
tracing
::
error!
(
"Worker lock poisoned: {}"
,
poisoned
);
continue
;
}
};
// Perform health checks concurrently
let
health_checks
=
workers_to_check
.iter
()
.map
(|
worker
|
{
let
worker_url
=
worker
.url
()
.to_string
();
let
was_healthy
=
worker
.is_healthy
();
async
move
{
match
worker
.check_health_async
()
.await
{
Ok
(
_
)
=>
{
if
!
was_healthy
{
tracing
::
info!
(
"Worker {} is now healthy"
,
worker_url
);
}
}
Err
(
e
)
=>
{
if
was_healthy
{
tracing
::
warn!
(
"Worker {} health check failed: {}"
,
worker_url
,
e
);
}
}
}
}
});
// Execute all health checks concurrently
futures
::
future
::
join_all
(
health_checks
)
.await
;
}
});
HealthChecker
{
handle
,
shutdown
}
}
sgl-router/src/lib.rs
View file @
f2d5c492
...
...
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub
mod
config
;
pub
mod
logging
;
use
std
::
collections
::
HashMap
;
pub
mod
core
;
pub
mod
openai_api_types
;
pub
mod
pd_router
;
pub
mod
pd_types
;
...
...
sgl-router/src/pd_router.rs
View file @
f2d5c492
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
use
crate
::
pd_types
::{
Bootstrap
,
ChatReqInput
,
EngineInfo
,
GenerateReqInput
,
PDRouterError
,
PDSelectionPolicy
,
api_path
,
Bootstrap
,
ChatReqInput
,
GenerateReqInput
,
PDRouterError
,
PDSelectionPolicy
,
};
use
crate
::
tree
::
Tree
;
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
...
...
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
use
metrics
::{
counter
,
histogram
};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
time
::{
Duration
,
Instant
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
...
...
@@ -21,49 +21,17 @@ use uuid::Uuid;
#[derive(Debug)]
pub
struct
PDRouter
{
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
EngineInfo
>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
EngineInfo
>>>
,
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
pub
selection_policy
:
PDSelectionPolicy
,
pub
load_tracking
:
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
timeout_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
pub
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
pub
http_client
:
reqwest
::
Client
,
}
// RAII guard for load tracking to ensure cleanup even on panic
struct
LoadGuard
<
'a
>
{
tracking
:
&
'a
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
urls
:
Vec
<
String
>
,
}
impl
<
'a
>
LoadGuard
<
'a
>
{
fn
new
(
tracking
:
&
'a
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
urls
:
Vec
<
String
>
,
)
->
Self
{
// Increment counters
for
url
in
&
urls
{
let
counter
=
tracking
.entry
(
url
.clone
())
.or_insert_with
(||
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
LoadGuard
{
tracking
,
urls
}
}
}
impl
Drop
for
LoadGuard
<
'_
>
{
fn
drop
(
&
mut
self
)
{
// Guaranteed cleanup even on panic
for
url
in
&
self
.urls
{
if
let
Some
(
counter
)
=
self
.tracking
.get
(
url
)
{
counter
.fetch_sub
(
1
,
Ordering
::
Relaxed
);
}
}
}
_
prefill_health_checker
:
Option
<
HealthChecker
>
,
_
decode_health_checker
:
Option
<
HealthChecker
>
,
}
impl
PDRouter
{
...
...
@@ -73,9 +41,6 @@ impl PDRouter {
url
:
String
,
bootstrap_port
:
Option
<
u16
>
,
)
->
Result
<
String
,
PDRouterError
>
{
// Create EngineInfo for the new prefill server
let
engine_info
=
EngineInfo
::
new_prefill
(
url
.clone
(),
bootstrap_port
);
// Wait for the new server to be healthy
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
[
url
.clone
()],
...
...
@@ -84,6 +49,9 @@ impl PDRouter {
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
// Create Worker for the new prefill server
let
worker
=
WorkerFactory
::
create_prefill
(
url
.clone
(),
bootstrap_port
);
// Add to prefill workers list
let
mut
workers
=
self
.prefill_workers
...
...
@@ -93,15 +61,11 @@ impl PDRouter {
})
?
;
// Check if already exists
if
workers
.iter
()
.any
(|
w
|
w
.url
==
url
)
{
if
workers
.iter
()
.any
(|
w
|
w
.url
()
==
&
url
)
{
return
Err
(
PDRouterError
::
WorkerAlreadyExists
{
url
:
url
.clone
()
});
}
workers
.push
(
engine_info
);
// Initialize load tracking
self
.load_tracking
.insert
(
url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
workers
.push
(
worker
);
// Add to cache tree if using cache-aware policy
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
...
...
@@ -113,9 +77,6 @@ impl PDRouter {
}
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
)
->
Result
<
String
,
PDRouterError
>
{
// Create EngineInfo for the new decode server
let
engine_info
=
EngineInfo
::
new_decode
(
url
.clone
());
// Wait for the new server to be healthy
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
[
url
.clone
()],
...
...
@@ -124,6 +85,9 @@ impl PDRouter {
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
// Create Worker for the new decode server
let
worker
=
WorkerFactory
::
create_decode
(
url
.clone
());
// Add to decode workers list
let
mut
workers
=
self
.decode_workers
...
...
@@ -133,15 +97,14 @@ impl PDRouter {
})
?
;
// Check if already exists
if
workers
.iter
()
.any
(|
w
|
w
.url
==
url
)
{
if
workers
.iter
()
.any
(|
w
|
w
.url
()
==
&
url
)
{
return
Err
(
PDRouterError
::
WorkerAlreadyExists
{
url
:
url
.clone
()
});
}
workers
.push
(
engine_info
);
workers
.push
(
worker
);
// Initialize load tracking
self
.load_tracking
.insert
(
url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
// Worker tracks its own load internally
info!
(
"Added decode server: {}"
,
url
);
Ok
(
format!
(
"Successfully added decode server: {}"
,
url
))
...
...
@@ -157,7 +120,7 @@ impl PDRouter {
// Find and remove the server
let
initial_len
=
workers
.len
();
workers
.retain
(|
w
|
w
.url
!=
url
);
workers
.retain
(|
w
|
w
.url
()
!=
url
);
if
workers
.len
()
==
initial_len
{
return
Err
(
PDRouterError
::
WorkerNotFound
{
...
...
@@ -166,7 +129,7 @@ impl PDRouter {
}
// Remove from load tracking
self
.
load
_
tracking
.remove
(
url
);
// Worker
load
tracking
is internal
// Remove from cache tree if using cache-aware policy
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
...
...
@@ -174,7 +137,7 @@ impl PDRouter {
let
mut
tree_guard
=
tree
.lock
()
.unwrap
();
*
tree_guard
=
Tree
::
new
();
for
worker
in
workers
.iter
()
{
tree_guard
.insert
(
""
,
&
worker
.url
);
tree_guard
.insert
(
""
,
worker
.url
()
);
}
}
...
...
@@ -192,7 +155,7 @@ impl PDRouter {
// Find and remove the server
let
initial_len
=
workers
.len
();
workers
.retain
(|
w
|
w
.url
!=
url
);
workers
.retain
(|
w
|
w
.url
()
!=
url
);
if
workers
.len
()
==
initial_len
{
return
Err
(
PDRouterError
::
WorkerNotFound
{
...
...
@@ -200,9 +163,6 @@ impl PDRouter {
});
}
// Remove from load tracking
self
.load_tracking
.remove
(
url
);
info!
(
"Removed decode server: {}"
,
url
);
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
}
...
...
@@ -214,41 +174,32 @@ impl PDRouter {
timeout_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
Self
,
String
>
{
// Convert URLs to
EngineInfo
let
prefill_workers
:
Vec
<
EngineInfo
>
=
prefill_urls
// Convert URLs to
Worker trait objects
let
prefill_workers
:
Vec
<
Box
<
dyn
Worker
>
>
=
prefill_urls
.into_iter
()
.map
(|(
url
,
port
)|
EngineInfo
::
new
_prefill
(
url
,
port
))
.map
(|(
url
,
port
)|
WorkerFactory
::
create
_prefill
(
url
,
port
))
.collect
();
let
decode_workers
:
Vec
<
EngineInfo
>
=
decode_urls
let
decode_workers
:
Vec
<
Box
<
dyn
Worker
>
>
=
decode_urls
.into_iter
()
.map
(
EngineInfo
::
new
_decode
)
.map
(
WorkerFactory
::
create
_decode
)
.collect
();
// Wait for PD workers to be healthy
let
all_urls
:
Vec
<
String
>
=
prefill_workers
.iter
()
.chain
(
decode_workers
.iter
())
.map
(|
engine
|
engine
.url
.clone
())
.map
(|
worker
|
worker
.url
()
.to_string
())
.collect
();
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
all_urls
,
timeout_secs
,
interval_secs
)
?
;
// Initialize load tracking with atomic counters
let
load_tracking
=
Arc
::
new
(
dashmap
::
DashMap
::
new
());
for
engine
in
&
prefill_workers
{
load_tracking
.insert
(
engine
.url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
}
for
engine
in
&
decode_workers
{
load_tracking
.insert
(
engine
.url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
}
// Initialize cache-aware components if needed
let
prefill_tree
=
match
&
selection_policy
{
PDSelectionPolicy
::
CacheAware
{
..
}
=>
{
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
// Initialize tree with prefill workers
for
engine
in
&
prefill_workers
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
engine
.url
);
for
worker
in
&
prefill_workers
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker
.url
()
);
}
Some
(
tree
)
}
...
...
@@ -283,17 +234,27 @@ impl PDRouter {
None
};
let
prefill_workers
=
Arc
::
new
(
RwLock
::
new
(
prefill_workers
));
let
decode_workers
=
Arc
::
new
(
RwLock
::
new
(
decode_workers
));
// Start health checkers for both worker pools
let
prefill_health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
prefill_workers
),
interval_secs
);
let
decode_health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
decode_workers
),
interval_secs
);
Ok
(
PDRouter
{
prefill_workers
:
Arc
::
new
(
RwLock
::
new
(
prefill_workers
))
,
decode_workers
:
Arc
::
new
(
RwLock
::
new
(
decode_workers
))
,
prefill_workers
,
decode_workers
,
selection_policy
,
load_tracking
,
prefill_tree
,
timeout_secs
,
interval_secs
,
worker_loads
,
load_monitor_handle
,
http_client
,
_
prefill_health_checker
:
Some
(
prefill_health_checker
),
_
decode_health_checker
:
Some
(
decode_health_checker
),
})
}
...
...
@@ -330,11 +291,13 @@ impl PDRouter {
// Log routing decision
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
,
decode
.url
route
,
prefill
.url
(),
decode
.url
()
);
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
&
prefill
)
{
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
()
)
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
return
HttpResponse
::
InternalServerError
()
...
...
@@ -356,8 +319,8 @@ impl PDRouter {
req
,
json_with_bootstrap
,
route
,
&
prefill
,
&
decode
,
prefill
.as_ref
()
,
decode
.as_ref
()
,
is_stream
,
return_logprob
,
start
,
...
...
@@ -397,11 +360,13 @@ impl PDRouter {
// Log routing decision
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
,
decode
.url
route
,
prefill
.url
(),
decode
.url
()
);
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
&
prefill
)
{
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
()
)
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
return
HttpResponse
::
InternalServerError
()
...
...
@@ -423,8 +388,8 @@ impl PDRouter {
req
,
json_with_bootstrap
,
route
,
&
prefill
,
&
decode
,
prefill
.as_ref
()
,
decode
.as_ref
()
,
is_stream
,
return_logprob
,
start
,
...
...
@@ -440,22 +405,23 @@ impl PDRouter {
req
:
&
HttpRequest
,
json_request
:
serde_json
::
Value
,
route
:
&
str
,
prefill
:
&
EngineInfo
,
decode
:
&
EngineInfo
,
prefill
:
&
dyn
Worker
,
decode
:
&
dyn
Worker
,
is_stream
:
bool
,
return_logprob
:
bool
,
start_time
:
Instant
,
)
->
HttpResponse
{
// Update load tracking for both workers
let
_
guard
=
LoadGuard
::
new
(
&
self
.load_tracking
,
vec!
[
prefill
.url
.clone
(),
decode
.url
.clone
()],
);
let
_
guard
=
WorkerLoadGuard
::
new_multi
(
vec!
[
prefill
,
decode
]);
// Build requests using .json() method
let
mut
prefill_request
=
client
.post
(
prefill
.api_path
(
route
))
.json
(
&
json_request
);
let
mut
prefill_request
=
client
.post
(
api_path
(
prefill
.url
(),
route
))
.json
(
&
json_request
);
let
mut
decode_request
=
client
.post
(
decode
.api_path
(
route
))
.json
(
&
json_request
);
let
mut
decode_request
=
client
.post
(
api_path
(
decode
.url
(),
route
))
.json
(
&
json_request
);
// Copy headers from original request
for
(
name
,
value
)
in
crate
::
router
::
copy_request_headers
(
req
)
{
...
...
@@ -474,9 +440,9 @@ impl PDRouter {
histogram!
(
"sgl_router_pd_request_duration_seconds"
,
"route"
=>
route
.to_string
())
.record
(
duration
.as_secs_f64
());
counter!
(
"sgl_router_pd_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_prefill_requests_total"
,
"worker"
=>
prefill
.url
.to_string
())
counter!
(
"sgl_router_pd_prefill_requests_total"
,
"worker"
=>
prefill
.url
()
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_decode_requests_total"
,
"worker"
=>
decode
.url
.to_string
())
counter!
(
"sgl_router_pd_decode_requests_total"
,
"worker"
=>
decode
.url
()
.to_string
())
.increment
(
1
);
// Process decode response
...
...
@@ -486,10 +452,11 @@ impl PDRouter {
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
if
!
status
.is_success
()
{
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
()
.to_string
())
.increment
(
1
);
error!
(
"Decode server {} returned error status: {}"
,
decode
.url
,
status
decode
.url
(),
status
);
// Return the error response from decode server
...
...
@@ -508,9 +475,10 @@ impl PDRouter {
if
let
Err
(
e
)
=
&
prefill_result
{
error!
(
"Prefill server {} failed (non-critical): {}"
,
prefill
.url
,
e
prefill
.url
(),
e
);
counter!
(
"sgl_router_pd_prefill_errors_total"
,
"worker"
=>
prefill
.url
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_prefill_errors_total"
,
"worker"
=>
prefill
.url
()
.to_string
())
.increment
(
1
);
}
if
is_stream
{
...
...
@@ -559,7 +527,7 @@ impl PDRouter {
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
({
let
decode_url
=
decode
.url
.clone
();
let
decode_url
=
decode
.url
()
.to_string
();
res
.bytes_stream
()
.map_err
(
move
|
e
|
{
error!
(
"Stream error from decode server {}: {}"
,
decode_url
,
e
);
counter!
(
"sgl_router_pd_stream_errors_total"
,
"worker"
=>
decode_url
.to_string
())
.increment
(
1
);
...
...
@@ -587,7 +555,7 @@ impl PDRouter {
}
Err
(
e
)
=>
{
error!
(
"Decode request failed: {}"
,
e
);
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
.to_string
())
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
()
.to_string
())
.increment
(
1
);
HttpResponse
::
BadGateway
()
.body
(
format!
(
"Decode server error: {}"
,
e
))
}
...
...
@@ -652,7 +620,7 @@ impl PDRouter {
async
fn
select_pd_pair
(
&
self
,
_
client
:
&
reqwest
::
Client
,
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
)
->
Result
<
(
Box
<
dyn
Worker
>
,
Box
<
dyn
Worker
>
),
String
>
{
// Check we have workers
if
self
.prefill_workers
...
...
@@ -681,17 +649,17 @@ impl PDRouter {
}
}
fn
select_random
(
&
self
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
fn
select_random
(
&
self
)
->
Result
<
(
Box
<
dyn
Worker
>
,
Box
<
dyn
Worker
>
),
String
>
{
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
prefill
=
prefill_list
[
rand
::
random
::
<
usize
>
()
%
prefill_list
.len
()]
.clone
();
let
decode
=
decode_list
[
rand
::
random
::
<
usize
>
()
%
decode_list
.len
()]
.clone
();
let
prefill
=
prefill_list
[
rand
::
random
::
<
usize
>
()
%
prefill_list
.len
()]
.clone
_worker
();
let
decode
=
decode_list
[
rand
::
random
::
<
usize
>
()
%
decode_list
.len
()]
.clone
_worker
();
Ok
((
prefill
,
decode
))
}
async
fn
select_power_of_two
(
&
self
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
async
fn
select_power_of_two
(
&
self
)
->
Result
<
(
Box
<
dyn
Worker
>
,
Box
<
dyn
Worker
>
),
String
>
{
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
...
...
@@ -700,33 +668,45 @@ impl PDRouter {
let
loads
=
self
.worker_loads
.borrow
();
let
p1_load
=
loads
.get
(
&
prefill_list
[
p1_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
p2_load
=
loads
.get
(
&
prefill_list
[
p2_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
d1_load
=
loads
.get
(
&
decode_list
[
d1_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
d2_load
=
loads
.get
(
&
decode_list
[
d2_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
p1_load
=
loads
.get
(
prefill_list
[
p1_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
let
p2_load
=
loads
.get
(
prefill_list
[
p2_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
let
d1_load
=
loads
.get
(
decode_list
[
d1_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
let
d2_load
=
loads
.get
(
decode_list
[
d2_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
info!
(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}"
,
prefill_list
[
p1_idx
]
.url
,
prefill_list
[
p1_idx
]
.url
()
,
p1_load
,
prefill_list
[
p2_idx
]
.url
,
prefill_list
[
p2_idx
]
.url
()
,
p2_load
,
decode_list
[
d1_idx
]
.url
,
decode_list
[
d1_idx
]
.url
()
,
d1_load
,
decode_list
[
d2_idx
]
.url
,
decode_list
[
d2_idx
]
.url
()
,
d2_load
);
let
selected_prefill
=
if
p1_load
<=
p2_load
{
prefill_list
[
p1_idx
]
.clone
()
prefill_list
[
p1_idx
]
.clone
_worker
()
}
else
{
prefill_list
[
p2_idx
]
.clone
()
prefill_list
[
p2_idx
]
.clone
_worker
()
};
let
selected_decode
=
if
d1_load
<=
d2_load
{
decode_list
[
d1_idx
]
.clone
()
decode_list
[
d1_idx
]
.clone
_worker
()
}
else
{
decode_list
[
d2_idx
]
.clone
()
decode_list
[
d2_idx
]
.clone
_worker
()
};
Ok
((
selected_prefill
,
selected_decode
))
...
...
@@ -868,11 +848,11 @@ impl PDRouter {
let
mut
worker_infos
=
Vec
::
new
();
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
worker_infos
.push
((
worker
.url
.clone
(),
"prefill"
));
worker_infos
.push
((
worker
.url
()
.to_string
(),
"prefill"
));
}
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
worker_infos
.push
((
worker
.url
.clone
(),
"decode"
));
worker_infos
.push
((
worker
.url
()
.to_string
(),
"decode"
));
}
// Create tasks with URL tracking
...
...
@@ -922,7 +902,7 @@ impl PDRouter {
pub
async
fn
get_server_info
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
// Get info from the first decode server to match sglang's server info format
let
first_decode_url
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access decode workers"
);
};
...
...
@@ -967,7 +947,7 @@ impl PDRouter {
pub
async
fn
get_models
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
)
->
HttpResponse
{
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
};
...
...
@@ -1005,14 +985,14 @@ impl PDRouter {
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
let
d_urls
:
Vec
<
_
>
=
self
.decode_workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
let
mut
prefill_loads
=
Vec
::
new
();
...
...
@@ -1048,7 +1028,7 @@ impl PDRouter {
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
};
...
...
@@ -1084,13 +1064,13 @@ impl PDRouter {
// Flush cache on all prefill servers
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
);
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
()
);
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
// Flush cache on all decode servers
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
);
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
()
);
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
...
...
sgl-router/src/pd_types.rs
View file @
f2d5c492
// Essential PDLB types extracted for PD routing
use
crate
::
core
::{
Worker
,
WorkerType
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
...
...
@@ -28,52 +29,21 @@ pub enum PDRouterError {
Timeout
{
url
:
String
},
}
#[derive(Debug,
Clone)]
pub
enum
EngineType
{
Prefill
,
Decode
,
}
#[derive(Debug,
Clone)]
pub
struct
EngineInfo
{
pub
engine_type
:
EngineType
,
pub
url
:
String
,
pub
bootstrap_port
:
Option
<
u16
>
,
}
impl
EngineInfo
{
pub
fn
new_prefill
(
url
:
String
,
bootstrap_port
:
Option
<
u16
>
)
->
Self
{
EngineInfo
{
engine_type
:
EngineType
::
Prefill
,
url
,
bootstrap_port
,
}
}
pub
fn
new_decode
(
url
:
String
)
->
Self
{
EngineInfo
{
engine_type
:
EngineType
::
Decode
,
url
,
bootstrap_port
:
None
,
}
}
pub
fn
api_path
(
&
self
,
api_path
:
&
str
)
->
String
{
if
api_path
.starts_with
(
"/"
)
{
format!
(
"{}{}"
,
self
.url
,
api_path
)
}
else
{
format!
(
"{}/{}"
,
self
.url
,
api_path
)
}
// Helper functions for workers
pub
fn
api_path
(
url
:
&
str
,
api_path
:
&
str
)
->
String
{
if
api_path
.starts_with
(
"/"
)
{
format!
(
"{}{}"
,
url
,
api_path
)
}
else
{
format!
(
"{}/{}"
,
url
,
api_path
)
}
}
pub
fn
get_hostname
(
&
self
)
->
String
{
// Simple hostname extraction without external dependencies
let
url
=
self
.url
.trim_start_matches
(
"http://"
)
.trim_start_matches
(
"https://"
);
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
}
pub
fn
get_hostname
(
url
:
&
str
)
->
String
{
// Simple hostname extraction without external dependencies
let
url
=
url
.trim_start_matches
(
"http://"
)
.trim_start_matches
(
"https://"
);
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
}
// PD-specific routing policies
...
...
@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
bootstrap_room
:
BootstrapRoom
,
);
fn
add_bootstrap_info
(
&
mut
self
,
prefill_
info
:
&
EngineInfo
)
->
Result
<
(),
String
>
{
fn
add_bootstrap_info
(
&
mut
self
,
prefill_
worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
self
.get_batch_size
()
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
prefill_worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
self
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
prefill_info
.get_
hostname
()
;
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
prefill_info
.
bootstrap_port
;
batch_size
]),
BootstrapHost
::
Batch
(
vec!
[
hostname
;
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
bootstrap_port
;
batch_size
]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom
::
Batch
(
(
0
..
batch_size
)
...
...
@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
);
}
else
{
self
.set_bootstrap_info
(
BootstrapHost
::
Single
(
prefill_info
.get_
hostname
()
),
BootstrapPort
::
Single
(
prefill_info
.
bootstrap_port
),
BootstrapHost
::
Single
(
hostname
),
BootstrapPort
::
Single
(
bootstrap_port
),
BootstrapRoom
::
Single
({
// Use high-quality random number for single requests too
let
r1
=
rand
::
random
::
<
u64
>
();
...
...
sgl-router/src/router.rs
View file @
f2d5c492
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
};
use
crate
::
pd_router
::
PDRouter
;
use
crate
::
pd_types
::
PDSelectionPolicy
;
use
crate
::
tree
::
Tree
;
...
...
@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
...
...
@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
#[derive(Debug)]
pub
enum
Router
{
RoundRobin
{
worker
_url
s
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
current_index
:
AtomicUsize
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
_
health_checker
:
Option
<
HealthChecker
>
,
},
Random
{
worker
_url
s
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
_
health_checker
:
Option
<
HealthChecker
>
,
},
PrefillDecode
{
pd_router
:
Arc
<
PDRouter
>
,
...
...
@@ -104,16 +106,15 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker
_url
s
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
running_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
processed_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
_
eviction_thread
:
Option
<
thread
::
JoinHandle
<
()
>>
,
_
health_checker
:
Option
<
HealthChecker
>
,
},
}
...
...
@@ -192,25 +193,43 @@ impl Router {
}
}
// Create Worker trait objects from URLs
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
worker_urls
.iter
()
.map
(|
url
|
WorkerFactory
::
create_regular
(
url
.clone
()))
.collect
();
// Create router based on policy...
Ok
(
match
policy_config
{
PolicyConfig
::
RandomConfig
{
timeout_secs
,
interval_secs
,
}
=>
Router
::
Random
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
timeout_secs
,
interval_secs
,
},
}
=>
{
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
));
let
health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
Router
::
Random
{
workers
,
timeout_secs
,
interval_secs
,
_
health_checker
:
Some
(
health_checker
),
}
}
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
,
interval_secs
,
}
=>
Router
::
RoundRobin
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
timeout_secs
,
interval_secs
,
},
}
=>
{
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
));
let
health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
Router
::
RoundRobin
{
workers
,
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
timeout_secs
,
interval_secs
,
_
health_checker
:
Some
(
health_checker
),
}
}
PolicyConfig
::
CacheAwareConfig
{
cache_threshold
,
balance_abs_threshold
,
...
...
@@ -220,24 +239,12 @@ impl Router {
timeout_secs
,
interval_secs
,
}
=>
{
let
mut
running_queue
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
running_queue
.insert
(
url
.clone
(),
0
);
}
let
mut
processed_queue
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
processed_queue
.insert
(
url
.clone
(),
0
);
}
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
running_queue
=
Arc
::
new
(
Mutex
::
new
(
running_queue
));
let
processed_queue
=
Arc
::
new
(
Mutex
::
new
(
processed_queue
));
// Create background eviction thread
let
tree_clone
=
Arc
::
clone
(
&
tree
);
let
processed_queue_clone
=
Arc
::
clone
(
&
processed_queue
);
let
running_queue
_clone
=
Arc
::
clone
(
&
running_queue
);
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
)
);
let
workers
_clone
=
Arc
::
clone
(
&
workers
);
let
eviction_thread
=
thread
::
spawn
(
move
||
{
loop
{
// Sleep for the specified interval
...
...
@@ -246,32 +253,41 @@ impl Router {
let
locked_tree_clone
=
tree_clone
.lock
()
.unwrap
();
// Run eviction
locked_tree_clone
.evict_tenant_by_size
(
max_tree_size
);
// Print the process queue
let
locked_processed_queue
=
processed_queue_clone
.lock
()
.unwrap
();
info!
(
"Processed Queue: {:?}"
,
locked_processed_queue
);
// Print the running queue
let
locked_running_queue
=
running_queue_clone
.lock
()
.unwrap
();
info!
(
"Running Queue: {:?}"
,
locked_running_queue
);
drop
(
locked_tree_clone
);
// Log worker loads and processed requests
let
workers_guard
=
workers_clone
.read
()
.unwrap
();
let
loads
:
Vec
<
(
String
,
usize
)
>
=
workers_guard
.iter
()
.map
(|
w
|
(
w
.url
()
.to_string
(),
w
.load
()))
.collect
();
info!
(
"Worker loads: {:?}"
,
loads
);
let
processed
:
Vec
<
(
String
,
usize
)
>
=
workers_guard
.iter
()
.map
(|
w
|
(
w
.url
()
.to_string
(),
w
.processed_requests
()))
.collect
();
info!
(
"Processed requests: {:?}"
,
processed
);
}
});
for
url
in
&
worker
_urls
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
url
);
for
worker
in
worker
s
.read
()
.unwrap
()
.iter
()
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker
.
url
()
);
}
let
health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
Router
::
CacheAware
{
worker
_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
))
,
worker
s
,
tree
,
running_queue
,
processed_queue
,
cache_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
timeout_secs
,
interval_secs
,
_
eviction_thread
:
Some
(
eviction_thread
),
_
health_checker
:
Some
(
health_checker
),
}
}
PolicyConfig
::
PrefillDecodeConfig
{
...
...
@@ -297,16 +313,18 @@ impl Router {
})
}
/// Get
a reference to the worker URLs shared across thread
s
pub
fn
get_worker_urls
(
&
self
)
->
Arc
<
RwLock
<
Vec
<
String
>
>>
{
/// Get
the current list of worker URL
s
pub
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
Random
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
CacheAware
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
PrefillDecode
{
..
}
=>
{
// For PD mode, return empty list since we manage workers differently
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()))
}
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
workers
,
..
}
=>
workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
(),
Router
::
PrefillDecode
{
..
}
=>
Vec
::
new
(),
}
}
...
...
@@ -373,13 +391,14 @@ impl Router {
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
|
Router
::
Random
{
worker_urls
,
..
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
if
worker_urls
.read
()
.unwrap
()
.is_empty
()
{
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
workers
,
..
}
=>
{
let
workers_guard
=
workers
.read
()
.unwrap
();
if
workers_guard
.is_empty
()
{
Err
(
"No workers are available"
.to_string
())
}
else
{
Ok
(
worker
_urls
.read
()
.unwrap
()[
0
]
.clone
())
Ok
(
worker
s_guard
[
0
]
.url
()
.to_string
())
}
}
Router
::
PrefillDecode
{
..
}
=>
{
...
...
@@ -514,7 +533,7 @@ impl Router {
return
HttpResponse
::
NotImplemented
()
.body
(
"route_to_all not implemented for PrefillDecode mode"
);
}
_
=>
self
.get_worker_urls
()
.read
()
.unwrap
()
.clone
()
,
_
=>
self
.get_worker_urls
(),
};
// Send requests to all workers concurrently
...
...
@@ -562,7 +581,7 @@ impl Router {
}
}
let
urls
=
self
.get_worker_urls
()
.read
()
.unwrap
()
.clone
()
;
let
urls
=
self
.get_worker_urls
();
let
prefill_urls
:
Vec
<
String
>
=
Vec
::
new
();
let
decode_urls
=
urls
;
...
...
@@ -631,6 +650,24 @@ impl Router {
.increment
(
1
);
}
// For CacheAware router, increment load before request
let
load_incremented
=
match
self
{
Router
::
CacheAware
{
workers
,
..
}
=>
{
let
workers_guard
=
workers
.read
()
.unwrap
();
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
&
worker_url
)
{
worker
.increment_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
true
}
else
{
false
}
}
_
=>
false
,
};
// Send typed request directly
let
response
=
self
.send_typed_request
(
...
...
@@ -640,6 +677,7 @@ impl Router {
route
,
&
worker_url
,
is_stream
,
load_incremented
,
)
.await
;
...
...
@@ -684,44 +722,47 @@ impl Router {
}
}
// Helper method to select worker from text
// Helper method to select worker from text
(returns index for RoundRobin/Random, URL for CacheAware)
fn
select_generate_worker_from_text
(
&
self
,
text
:
&
str
)
->
String
{
match
self
{
Router
::
RoundRobin
{
worker
_url
s
,
workers
,
current_index
,
..
}
=>
{
let
workers_guard
=
workers
.read
()
.unwrap
();
let
idx
=
current_index
.fetch_update
(
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
|
x
|
Some
((
x
+
1
)
%
worker
_urls
.read
()
.unwrap
()
.len
()),
|
x
|
Some
((
x
+
1
)
%
worker
s_guard
.len
()),
)
.unwrap
();
worker
_urls
.read
()
.unwrap
()[
idx
]
.clone
()
worker
s_guard
[
idx
]
.url
()
.to_string
()
}
Router
::
Random
{
worker_urls
,
..
}
=>
worker_urls
.read
()
.unwrap
()
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.read
()
.unwrap
()
.len
()]
.clone
(),
Router
::
Random
{
workers
,
..
}
=>
{
let
workers_guard
=
workers
.read
()
.unwrap
();
workers_guard
[
rand
::
random
::
<
usize
>
()
%
workers_guard
.len
()]
.url
()
.to_string
()
}
Router
::
CacheAware
{
worker
_url
s
,
workers
,
tree
,
running_queue
,
processed_queue
,
cache_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
..
}
=>
{
let
tree
=
tree
.lock
()
.unwrap
();
let
mut
running_queue
=
running_queue
.lock
()
.unwrap
();
let
workers_guard
=
workers
.read
()
.unwrap
();
// Get current load statistics
let
max_load
=
*
running_queue
.values
()
.max
()
.unwrap_or
(
&
0
);
let
min_load
=
*
running_queue
.values
()
.min
()
.unwrap_or
(
&
0
);
// Get current load statistics from workers
let
loads
:
Vec
<
usize
>
=
workers_guard
.iter
()
.map
(|
w
|
w
.load
())
.collect
();
let
max_load
=
*
loads
.iter
()
.max
()
.unwrap_or
(
&
0
);
let
min_load
=
*
loads
.iter
()
.min
()
.unwrap_or
(
&
0
);
// Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
...
...
@@ -731,11 +772,16 @@ impl Router {
let
selected_url
=
if
is_imbalanced
{
// Log load balancing trigger and current queue state
let
worker_loads
:
Vec
<
(
String
,
usize
)
>
=
workers_guard
.iter
()
.map
(|
w
|
(
w
.url
()
.to_string
(),
w
.load
()))
.collect
();
info!
(
"Load balancing triggered due to workload imbalance:
\n
\
Max load: {}, Min load: {}
\n
\
Current
running queue
: {:?}"
,
max_load
,
min_load
,
running_queue
Current
worker loads
: {:?}"
,
max_load
,
min_load
,
worker_loads
);
counter!
(
"sgl_router_load_balancing_events_total"
)
.increment
(
1
);
...
...
@@ -743,11 +789,11 @@ impl Router {
gauge!
(
"sgl_router_min_load"
)
.set
(
min_load
as
f64
);
// Use shortest queue routing when load is imbalanced
running_queue
workers_guard
.iter
()
.min_by_key
(|
(
_u
rl
,
&
count
)|
count
)
.map
(|
(
url
,
_
)|
url
.clone
())
.unwrap_or_else
(||
worker
_urls
.read
()
.unwrap
()[
0
]
.clone
())
.min_by_key
(|
w
|
w
.load
()
)
.map
(|
w
|
w
.url
()
.to_string
())
.unwrap_or_else
(||
worker
s_guard
[
0
]
.url
()
.to_string
())
}
else
{
// Use cache-aware routing when load is balanced
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
&
text
);
...
...
@@ -763,18 +809,12 @@ impl Router {
}
};
// Update queues and tree
*
running_queue
.get_mut
(
&
selected_url
)
.unwrap
()
+=
1
;
*
processed_queue
.lock
()
.unwrap
()
.get_mut
(
&
selected_url
)
.unwrap
()
+=
1
;
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
selected_url
.to_string
())
.set
(
*
running_queue
.get
(
&
selected_url
)
.unwrap
()
as
f64
);
counter!
(
"sgl_router_processed_requests_total"
,
"worker"
=>
selected_url
.to_string
())
.increment
(
1
);
// Find the selected worker and increment processed counter only
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
&
selected_url
)
{
worker
.increment_processed
();
counter!
(
"sgl_router_processed_requests_total"
,
"worker"
=>
selected_url
.to_string
())
.increment
(
1
);
}
tree
.insert
(
&
text
,
&
selected_url
);
...
...
@@ -796,6 +836,7 @@ impl Router {
route
:
&
str
,
worker_url
:
&
str
,
is_stream
:
bool
,
load_incremented
:
bool
,
// Whether load was incremented for this request
)
->
HttpResponse
{
let
start
=
Instant
::
now
();
...
...
@@ -820,6 +861,22 @@ impl Router {
Ok
(
res
)
=>
res
,
Err
(
e
)
=>
{
error!
(
"Failed to send request to {}: {}"
,
worker_url
,
e
);
// Decrement load on error for CacheAware router
if
load_incremented
{
if
let
Router
::
CacheAware
{
workers
,
..
}
=
self
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
}
}
}
}
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Request failed: {}"
,
e
));
}
};
...
...
@@ -837,13 +894,15 @@ impl Router {
}
};
// Then decrement running queue counter if using CacheAware
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
if
let
Ok
(
mut
queue
)
=
running_queue
.lock
()
{
if
let
Some
(
count
)
=
queue
.get_mut
(
worker_url
)
{
*
count
=
count
.saturating_sub
(
1
);
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
*
count
as
f64
);
// Decrement load counter for non-streaming CacheAware requests
if
load_incremented
&&
!
is_stream
{
if
let
Router
::
CacheAware
{
workers
,
..
}
=
self
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
}
}
}
}
...
...
@@ -855,8 +914,9 @@ impl Router {
counter!
(
"sgl_router_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
response
}
else
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
let
running_queue
=
Arc
::
clone
(
running_queue
);
}
else
if
let
Router
::
CacheAware
{
workers
,
..
}
=
self
{
// For streaming with CacheAware router, we need to manually decrement when done
let
workers
=
Arc
::
clone
(
workers
);
let
worker_url
=
worker_url
.to_string
();
HttpResponse
::
build
(
status
)
...
...
@@ -867,21 +927,28 @@ impl Router {
actix_web
::
error
::
ErrorInternalServerError
(
"Failed to read stream"
)
})
.inspect
(
move
|
bytes
|
{
let
bytes
=
bytes
.as_ref
()
.unwrap
();
if
bytes
.as_ref
()
.windows
(
12
)
.any
(|
window
|
window
==
b
"data: [DONE]"
)
{
let
mut
locked_queue
=
running_queue
.lock
()
.unwrap
();
let
count
=
locked_queue
.get_mut
(
&
worker_url
)
.unwrap
();
*
count
=
count
.saturating_sub
(
1
);
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
*
count
as
f64
);
debug!
(
"Streaming is done!!"
)
if
let
Ok
(
bytes
)
=
bytes
{
if
bytes
.as_ref
()
.windows
(
12
)
.any
(|
window
|
window
==
b
"data: [DONE]"
)
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
&
worker_url
)
{
worker
.decrement_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
debug!
(
"Streaming is done!!"
)
}
}
}
}
}),
)
}
else
{
// For non-CacheAware routers, just stream without load tracking
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
(
res
.bytes_stream
()
.map_err
(|
_
|
{
...
...
@@ -935,43 +1002,27 @@ impl Router {
Ok
(
res
)
=>
{
if
res
.status
()
.is_success
()
{
match
self
{
Router
::
RoundRobin
{
worker
_url
s
,
..
}
|
Router
::
Random
{
worker
_url
s
,
..
}
|
Router
::
CacheAware
{
worker
_url
s
,
..
}
=>
{
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
workers
,
..
}
=>
{
info!
(
"Worker {} health check passed"
,
worker_url
);
let
mut
urls
=
worker
_url
s
.write
()
.unwrap
();
if
urls
.contains
(
&
worker_url
.to_string
()
)
{
let
mut
workers_guard
=
workers
.write
()
.unwrap
();
if
workers_guard
.iter
()
.any
(|
w
|
w
.url
()
==
worker_url
)
{
return
Err
(
format!
(
"Worker {} already exists"
,
worker_url
));
}
info!
(
"Added worker: {}"
,
worker_url
);
urls
.push
(
worker_url
.to_string
());
gauge!
(
"sgl_router_active_workers"
)
.set
(
urls
.len
()
as
f64
);
let
new_worker
=
WorkerFactory
::
create_regular
(
worker_url
.to_string
());
workers_guard
.push
(
new_worker
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
workers_guard
.len
()
as
f64
);
}
Router
::
PrefillDecode
{
..
}
=>
{
return
Err
(
"Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods."
.to_string
());
}
}
// If cache aware, initialize the queues for the new worker
if
let
Router
::
CacheAware
{
running_queue
,
processed_queue
,
tree
,
..
}
=
self
{
// Add worker to running queue with initial count of 0
running_queue
.lock
()
.unwrap
()
.insert
(
worker_url
.to_string
(),
0
);
// Add worker to processed queue with initial count of 0
processed_queue
.lock
()
.unwrap
()
.insert
(
worker_url
.to_string
(),
0
);
// If cache aware, add worker to tree
if
let
Router
::
CacheAware
{
tree
,
..
}
=
self
{
// Add worker to tree
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker_url
);
}
...
...
@@ -1013,14 +1064,14 @@ impl Router {
pub
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
match
self
{
Router
::
RoundRobin
{
worker
_url
s
,
..
}
|
Router
::
Random
{
worker
_url
s
,
..
}
|
Router
::
CacheAware
{
worker
_url
s
,
..
}
=>
{
let
mut
urls
=
worker
_url
s
.write
()
.unwrap
();
if
let
Some
(
index
)
=
urls
.iter
()
.position
(|
url
|
url
==
&
worker_url
)
{
urls
.remove
(
index
);
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
workers
,
..
}
=>
{
let
mut
workers_guard
=
workers
.write
()
.unwrap
();
if
let
Some
(
index
)
=
workers_guard
.iter
()
.position
(|
w
|
w
.
url
()
==
worker_url
)
{
workers_guard
.remove
(
index
);
info!
(
"Removed worker: {}"
,
worker_url
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
urls
.len
()
as
f64
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
workers_guard
.len
()
as
f64
);
}
else
{
warn!
(
"Worker {} not found, skipping removal"
,
worker_url
);
return
;
...
...
@@ -1033,26 +1084,9 @@ impl Router {
}
// if cache aware, remove the worker from the tree
if
let
Router
::
CacheAware
{
tree
,
running_queue
,
processed_queue
,
..
}
=
self
{
if
let
Router
::
CacheAware
{
tree
,
..
}
=
self
{
tree
.lock
()
.unwrap
()
.remove_tenant
(
&
worker_url
);
running_queue
.lock
()
.unwrap
()
.remove
(
&
worker_url
.to_string
());
processed_queue
.lock
()
.unwrap
()
.remove
(
&
worker_url
.to_string
());
info!
(
"Removed worker from tree and cleaned up queues: {}"
,
worker_url
);
info!
(
"Removed worker from tree: {}"
,
worker_url
);
}
}
...
...
@@ -1241,21 +1275,22 @@ mod tests {
use
crate
::
service_discovery
::
PodType
;
fn
create_test_regular_router
()
->
Router
{
let
workers
=
vec!
[
WorkerFactory
::
create_regular
(
"http://worker1:8080"
.to_string
()),
WorkerFactory
::
create_regular
(
"http://worker2:8080"
.to_string
()),
];
Router
::
Random
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
vec!
[
"http://worker1:8080"
.to_string
(),
"http://worker2:8080"
.to_string
(),
])),
workers
:
Arc
::
new
(
RwLock
::
new
(
workers
)),
timeout_secs
:
5
,
interval_secs
:
1
,
_
health_checker
:
None
,
}
}
#[test]
fn
test_router_get_worker_urls_regular
()
{
let
router
=
create_test_regular_router
();
let
worker_urls
=
router
.get_worker_urls
();
let
urls
=
worker_urls
.read
()
.unwrap
();
let
urls
=
router
.get_worker_urls
();
assert_eq!
(
urls
.len
(),
2
);
assert
!
(
urls
.contains
(
&
"http://worker1:8080"
.to_string
()));
...
...
sgl-router/src/server.rs
View file @
f2d5c492
...
...
@@ -236,8 +236,7 @@ async fn add_worker(
#[get(
"/list_workers"
)]
async
fn
list_workers
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
workers
=
data
.router
.get_worker_urls
();
let
worker_list
=
workers
.read
()
.unwrap
()
.clone
();
let
worker_list
=
data
.router
.get_worker_urls
();
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"urls"
:
worker_list
}))
}
...
...
@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!
(
"✅ Serving router on {}:{}"
,
config
.host
,
config
.port
);
info!
(
"✅ Serving workers on {:?}"
,
app_state
.router
.get_worker_urls
()
.read
()
.unwrap
()
app_state
.router
.get_worker_urls
()
);
HttpServer
::
new
(
move
||
{
...
...
sgl-router/src/service_discovery.rs
View file @
f2d5c492
...
...
@@ -547,11 +547,12 @@ mod tests {
// Helper to create a Router instance for testing event handlers
fn
create_test_router
()
->
Arc
<
Router
>
{
let
worker
_url
s
=
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()));
let
workers
=
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()));
Arc
::
new
(
Router
::
Random
{
worker
_url
s
,
workers
,
timeout_secs
:
5
,
interval_secs
:
1
,
_
health_checker
:
None
,
})
}
...
...
@@ -878,8 +879,6 @@ mod tests {
assert
!
(
!
tracked_pods
.lock
()
.unwrap
()
.contains
(
&
pod_info
));
assert
!
(
!
router
.get_worker_urls
()
.read
()
.unwrap
()
.contains
(
&
pod_info
.worker_url
(
port
)));
}
...
...
@@ -907,7 +906,7 @@ mod tests {
.await
;
assert
!
(
tracked_pods
.lock
()
.unwrap
()
.is_empty
());
assert
!
(
router
.get_worker_urls
()
.
read
()
.unwrap
()
.
is_empty
());
assert
!
(
router
.get_worker_urls
()
.is_empty
());
}
#[tokio::test]
...
...
sgl-router/tests/test_pd_routing.rs
View file @
f2d5c492
...
...
@@ -12,7 +12,7 @@
mod
test_pd_routing
{
use
rand
::
Rng
;
use
serde_json
::
json
;
use
sglang_router_rs
::
pd_types
::
{
EngineInfo
,
EngineType
,
PDSelectionPolicy
}
;
use
sglang_router_rs
::
pd_types
::
PDSelectionPolicy
;
use
sglang_router_rs
::
router
::{
PolicyConfig
,
Router
};
// Test-only struct to help validate PD request parsing
...
...
@@ -51,40 +51,35 @@ mod test_pd_routing {
// ========================================================================
#[test]
fn
test_engine_info_creation
()
{
// Test EngineInfo creation for prefill servers
let
prefill_engine
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
match
prefill_engine
.engine_type
{
EngineType
::
Prefill
=>
(),
_
=>
panic!
(
"Expected Prefill engine type"
),
fn
test_worker_types
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
// Test worker creation for prefill servers
let
prefill_worker
=
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
assert_eq!
(
prefill_worker
.url
(),
"http://prefill:8080"
);
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
{
assert_eq!
(
bootstrap_port
,
Some
(
9000
));
}
_
=>
panic!
(
"Expected Prefill worker type"
),
}
assert_eq!
(
prefill_engine
.url
,
"http://prefill:8080"
);
assert_eq!
(
prefill_engine
.bootstrap_port
,
Some
(
9000
));
assert_eq!
(
prefill_engine
.get_hostname
(),
"prefill"
);
// Test EngineInfo creation for decode servers
let
decode_engine
=
EngineInfo
::
new_decode
(
"http://decode:8080"
.to_string
());
match
decode_engine
.engine_type
{
EngineType
::
Decode
=>
(),
_
=>
panic!
(
"Expected Decode engine type"
),
// Test worker creation for decode servers
let
decode_worker
=
WorkerFactory
::
create_decode
(
"http://decode:8080"
.to_string
());
assert_eq!
(
decode_worker
.url
(),
"http://decode:8080"
);
match
decode_worker
.worker_type
()
{
WorkerType
::
Decode
=>
(),
_
=>
panic!
(
"Expected Decode worker type"
),
}
assert_eq!
(
decode_engine
.url
,
"http://decode:8080"
);
assert_eq!
(
decode_engine
.bootstrap_port
,
None
);
assert_eq!
(
decode_engine
.get_hostname
(),
"decode"
);
// Test API path generation
assert_eq!
(
prefill_engine
.api_path
(
"/generate"
),
"http://prefill:8080/generate"
);
assert_eq!
(
prefill_engine
.api_path
(
"health"
),
"http://prefill:8080/health"
);
assert_eq!
(
decode_engine
.api_path
(
"/v1/chat/completions"
),
"http://decode:8080/v1/chat/completions"
);
// Test regular worker creation
let
regular_worker
=
WorkerFactory
::
create_regular
(
"http://regular:8080"
.to_string
());
assert_eq!
(
regular_worker
.url
(),
"http://regular:8080"
);
match
regular_worker
.worker_type
()
{
WorkerType
::
Regular
=>
(),
_
=>
panic!
(
"Expected Regular worker type"
),
}
}
#[test]
...
...
@@ -230,6 +225,9 @@ mod test_pd_routing {
#[test]
fn
test_bootstrap_injection_simulation
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
...
...
@@ -240,15 +238,24 @@ mod test_pd_routing {
"temperature"
:
0.7
});
// Create a prefill worker to simulate injection
let
prefill_worker
=
WorkerFactory
::
create_prefill
(
"http://prefill1:8080"
.to_string
(),
Some
(
9000
));
// Extract bootstrap port from worker type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
// Simulate what inject_bootstrap_fields would do
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill1:8080"
.to_string
(),
Some
(
9000
));
single_json
[
"bootstrap_host"
]
=
json!
(
prefill_info
.get_hostname
());
single_json
[
"bootstrap_port"
]
=
json!
(
prefill_info
.bootstrap_port
);
single_json
[
"bootstrap_host"
]
=
json!
(
get_hostname
(
prefill_worker
.url
()));
single_json
[
"bootstrap_port"
]
=
json!
(
bootstrap_port
);
single_json
[
"bootstrap_room"
]
=
json!
(
12345u64
);
// Random room ID
// Verify bootstrap fields are added correctly
assert_eq!
(
single_json
[
"bootstrap_host"
],
"prefill1"
);
assert_eq!
(
single_json
[
"bootstrap_port"
],
9000
);
assert_eq!
(
single_json
[
"bootstrap_port"
],
json!
(
Some
(
9000
)
))
;
assert
!
(
single_json
[
"bootstrap_room"
]
.is_u64
());
assert_eq!
(
single_json
[
"temperature"
],
0.7
);
// Original field preserved
...
...
@@ -259,8 +266,9 @@ mod test_pd_routing {
});
let
batch_size
=
3
;
batch_json
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
batch_json
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
let
hostname
=
get_hostname
(
prefill_worker
.url
());
batch_json
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
batch_json
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
batch_json
[
"bootstrap_room"
]
=
json!
(
vec!
[
111u64
,
222u64
,
333u64
]);
// Verify batch bootstrap fields
...
...
@@ -306,7 +314,9 @@ mod test_pd_routing {
}
#[test]
fn
test_engine_info_hostname_extraction
()
{
fn
test_hostname_extraction
()
{
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Test various URL formats
let
test_cases
=
vec!
[
(
"http://localhost:8080"
,
"localhost"
),
...
...
@@ -318,8 +328,7 @@ mod test_pd_routing {
];
for
(
url
,
expected_hostname
)
in
test_cases
{
let
engine
=
EngineInfo
::
new_prefill
(
url
.to_string
(),
None
);
assert_eq!
(
engine
.get_hostname
(),
expected_hostname
);
assert_eq!
(
get_hostname
(
url
),
expected_hostname
);
}
}
...
...
@@ -652,6 +661,9 @@ mod test_pd_routing {
#[test]
fn
test_bootstrap_injection_with_benchmark_requests
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Test bootstrap injection with actual benchmark request patterns
let
mut
benchmark_request
=
json!
({
"input_ids"
:
vec!
[
vec!
[
1
,
2
,
3
,
4
];
16
],
// Batch size 16
...
...
@@ -664,12 +676,20 @@ mod test_pd_routing {
"stream"
:
true
});
// Simulate bootstrap injection
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
// Create a prefill worker to simulate injection
let
prefill_worker
=
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
// Extract bootstrap port from worker type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
batch_size
=
16
;
let
hostname
=
get_hostname
(
prefill_worker
.url
());
benchmark_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_
hostname
()
;
batch_size
]);
benchmark_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.
bootstrap_port
;
batch_size
]);
benchmark_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
benchmark_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
benchmark_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
12345u64
)
.collect
::
<
Vec
<
_
>>
());
...
...
@@ -770,6 +790,9 @@ mod test_pd_routing {
#[test]
fn
test_large_batch_bootstrap_injection
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let
large_batch_sizes
=
vec!
[
1024
,
4096
,
8192
];
...
...
@@ -787,14 +810,19 @@ mod test_pd_routing {
"stream"
:
true
});
// Simulate bootstrap injection
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
// Create a prefill worker to simulate injection
let
prefill_worker
=
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
// Extract bootstrap port from worker type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
prefill_worker
.url
());
large_batch_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
large_batch_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
large_batch_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
large_batch_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
large_batch_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
rand
::
thread_rng
()
.gen
::
<
u64
>
())
.collect
::
<
Vec
<
_
>>
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment