Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f2d5c492
Unverified
Commit
f2d5c492
authored
Jul 11, 2025
by
Simo Lin
Committed by
GitHub
Jul 11, 2025
Browse files
[router] add worker abstraction (#7960)
parent
2a2d3478
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
961 additions
and
411 deletions
+961
-411
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+2
-0
sgl-router/src/core/error.rs
sgl-router/src/core/error.rs
+57
-0
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+16
-0
sgl-router/src/core/worker.rs
sgl-router/src/core/worker.rs
+454
-0
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+1
-0
sgl-router/src/pd_router.rs
sgl-router/src/pd_router.rs
+111
-131
sgl-router/src/pd_types.rs
sgl-router/src/pd_types.rs
+28
-49
sgl-router/src/router.rs
sgl-router/src/router.rs
+206
-171
sgl-router/src/server.rs
sgl-router/src/server.rs
+2
-3
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+4
-5
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+80
-52
No files found.
sgl-router/Cargo.toml
View file @
f2d5c492
...
@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
...
@@ -30,6 +30,8 @@ tracing-appender = "0.2.3"
kube
=
{
version
=
"0.88.1"
,
features
=
[
"runtime"
,
"derive"
]
}
kube
=
{
version
=
"0.88.1"
,
features
=
[
"runtime"
,
"derive"
]
}
k8s-openapi
=
{
version
=
"0.21.0"
,
features
=
["v1_29"]
}
k8s-openapi
=
{
version
=
"0.21.0"
,
features
=
["v1_29"]
}
futures
=
"0.3"
futures
=
"0.3"
async-trait
=
"0.1"
once_cell
=
"1.21"
# Added for metrics
# Added for metrics
metrics
=
"0.24.2"
metrics
=
"0.24.2"
metrics-exporter-prometheus
=
"0.17.0"
metrics-exporter-prometheus
=
"0.17.0"
...
...
sgl-router/src/core/error.rs
0 → 100644
View file @
f2d5c492
//! Error types for the SGLang router core
//!
//! This module defines error types used throughout the router for worker operations.
use
std
::
fmt
;
/// Worker-related errors
#[derive(Debug)]
pub
enum
WorkerError
{
/// Health check failed
HealthCheckFailed
{
url
:
String
,
reason
:
String
},
/// Worker not found
WorkerNotFound
{
url
:
String
},
/// Invalid worker configuration
InvalidConfiguration
{
message
:
String
},
/// Network error
NetworkError
{
url
:
String
,
error
:
String
},
/// Worker is at capacity
WorkerAtCapacity
{
url
:
String
},
}
impl
fmt
::
Display
for
WorkerError
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
WorkerError
::
HealthCheckFailed
{
url
,
reason
}
=>
{
write!
(
f
,
"Health check failed for worker {}: {}"
,
url
,
reason
)
}
WorkerError
::
WorkerNotFound
{
url
}
=>
{
write!
(
f
,
"Worker not found: {}"
,
url
)
}
WorkerError
::
InvalidConfiguration
{
message
}
=>
{
write!
(
f
,
"Invalid worker configuration: {}"
,
message
)
}
WorkerError
::
NetworkError
{
url
,
error
}
=>
{
write!
(
f
,
"Network error for worker {}: {}"
,
url
,
error
)
}
WorkerError
::
WorkerAtCapacity
{
url
}
=>
{
write!
(
f
,
"Worker at capacity: {}"
,
url
)
}
}
}
}
impl
std
::
error
::
Error
for
WorkerError
{}
/// Result type for worker operations
pub
type
WorkerResult
<
T
>
=
Result
<
T
,
WorkerError
>
;
/// Convert from reqwest errors to worker errors
impl
From
<
reqwest
::
Error
>
for
WorkerError
{
fn
from
(
err
:
reqwest
::
Error
)
->
Self
{
WorkerError
::
NetworkError
{
url
:
err
.url
()
.map
(|
u
|
u
.to_string
())
.unwrap_or_default
(),
error
:
err
.to_string
(),
}
}
}
sgl-router/src/core/mod.rs
0 → 100644
View file @
f2d5c492
//! Core abstractions for the SGLang router
//!
//! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations
//! - Error types
//! - Common utilities
pub
mod
error
;
pub
mod
worker
;
// Re-export commonly used types at the module level
pub
use
error
::{
WorkerError
,
WorkerResult
};
pub
use
worker
::{
start_health_checker
,
BasicWorker
,
HealthChecker
,
Worker
,
WorkerCollection
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
};
sgl-router/src/core/worker.rs
0 → 100644
View file @
f2d5c492
use
super
::{
WorkerError
,
WorkerResult
};
use
async_trait
::
async_trait
;
use
once_cell
::
sync
::
Lazy
;
use
std
::
fmt
;
use
std
::
sync
::
atomic
::{
AtomicBool
,
AtomicUsize
,
Ordering
};
use
std
::
sync
::
Arc
;
// Shared HTTP client for health checks
static
HEALTH_CHECK_CLIENT
:
Lazy
<
reqwest
::
Client
>
=
Lazy
::
new
(||
{
reqwest
::
Client
::
builder
()
.timeout
(
std
::
time
::
Duration
::
from_secs
(
30
))
// Default timeout, overridden per request
.build
()
.expect
(
"Failed to create health check HTTP client"
)
});
/// Core worker abstraction that represents a backend service
#[async_trait]
pub
trait
Worker
:
Send
+
Sync
+
fmt
::
Debug
{
/// Get the worker's URL
fn
url
(
&
self
)
->
&
str
;
/// Get the worker's type (Regular, Prefill, or Decode)
fn
worker_type
(
&
self
)
->
WorkerType
;
/// Check if the worker is currently healthy
fn
is_healthy
(
&
self
)
->
bool
;
/// Set the worker's health status
fn
set_healthy
(
&
self
,
healthy
:
bool
);
/// Perform an async health check on the worker
async
fn
check_health_async
(
&
self
)
->
WorkerResult
<
()
>
;
/// Synchronous health check wrapper (for compatibility)
fn
check_health
(
&
self
)
->
WorkerResult
<
()
>
{
// Use a small runtime for synchronous contexts
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.build
()
.map_err
(|
e
|
WorkerError
::
HealthCheckFailed
{
url
:
self
.url
()
.to_string
(),
reason
:
format!
(
"Failed to create runtime: {}"
,
e
),
})
?
.block_on
(
self
.check_health_async
())
}
/// Get the current load (number of active requests)
fn
load
(
&
self
)
->
usize
;
/// Increment the load counter
fn
increment_load
(
&
self
);
/// Decrement the load counter
fn
decrement_load
(
&
self
);
/// Get the number of processed requests
fn
processed_requests
(
&
self
)
->
usize
;
/// Increment the processed requests counter
fn
increment_processed
(
&
self
);
/// Get worker-specific metadata
fn
metadata
(
&
self
)
->
&
WorkerMetadata
;
/// Clone the worker (for trait objects)
fn
clone_worker
(
&
self
)
->
Box
<
dyn
Worker
>
;
}
/// Worker type classification
#[derive(Debug,
Clone,
PartialEq,
Eq,
Hash)]
pub
enum
WorkerType
{
/// Regular worker for standard routing
Regular
,
/// Prefill worker for PD disaggregated mode
Prefill
{
/// Bootstrap port for communication with decode workers
bootstrap_port
:
Option
<
u16
>
,
},
/// Decode worker for PD disaggregated mode
Decode
,
}
impl
fmt
::
Display
for
WorkerType
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
WorkerType
::
Regular
=>
write!
(
f
,
"Regular"
),
WorkerType
::
Prefill
{
bootstrap_port
}
=>
match
bootstrap_port
{
Some
(
port
)
=>
write!
(
f
,
"Prefill(bootstrap:{})"
,
port
),
None
=>
write!
(
f
,
"Prefill"
),
},
WorkerType
::
Decode
=>
write!
(
f
,
"Decode"
),
}
}
}
/// Health check configuration
#[derive(Debug,
Clone)]
pub
struct
HealthConfig
{
/// Timeout for health checks in seconds
pub
timeout_secs
:
u64
,
/// Interval between health checks in seconds
pub
check_interval_secs
:
u64
,
/// Health check endpoint path
pub
endpoint
:
String
,
}
impl
Default
for
HealthConfig
{
fn
default
()
->
Self
{
Self
{
timeout_secs
:
5
,
check_interval_secs
:
30
,
endpoint
:
"/health"
.to_string
(),
}
}
}
/// Metadata associated with a worker
#[derive(Debug,
Clone)]
pub
struct
WorkerMetadata
{
/// Worker URL
pub
url
:
String
,
/// Worker type
pub
worker_type
:
WorkerType
,
/// Additional labels/tags
pub
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
,
/// Health check configuration
pub
health_config
:
HealthConfig
,
}
/// Basic worker implementation
#[derive(Debug,
Clone)]
pub
struct
BasicWorker
{
metadata
:
WorkerMetadata
,
load_counter
:
Arc
<
AtomicUsize
>
,
processed_counter
:
Arc
<
AtomicUsize
>
,
healthy
:
Arc
<
AtomicBool
>
,
}
impl
BasicWorker
{
pub
fn
new
(
url
:
String
,
worker_type
:
WorkerType
)
->
Self
{
let
metadata
=
WorkerMetadata
{
url
:
url
.clone
(),
worker_type
,
labels
:
std
::
collections
::
HashMap
::
new
(),
health_config
:
HealthConfig
::
default
(),
};
Self
{
metadata
,
load_counter
:
Arc
::
new
(
AtomicUsize
::
new
(
0
)),
processed_counter
:
Arc
::
new
(
AtomicUsize
::
new
(
0
)),
healthy
:
Arc
::
new
(
AtomicBool
::
new
(
true
)),
}
}
pub
fn
with_labels
(
mut
self
,
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
)
->
Self
{
self
.metadata.labels
=
labels
;
self
}
pub
fn
with_health_config
(
mut
self
,
config
:
HealthConfig
)
->
Self
{
self
.metadata.health_config
=
config
;
self
}
}
#[async_trait]
impl
Worker
for
BasicWorker
{
fn
url
(
&
self
)
->
&
str
{
&
self
.metadata.url
}
fn
worker_type
(
&
self
)
->
WorkerType
{
self
.metadata.worker_type
.clone
()
}
fn
is_healthy
(
&
self
)
->
bool
{
self
.healthy
.load
(
Ordering
::
Acquire
)
}
fn
set_healthy
(
&
self
,
healthy
:
bool
)
{
self
.healthy
.store
(
healthy
,
Ordering
::
Release
);
}
async
fn
check_health_async
(
&
self
)
->
WorkerResult
<
()
>
{
use
std
::
time
::
Duration
;
// Perform actual HTTP health check
let
health_url
=
format!
(
"{}{}"
,
self
.url
(),
self
.metadata.health_config.endpoint
);
let
timeout
=
Duration
::
from_secs
(
self
.metadata.health_config.timeout_secs
);
// Use the shared client with a custom timeout for this request
match
HEALTH_CHECK_CLIENT
.get
(
&
health_url
)
.timeout
(
timeout
)
.send
()
.await
{
Ok
(
response
)
=>
{
if
response
.status
()
.is_success
()
{
self
.set_healthy
(
true
);
Ok
(())
}
else
{
self
.set_healthy
(
false
);
Err
(
WorkerError
::
HealthCheckFailed
{
url
:
self
.url
()
.to_string
(),
reason
:
format!
(
"Health check returned status: {}"
,
response
.status
()),
})
}
}
Err
(
e
)
=>
{
self
.set_healthy
(
false
);
Err
(
WorkerError
::
HealthCheckFailed
{
url
:
self
.url
()
.to_string
(),
reason
:
format!
(
"Health check request failed: {}"
,
e
),
})
}
}
}
fn
load
(
&
self
)
->
usize
{
self
.load_counter
.load
(
Ordering
::
Relaxed
)
}
fn
increment_load
(
&
self
)
{
self
.load_counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
fn
decrement_load
(
&
self
)
{
self
.load_counter
.fetch_update
(
Ordering
::
Relaxed
,
Ordering
::
Relaxed
,
|
current
|
{
current
.checked_sub
(
1
)
})
.ok
();
}
fn
processed_requests
(
&
self
)
->
usize
{
self
.processed_counter
.load
(
Ordering
::
Relaxed
)
}
fn
increment_processed
(
&
self
)
{
self
.processed_counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
fn
metadata
(
&
self
)
->
&
WorkerMetadata
{
&
self
.metadata
}
fn
clone_worker
(
&
self
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
self
.clone
())
}
}
/// Worker factory for creating workers of different types
pub
struct
WorkerFactory
;
impl
WorkerFactory
{
/// Create a regular worker
pub
fn
create_regular
(
url
:
String
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
BasicWorker
::
new
(
url
,
WorkerType
::
Regular
))
}
/// Create a prefill worker with optional bootstrap port
pub
fn
create_prefill
(
url
:
String
,
bootstrap_port
:
Option
<
u16
>
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
BasicWorker
::
new
(
url
,
WorkerType
::
Prefill
{
bootstrap_port
},
))
}
/// Create a decode worker
pub
fn
create_decode
(
url
:
String
)
->
Box
<
dyn
Worker
>
{
Box
::
new
(
BasicWorker
::
new
(
url
,
WorkerType
::
Decode
))
}
/// Create workers from URLs with automatic type detection
pub
fn
create_from_urls
(
regular_urls
:
Vec
<
String
>
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
decode_urls
:
Vec
<
String
>
,
)
->
(
Vec
<
Box
<
dyn
Worker
>>
,
Vec
<
Box
<
dyn
Worker
>>
,
Vec
<
Box
<
dyn
Worker
>>
,
)
{
let
regular_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
regular_urls
.into_iter
()
.map
(
Self
::
create_regular
)
.collect
();
let
prefill_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
prefill_urls
.into_iter
()
.map
(|(
url
,
port
)|
Self
::
create_prefill
(
url
,
port
))
.collect
();
let
decode_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
decode_urls
.into_iter
()
.map
(
Self
::
create_decode
)
.collect
();
(
regular_workers
,
prefill_workers
,
decode_workers
)
}
}
/// Helper trait for collections of workers
pub
trait
WorkerCollection
{
fn
healthy_workers
(
&
self
)
->
Vec
<&
dyn
Worker
>
;
fn
total_load
(
&
self
)
->
usize
;
fn
find_worker
(
&
self
,
url
:
&
str
)
->
Option
<&
dyn
Worker
>
;
fn
find_worker_mut
(
&
mut
self
,
url
:
&
str
)
->
Option
<&
mut
Box
<
dyn
Worker
>>
;
}
impl
WorkerCollection
for
Vec
<
Box
<
dyn
Worker
>>
{
fn
healthy_workers
(
&
self
)
->
Vec
<&
dyn
Worker
>
{
self
.iter
()
.filter
(|
w
|
w
.is_healthy
())
.map
(|
w
|
w
.as_ref
())
.collect
()
}
fn
total_load
(
&
self
)
->
usize
{
self
.iter
()
.map
(|
w
|
w
.load
())
.sum
()
}
fn
find_worker
(
&
self
,
url
:
&
str
)
->
Option
<&
dyn
Worker
>
{
self
.iter
()
.find
(|
w
|
w
.url
()
==
url
)
.map
(|
w
|
w
.as_ref
())
}
fn
find_worker_mut
(
&
mut
self
,
url
:
&
str
)
->
Option
<&
mut
Box
<
dyn
Worker
>>
{
self
.iter_mut
()
.find
(|
w
|
w
.url
()
==
url
)
}
}
/// Convert a list of worker URLs to worker trait objects
pub
fn
urls_to_workers
(
urls
:
Vec
<
String
>
)
->
Vec
<
Box
<
dyn
Worker
>>
{
urls
.into_iter
()
.map
(
WorkerFactory
::
create_regular
)
.collect
()
}
/// Convert worker trait objects back to URLs
pub
fn
workers_to_urls
(
workers
:
&
[
Box
<
dyn
Worker
>
])
->
Vec
<
String
>
{
workers
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
()
}
/// RAII guard for worker load management
pub
struct
WorkerLoadGuard
<
'a
>
{
workers
:
Vec
<&
'a
dyn
Worker
>
,
}
impl
<
'a
>
WorkerLoadGuard
<
'a
>
{
/// Create a new load guard for a single worker
pub
fn
new
(
worker
:
&
'a
dyn
Worker
)
->
Self
{
worker
.increment_load
();
Self
{
workers
:
vec!
[
worker
],
}
}
/// Create a new load guard for multiple workers
pub
fn
new_multi
(
workers
:
Vec
<&
'a
dyn
Worker
>
)
->
Self
{
// Increment load counters for all workers
for
worker
in
&
workers
{
worker
.increment_load
();
}
Self
{
workers
}
}
}
impl
<
'a
>
Drop
for
WorkerLoadGuard
<
'a
>
{
fn
drop
(
&
mut
self
)
{
// Decrement load counters for all workers
for
worker
in
&
self
.workers
{
worker
.decrement_load
();
}
}
}
/// Health checker handle with graceful shutdown
pub
struct
HealthChecker
{
handle
:
tokio
::
task
::
JoinHandle
<
()
>
,
shutdown
:
Arc
<
AtomicBool
>
,
}
impl
fmt
::
Debug
for
HealthChecker
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
f
.debug_struct
(
"HealthChecker"
)
.field
(
"shutdown"
,
&
self
.shutdown
.load
(
Ordering
::
Relaxed
))
.finish
()
}
}
impl
HealthChecker
{
/// Shutdown the health checker gracefully
pub
async
fn
shutdown
(
self
)
{
self
.shutdown
.store
(
true
,
Ordering
::
Release
);
let
_
=
self
.handle
.await
;
}
}
/// Start an async background health checker for a collection of workers
pub
fn
start_health_checker
(
workers
:
std
::
sync
::
Arc
<
std
::
sync
::
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
check_interval_secs
:
u64
,
)
->
HealthChecker
{
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown_clone
=
shutdown
.clone
();
let
handle
=
tokio
::
spawn
(
async
move
{
let
mut
interval
=
tokio
::
time
::
interval
(
tokio
::
time
::
Duration
::
from_secs
(
check_interval_secs
));
loop
{
interval
.tick
()
.await
;
// Check for shutdown signal
if
shutdown_clone
.load
(
Ordering
::
Acquire
)
{
tracing
::
info!
(
"Health checker shutting down"
);
break
;
}
// Check health of all workers
let
workers_to_check
=
match
workers
.read
()
{
Ok
(
guard
)
=>
guard
.iter
()
.map
(|
w
|
w
.clone_worker
())
.collect
::
<
Vec
<
_
>>
(),
Err
(
poisoned
)
=>
{
tracing
::
error!
(
"Worker lock poisoned: {}"
,
poisoned
);
continue
;
}
};
// Perform health checks concurrently
let
health_checks
=
workers_to_check
.iter
()
.map
(|
worker
|
{
let
worker_url
=
worker
.url
()
.to_string
();
let
was_healthy
=
worker
.is_healthy
();
async
move
{
match
worker
.check_health_async
()
.await
{
Ok
(
_
)
=>
{
if
!
was_healthy
{
tracing
::
info!
(
"Worker {} is now healthy"
,
worker_url
);
}
}
Err
(
e
)
=>
{
if
was_healthy
{
tracing
::
warn!
(
"Worker {} health check failed: {}"
,
worker_url
,
e
);
}
}
}
}
});
// Execute all health checks concurrently
futures
::
future
::
join_all
(
health_checks
)
.await
;
}
});
HealthChecker
{
handle
,
shutdown
}
}
sgl-router/src/lib.rs
View file @
f2d5c492
...
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
...
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub
mod
config
;
pub
mod
config
;
pub
mod
logging
;
pub
mod
logging
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
pub
mod
core
;
pub
mod
openai_api_types
;
pub
mod
openai_api_types
;
pub
mod
pd_router
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
pd_types
;
...
...
sgl-router/src/pd_router.rs
View file @
f2d5c492
// PD (Prefill-Decode) Router Implementation
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
// This module handles routing for disaggregated prefill-decode systems
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
use
crate
::
pd_types
::{
use
crate
::
pd_types
::{
Bootstrap
,
ChatReqInput
,
EngineInfo
,
GenerateReqInput
,
PDRouterError
,
PDSelectionPolicy
,
api_path
,
Bootstrap
,
ChatReqInput
,
GenerateReqInput
,
PDRouterError
,
PDSelectionPolicy
,
};
};
use
crate
::
tree
::
Tree
;
use
crate
::
tree
::
Tree
;
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
...
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
...
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
use
metrics
::{
counter
,
histogram
};
use
metrics
::{
counter
,
histogram
};
use
serde_json
::
Value
;
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
time
::{
Duration
,
Instant
};
use
std
::
time
::{
Duration
,
Instant
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
...
@@ -21,49 +21,17 @@ use uuid::Uuid;
...
@@ -21,49 +21,17 @@ use uuid::Uuid;
#[derive(Debug)]
#[derive(Debug)]
pub
struct
PDRouter
{
pub
struct
PDRouter
{
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
EngineInfo
>>>
,
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
EngineInfo
>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
pub
selection_policy
:
PDSelectionPolicy
,
pub
selection_policy
:
PDSelectionPolicy
,
pub
load_tracking
:
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
timeout_secs
:
u64
,
pub
timeout_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
pub
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
pub
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
pub
http_client
:
reqwest
::
Client
,
pub
http_client
:
reqwest
::
Client
,
}
_
prefill_health_checker
:
Option
<
HealthChecker
>
,
_
decode_health_checker
:
Option
<
HealthChecker
>
,
// RAII guard for load tracking to ensure cleanup even on panic
struct
LoadGuard
<
'a
>
{
tracking
:
&
'a
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
urls
:
Vec
<
String
>
,
}
impl
<
'a
>
LoadGuard
<
'a
>
{
fn
new
(
tracking
:
&
'a
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
urls
:
Vec
<
String
>
,
)
->
Self
{
// Increment counters
for
url
in
&
urls
{
let
counter
=
tracking
.entry
(
url
.clone
())
.or_insert_with
(||
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
LoadGuard
{
tracking
,
urls
}
}
}
impl
Drop
for
LoadGuard
<
'_
>
{
fn
drop
(
&
mut
self
)
{
// Guaranteed cleanup even on panic
for
url
in
&
self
.urls
{
if
let
Some
(
counter
)
=
self
.tracking
.get
(
url
)
{
counter
.fetch_sub
(
1
,
Ordering
::
Relaxed
);
}
}
}
}
}
impl
PDRouter
{
impl
PDRouter
{
...
@@ -73,9 +41,6 @@ impl PDRouter {
...
@@ -73,9 +41,6 @@ impl PDRouter {
url
:
String
,
url
:
String
,
bootstrap_port
:
Option
<
u16
>
,
bootstrap_port
:
Option
<
u16
>
,
)
->
Result
<
String
,
PDRouterError
>
{
)
->
Result
<
String
,
PDRouterError
>
{
// Create EngineInfo for the new prefill server
let
engine_info
=
EngineInfo
::
new_prefill
(
url
.clone
(),
bootstrap_port
);
// Wait for the new server to be healthy
// Wait for the new server to be healthy
crate
::
router
::
Router
::
wait_for_healthy_workers
(
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
[
url
.clone
()],
&
[
url
.clone
()],
...
@@ -84,6 +49,9 @@ impl PDRouter {
...
@@ -84,6 +49,9 @@ impl PDRouter {
)
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
// Create Worker for the new prefill server
let
worker
=
WorkerFactory
::
create_prefill
(
url
.clone
(),
bootstrap_port
);
// Add to prefill workers list
// Add to prefill workers list
let
mut
workers
=
self
let
mut
workers
=
self
.prefill_workers
.prefill_workers
...
@@ -93,15 +61,11 @@ impl PDRouter {
...
@@ -93,15 +61,11 @@ impl PDRouter {
})
?
;
})
?
;
// Check if already exists
// Check if already exists
if
workers
.iter
()
.any
(|
w
|
w
.url
==
url
)
{
if
workers
.iter
()
.any
(|
w
|
w
.url
()
==
&
url
)
{
return
Err
(
PDRouterError
::
WorkerAlreadyExists
{
url
:
url
.clone
()
});
return
Err
(
PDRouterError
::
WorkerAlreadyExists
{
url
:
url
.clone
()
});
}
}
workers
.push
(
engine_info
);
workers
.push
(
worker
);
// Initialize load tracking
self
.load_tracking
.insert
(
url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
// Add to cache tree if using cache-aware policy
// Add to cache tree if using cache-aware policy
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
...
@@ -113,9 +77,6 @@ impl PDRouter {
...
@@ -113,9 +77,6 @@ impl PDRouter {
}
}
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
)
->
Result
<
String
,
PDRouterError
>
{
pub
async
fn
add_decode_server
(
&
self
,
url
:
String
)
->
Result
<
String
,
PDRouterError
>
{
// Create EngineInfo for the new decode server
let
engine_info
=
EngineInfo
::
new_decode
(
url
.clone
());
// Wait for the new server to be healthy
// Wait for the new server to be healthy
crate
::
router
::
Router
::
wait_for_healthy_workers
(
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
[
url
.clone
()],
&
[
url
.clone
()],
...
@@ -124,6 +85,9 @@ impl PDRouter {
...
@@ -124,6 +85,9 @@ impl PDRouter {
)
)
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.clone
()
})
?
;
// Create Worker for the new decode server
let
worker
=
WorkerFactory
::
create_decode
(
url
.clone
());
// Add to decode workers list
// Add to decode workers list
let
mut
workers
=
self
let
mut
workers
=
self
.decode_workers
.decode_workers
...
@@ -133,15 +97,14 @@ impl PDRouter {
...
@@ -133,15 +97,14 @@ impl PDRouter {
})
?
;
})
?
;
// Check if already exists
// Check if already exists
if
workers
.iter
()
.any
(|
w
|
w
.url
==
url
)
{
if
workers
.iter
()
.any
(|
w
|
w
.url
()
==
&
url
)
{
return
Err
(
PDRouterError
::
WorkerAlreadyExists
{
url
:
url
.clone
()
});
return
Err
(
PDRouterError
::
WorkerAlreadyExists
{
url
:
url
.clone
()
});
}
}
workers
.push
(
engine_info
);
workers
.push
(
worker
);
// Initialize load tracking
// Initialize load tracking
self
.load_tracking
// Worker tracks its own load internally
.insert
(
url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
info!
(
"Added decode server: {}"
,
url
);
info!
(
"Added decode server: {}"
,
url
);
Ok
(
format!
(
"Successfully added decode server: {}"
,
url
))
Ok
(
format!
(
"Successfully added decode server: {}"
,
url
))
...
@@ -157,7 +120,7 @@ impl PDRouter {
...
@@ -157,7 +120,7 @@ impl PDRouter {
// Find and remove the server
// Find and remove the server
let
initial_len
=
workers
.len
();
let
initial_len
=
workers
.len
();
workers
.retain
(|
w
|
w
.url
!=
url
);
workers
.retain
(|
w
|
w
.url
()
!=
url
);
if
workers
.len
()
==
initial_len
{
if
workers
.len
()
==
initial_len
{
return
Err
(
PDRouterError
::
WorkerNotFound
{
return
Err
(
PDRouterError
::
WorkerNotFound
{
...
@@ -166,7 +129,7 @@ impl PDRouter {
...
@@ -166,7 +129,7 @@ impl PDRouter {
}
}
// Remove from load tracking
// Remove from load tracking
self
.
load
_
tracking
.remove
(
url
);
// Worker
load
tracking
is internal
// Remove from cache tree if using cache-aware policy
// Remove from cache tree if using cache-aware policy
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
...
@@ -174,7 +137,7 @@ impl PDRouter {
...
@@ -174,7 +137,7 @@ impl PDRouter {
let
mut
tree_guard
=
tree
.lock
()
.unwrap
();
let
mut
tree_guard
=
tree
.lock
()
.unwrap
();
*
tree_guard
=
Tree
::
new
();
*
tree_guard
=
Tree
::
new
();
for
worker
in
workers
.iter
()
{
for
worker
in
workers
.iter
()
{
tree_guard
.insert
(
""
,
&
worker
.url
);
tree_guard
.insert
(
""
,
worker
.url
()
);
}
}
}
}
...
@@ -192,7 +155,7 @@ impl PDRouter {
...
@@ -192,7 +155,7 @@ impl PDRouter {
// Find and remove the server
// Find and remove the server
let
initial_len
=
workers
.len
();
let
initial_len
=
workers
.len
();
workers
.retain
(|
w
|
w
.url
!=
url
);
workers
.retain
(|
w
|
w
.url
()
!=
url
);
if
workers
.len
()
==
initial_len
{
if
workers
.len
()
==
initial_len
{
return
Err
(
PDRouterError
::
WorkerNotFound
{
return
Err
(
PDRouterError
::
WorkerNotFound
{
...
@@ -200,9 +163,6 @@ impl PDRouter {
...
@@ -200,9 +163,6 @@ impl PDRouter {
});
});
}
}
// Remove from load tracking
self
.load_tracking
.remove
(
url
);
info!
(
"Removed decode server: {}"
,
url
);
info!
(
"Removed decode server: {}"
,
url
);
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
}
}
...
@@ -214,41 +174,32 @@ impl PDRouter {
...
@@ -214,41 +174,32 @@ impl PDRouter {
timeout_secs
:
u64
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
Self
,
String
>
{
)
->
Result
<
Self
,
String
>
{
// Convert URLs to
EngineInfo
// Convert URLs to
Worker trait objects
let
prefill_workers
:
Vec
<
EngineInfo
>
=
prefill_urls
let
prefill_workers
:
Vec
<
Box
<
dyn
Worker
>
>
=
prefill_urls
.into_iter
()
.into_iter
()
.map
(|(
url
,
port
)|
EngineInfo
::
new
_prefill
(
url
,
port
))
.map
(|(
url
,
port
)|
WorkerFactory
::
create
_prefill
(
url
,
port
))
.collect
();
.collect
();
let
decode_workers
:
Vec
<
EngineInfo
>
=
decode_urls
let
decode_workers
:
Vec
<
Box
<
dyn
Worker
>
>
=
decode_urls
.into_iter
()
.into_iter
()
.map
(
EngineInfo
::
new
_decode
)
.map
(
WorkerFactory
::
create
_decode
)
.collect
();
.collect
();
// Wait for PD workers to be healthy
// Wait for PD workers to be healthy
let
all_urls
:
Vec
<
String
>
=
prefill_workers
let
all_urls
:
Vec
<
String
>
=
prefill_workers
.iter
()
.iter
()
.chain
(
decode_workers
.iter
())
.chain
(
decode_workers
.iter
())
.map
(|
engine
|
engine
.url
.clone
())
.map
(|
worker
|
worker
.url
()
.to_string
())
.collect
();
.collect
();
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
all_urls
,
timeout_secs
,
interval_secs
)
?
;
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
all_urls
,
timeout_secs
,
interval_secs
)
?
;
// Initialize load tracking with atomic counters
let
load_tracking
=
Arc
::
new
(
dashmap
::
DashMap
::
new
());
for
engine
in
&
prefill_workers
{
load_tracking
.insert
(
engine
.url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
}
for
engine
in
&
decode_workers
{
load_tracking
.insert
(
engine
.url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
}
// Initialize cache-aware components if needed
// Initialize cache-aware components if needed
let
prefill_tree
=
match
&
selection_policy
{
let
prefill_tree
=
match
&
selection_policy
{
PDSelectionPolicy
::
CacheAware
{
..
}
=>
{
PDSelectionPolicy
::
CacheAware
{
..
}
=>
{
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
// Initialize tree with prefill workers
// Initialize tree with prefill workers
for
engine
in
&
prefill_workers
{
for
worker
in
&
prefill_workers
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
engine
.url
);
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker
.url
()
);
}
}
Some
(
tree
)
Some
(
tree
)
}
}
...
@@ -283,17 +234,27 @@ impl PDRouter {
...
@@ -283,17 +234,27 @@ impl PDRouter {
None
None
};
};
let
prefill_workers
=
Arc
::
new
(
RwLock
::
new
(
prefill_workers
));
let
decode_workers
=
Arc
::
new
(
RwLock
::
new
(
decode_workers
));
// Start health checkers for both worker pools
let
prefill_health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
prefill_workers
),
interval_secs
);
let
decode_health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
decode_workers
),
interval_secs
);
Ok
(
PDRouter
{
Ok
(
PDRouter
{
prefill_workers
:
Arc
::
new
(
RwLock
::
new
(
prefill_workers
))
,
prefill_workers
,
decode_workers
:
Arc
::
new
(
RwLock
::
new
(
decode_workers
))
,
decode_workers
,
selection_policy
,
selection_policy
,
load_tracking
,
prefill_tree
,
prefill_tree
,
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
worker_loads
,
worker_loads
,
load_monitor_handle
,
load_monitor_handle
,
http_client
,
http_client
,
_
prefill_health_checker
:
Some
(
prefill_health_checker
),
_
decode_health_checker
:
Some
(
decode_health_checker
),
})
})
}
}
...
@@ -330,11 +291,13 @@ impl PDRouter {
...
@@ -330,11 +291,13 @@ impl PDRouter {
// Log routing decision
// Log routing decision
info!
(
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
,
decode
.url
route
,
prefill
.url
(),
decode
.url
()
);
);
// Add bootstrap info using the trait method
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
&
prefill
)
{
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
()
)
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
error!
(
"Failed to add bootstrap info: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
return
HttpResponse
::
InternalServerError
()
return
HttpResponse
::
InternalServerError
()
...
@@ -356,8 +319,8 @@ impl PDRouter {
...
@@ -356,8 +319,8 @@ impl PDRouter {
req
,
req
,
json_with_bootstrap
,
json_with_bootstrap
,
route
,
route
,
&
prefill
,
prefill
.as_ref
()
,
&
decode
,
decode
.as_ref
()
,
is_stream
,
is_stream
,
return_logprob
,
return_logprob
,
start
,
start
,
...
@@ -397,11 +360,13 @@ impl PDRouter {
...
@@ -397,11 +360,13 @@ impl PDRouter {
// Log routing decision
// Log routing decision
info!
(
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
,
decode
.url
route
,
prefill
.url
(),
decode
.url
()
);
);
// Add bootstrap info using the trait method
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
&
prefill
)
{
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
prefill
.as_ref
()
)
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
error!
(
"Failed to add bootstrap info: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
return
HttpResponse
::
InternalServerError
()
return
HttpResponse
::
InternalServerError
()
...
@@ -423,8 +388,8 @@ impl PDRouter {
...
@@ -423,8 +388,8 @@ impl PDRouter {
req
,
req
,
json_with_bootstrap
,
json_with_bootstrap
,
route
,
route
,
&
prefill
,
prefill
.as_ref
()
,
&
decode
,
decode
.as_ref
()
,
is_stream
,
is_stream
,
return_logprob
,
return_logprob
,
start
,
start
,
...
@@ -440,22 +405,23 @@ impl PDRouter {
...
@@ -440,22 +405,23 @@ impl PDRouter {
req
:
&
HttpRequest
,
req
:
&
HttpRequest
,
json_request
:
serde_json
::
Value
,
json_request
:
serde_json
::
Value
,
route
:
&
str
,
route
:
&
str
,
prefill
:
&
EngineInfo
,
prefill
:
&
dyn
Worker
,
decode
:
&
EngineInfo
,
decode
:
&
dyn
Worker
,
is_stream
:
bool
,
is_stream
:
bool
,
return_logprob
:
bool
,
return_logprob
:
bool
,
start_time
:
Instant
,
start_time
:
Instant
,
)
->
HttpResponse
{
)
->
HttpResponse
{
// Update load tracking for both workers
// Update load tracking for both workers
let
_
guard
=
LoadGuard
::
new
(
let
_
guard
=
WorkerLoadGuard
::
new_multi
(
vec!
[
prefill
,
decode
]);
&
self
.load_tracking
,
vec!
[
prefill
.url
.clone
(),
decode
.url
.clone
()],
);
// Build requests using .json() method
// Build requests using .json() method
let
mut
prefill_request
=
client
.post
(
prefill
.api_path
(
route
))
.json
(
&
json_request
);
let
mut
prefill_request
=
client
.post
(
api_path
(
prefill
.url
(),
route
))
.json
(
&
json_request
);
let
mut
decode_request
=
client
.post
(
decode
.api_path
(
route
))
.json
(
&
json_request
);
let
mut
decode_request
=
client
.post
(
api_path
(
decode
.url
(),
route
))
.json
(
&
json_request
);
// Copy headers from original request
// Copy headers from original request
for
(
name
,
value
)
in
crate
::
router
::
copy_request_headers
(
req
)
{
for
(
name
,
value
)
in
crate
::
router
::
copy_request_headers
(
req
)
{
...
@@ -474,9 +440,9 @@ impl PDRouter {
...
@@ -474,9 +440,9 @@ impl PDRouter {
histogram!
(
"sgl_router_pd_request_duration_seconds"
,
"route"
=>
route
.to_string
())
histogram!
(
"sgl_router_pd_request_duration_seconds"
,
"route"
=>
route
.to_string
())
.record
(
duration
.as_secs_f64
());
.record
(
duration
.as_secs_f64
());
counter!
(
"sgl_router_pd_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_prefill_requests_total"
,
"worker"
=>
prefill
.url
.to_string
())
counter!
(
"sgl_router_pd_prefill_requests_total"
,
"worker"
=>
prefill
.url
()
.to_string
())
.increment
(
1
);
.increment
(
1
);
counter!
(
"sgl_router_pd_decode_requests_total"
,
"worker"
=>
decode
.url
.to_string
())
counter!
(
"sgl_router_pd_decode_requests_total"
,
"worker"
=>
decode
.url
()
.to_string
())
.increment
(
1
);
.increment
(
1
);
// Process decode response
// Process decode response
...
@@ -486,10 +452,11 @@ impl PDRouter {
...
@@ -486,10 +452,11 @@ impl PDRouter {
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
if
!
status
.is_success
()
{
if
!
status
.is_success
()
{
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
()
.to_string
())
.increment
(
1
);
error!
(
error!
(
"Decode server {} returned error status: {}"
,
"Decode server {} returned error status: {}"
,
decode
.url
,
status
decode
.url
(),
status
);
);
// Return the error response from decode server
// Return the error response from decode server
...
@@ -508,9 +475,10 @@ impl PDRouter {
...
@@ -508,9 +475,10 @@ impl PDRouter {
if
let
Err
(
e
)
=
&
prefill_result
{
if
let
Err
(
e
)
=
&
prefill_result
{
error!
(
error!
(
"Prefill server {} failed (non-critical): {}"
,
"Prefill server {} failed (non-critical): {}"
,
prefill
.url
,
e
prefill
.url
(),
e
);
);
counter!
(
"sgl_router_pd_prefill_errors_total"
,
"worker"
=>
prefill
.url
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_prefill_errors_total"
,
"worker"
=>
prefill
.url
()
.to_string
())
.increment
(
1
);
}
}
if
is_stream
{
if
is_stream
{
...
@@ -559,7 +527,7 @@ impl PDRouter {
...
@@ -559,7 +527,7 @@ impl PDRouter {
HttpResponse
::
build
(
status
)
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
({
.streaming
({
let
decode_url
=
decode
.url
.clone
();
let
decode_url
=
decode
.url
()
.to_string
();
res
.bytes_stream
()
.map_err
(
move
|
e
|
{
res
.bytes_stream
()
.map_err
(
move
|
e
|
{
error!
(
"Stream error from decode server {}: {}"
,
decode_url
,
e
);
error!
(
"Stream error from decode server {}: {}"
,
decode_url
,
e
);
counter!
(
"sgl_router_pd_stream_errors_total"
,
"worker"
=>
decode_url
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_stream_errors_total"
,
"worker"
=>
decode_url
.to_string
())
.increment
(
1
);
...
@@ -587,7 +555,7 @@ impl PDRouter {
...
@@ -587,7 +555,7 @@ impl PDRouter {
}
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
error!
(
"Decode request failed: {}"
,
e
);
error!
(
"Decode request failed: {}"
,
e
);
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
.to_string
())
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
()
.to_string
())
.increment
(
1
);
.increment
(
1
);
HttpResponse
::
BadGateway
()
.body
(
format!
(
"Decode server error: {}"
,
e
))
HttpResponse
::
BadGateway
()
.body
(
format!
(
"Decode server error: {}"
,
e
))
}
}
...
@@ -652,7 +620,7 @@ impl PDRouter {
...
@@ -652,7 +620,7 @@ impl PDRouter {
async
fn
select_pd_pair
(
async
fn
select_pd_pair
(
&
self
,
&
self
,
_
client
:
&
reqwest
::
Client
,
_
client
:
&
reqwest
::
Client
,
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
)
->
Result
<
(
Box
<
dyn
Worker
>
,
Box
<
dyn
Worker
>
),
String
>
{
// Check we have workers
// Check we have workers
if
self
if
self
.prefill_workers
.prefill_workers
...
@@ -681,17 +649,17 @@ impl PDRouter {
...
@@ -681,17 +649,17 @@ impl PDRouter {
}
}
}
}
fn
select_random
(
&
self
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
fn
select_random
(
&
self
)
->
Result
<
(
Box
<
dyn
Worker
>
,
Box
<
dyn
Worker
>
),
String
>
{
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
prefill
=
prefill_list
[
rand
::
random
::
<
usize
>
()
%
prefill_list
.len
()]
.clone
();
let
prefill
=
prefill_list
[
rand
::
random
::
<
usize
>
()
%
prefill_list
.len
()]
.clone
_worker
();
let
decode
=
decode_list
[
rand
::
random
::
<
usize
>
()
%
decode_list
.len
()]
.clone
();
let
decode
=
decode_list
[
rand
::
random
::
<
usize
>
()
%
decode_list
.len
()]
.clone
_worker
();
Ok
((
prefill
,
decode
))
Ok
((
prefill
,
decode
))
}
}
async
fn
select_power_of_two
(
&
self
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
async
fn
select_power_of_two
(
&
self
)
->
Result
<
(
Box
<
dyn
Worker
>
,
Box
<
dyn
Worker
>
),
String
>
{
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
...
@@ -700,33 +668,45 @@ impl PDRouter {
...
@@ -700,33 +668,45 @@ impl PDRouter {
let
loads
=
self
.worker_loads
.borrow
();
let
loads
=
self
.worker_loads
.borrow
();
let
p1_load
=
loads
.get
(
&
prefill_list
[
p1_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
p1_load
=
loads
let
p2_load
=
loads
.get
(
&
prefill_list
[
p2_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
.get
(
prefill_list
[
p1_idx
]
.url
())
let
d1_load
=
loads
.get
(
&
decode_list
[
d1_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
.copied
()
let
d2_load
=
loads
.get
(
&
decode_list
[
d2_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
.unwrap_or
(
isize
::
MAX
);
let
p2_load
=
loads
.get
(
prefill_list
[
p2_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
let
d1_load
=
loads
.get
(
decode_list
[
d1_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
let
d2_load
=
loads
.get
(
decode_list
[
d2_idx
]
.url
())
.copied
()
.unwrap_or
(
isize
::
MAX
);
info!
(
info!
(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}"
,
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}"
,
prefill_list
[
p1_idx
]
.url
,
prefill_list
[
p1_idx
]
.url
()
,
p1_load
,
p1_load
,
prefill_list
[
p2_idx
]
.url
,
prefill_list
[
p2_idx
]
.url
()
,
p2_load
,
p2_load
,
decode_list
[
d1_idx
]
.url
,
decode_list
[
d1_idx
]
.url
()
,
d1_load
,
d1_load
,
decode_list
[
d2_idx
]
.url
,
decode_list
[
d2_idx
]
.url
()
,
d2_load
d2_load
);
);
let
selected_prefill
=
if
p1_load
<=
p2_load
{
let
selected_prefill
=
if
p1_load
<=
p2_load
{
prefill_list
[
p1_idx
]
.clone
()
prefill_list
[
p1_idx
]
.clone
_worker
()
}
else
{
}
else
{
prefill_list
[
p2_idx
]
.clone
()
prefill_list
[
p2_idx
]
.clone
_worker
()
};
};
let
selected_decode
=
if
d1_load
<=
d2_load
{
let
selected_decode
=
if
d1_load
<=
d2_load
{
decode_list
[
d1_idx
]
.clone
()
decode_list
[
d1_idx
]
.clone
_worker
()
}
else
{
}
else
{
decode_list
[
d2_idx
]
.clone
()
decode_list
[
d2_idx
]
.clone
_worker
()
};
};
Ok
((
selected_prefill
,
selected_decode
))
Ok
((
selected_prefill
,
selected_decode
))
...
@@ -868,11 +848,11 @@ impl PDRouter {
...
@@ -868,11 +848,11 @@ impl PDRouter {
let
mut
worker_infos
=
Vec
::
new
();
let
mut
worker_infos
=
Vec
::
new
();
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
worker_infos
.push
((
worker
.url
.clone
(),
"prefill"
));
worker_infos
.push
((
worker
.url
()
.to_string
(),
"prefill"
));
}
}
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
worker_infos
.push
((
worker
.url
.clone
(),
"decode"
));
worker_infos
.push
((
worker
.url
()
.to_string
(),
"decode"
));
}
}
// Create tasks with URL tracking
// Create tasks with URL tracking
...
@@ -922,7 +902,7 @@ impl PDRouter {
...
@@ -922,7 +902,7 @@ impl PDRouter {
pub
async
fn
get_server_info
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
pub
async
fn
get_server_info
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
// Get info from the first decode server to match sglang's server info format
// Get info from the first decode server to match sglang's server info format
let
first_decode_url
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
let
first_decode_url
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
}
else
{
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access decode workers"
);
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access decode workers"
);
};
};
...
@@ -967,7 +947,7 @@ impl PDRouter {
...
@@ -967,7 +947,7 @@ impl PDRouter {
pub
async
fn
get_models
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
)
->
HttpResponse
{
pub
async
fn
get_models
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
)
->
HttpResponse
{
// Get first prefill worker URL to avoid holding lock across await
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
}
else
{
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
};
};
...
@@ -1005,14 +985,14 @@ impl PDRouter {
...
@@ -1005,14 +985,14 @@ impl PDRouter {
.read
()
.read
()
.unwrap
()
.unwrap
()
.iter
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
.collect
();
let
d_urls
:
Vec
<
_
>
=
self
let
d_urls
:
Vec
<
_
>
=
self
.decode_workers
.decode_workers
.read
()
.read
()
.unwrap
()
.unwrap
()
.iter
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
.collect
();
let
mut
prefill_loads
=
Vec
::
new
();
let
mut
prefill_loads
=
Vec
::
new
();
...
@@ -1048,7 +1028,7 @@ impl PDRouter {
...
@@ -1048,7 +1028,7 @@ impl PDRouter {
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
}
else
{
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
};
};
...
@@ -1084,13 +1064,13 @@ impl PDRouter {
...
@@ -1084,13 +1064,13 @@ impl PDRouter {
// Flush cache on all prefill servers
// Flush cache on all prefill servers
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
);
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
()
);
tasks
.push
(
client
.post
(
&
url
)
.send
());
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
}
// Flush cache on all decode servers
// Flush cache on all decode servers
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
);
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
()
);
tasks
.push
(
client
.post
(
&
url
)
.send
());
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
}
...
...
sgl-router/src/pd_types.rs
View file @
f2d5c492
// Essential PDLB types extracted for PD routing
// Essential PDLB types extracted for PD routing
use
crate
::
core
::{
Worker
,
WorkerType
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
serde_json
::
Value
;
...
@@ -28,52 +29,21 @@ pub enum PDRouterError {
...
@@ -28,52 +29,21 @@ pub enum PDRouterError {
Timeout
{
url
:
String
},
Timeout
{
url
:
String
},
}
}
#[derive(Debug,
Clone)]
// Helper functions for workers
pub
enum
EngineType
{
pub
fn
api_path
(
url
:
&
str
,
api_path
:
&
str
)
->
String
{
Prefill
,
if
api_path
.starts_with
(
"/"
)
{
Decode
,
format!
(
"{}{}"
,
url
,
api_path
)
}
}
else
{
format!
(
"{}/{}"
,
url
,
api_path
)
#[derive(Debug,
Clone)]
pub
struct
EngineInfo
{
pub
engine_type
:
EngineType
,
pub
url
:
String
,
pub
bootstrap_port
:
Option
<
u16
>
,
}
impl
EngineInfo
{
pub
fn
new_prefill
(
url
:
String
,
bootstrap_port
:
Option
<
u16
>
)
->
Self
{
EngineInfo
{
engine_type
:
EngineType
::
Prefill
,
url
,
bootstrap_port
,
}
}
pub
fn
new_decode
(
url
:
String
)
->
Self
{
EngineInfo
{
engine_type
:
EngineType
::
Decode
,
url
,
bootstrap_port
:
None
,
}
}
pub
fn
api_path
(
&
self
,
api_path
:
&
str
)
->
String
{
if
api_path
.starts_with
(
"/"
)
{
format!
(
"{}{}"
,
self
.url
,
api_path
)
}
else
{
format!
(
"{}/{}"
,
self
.url
,
api_path
)
}
}
}
}
pub
fn
get_hostname
(
&
self
)
->
String
{
pub
fn
get_hostname
(
url
:
&
str
)
->
String
{
// Simple hostname extraction without external dependencies
// Simple hostname extraction without external dependencies
let
url
=
self
let
url
=
url
.url
.trim_start_matches
(
"http://"
)
.trim_start_matches
(
"http://"
)
.trim_start_matches
(
"https://"
);
.trim_start_matches
(
"https://"
);
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
}
}
}
// PD-specific routing policies
// PD-specific routing policies
...
@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
...
@@ -112,12 +82,21 @@ pub trait Bootstrap: Send + Sync {
bootstrap_room
:
BootstrapRoom
,
bootstrap_room
:
BootstrapRoom
,
);
);
fn
add_bootstrap_info
(
&
mut
self
,
prefill_
info
:
&
EngineInfo
)
->
Result
<
(),
String
>
{
fn
add_bootstrap_info
(
&
mut
self
,
prefill_
worker
:
&
dyn
Worker
)
->
Result
<
(),
String
>
{
let
batch_size
=
self
.get_batch_size
()
?
;
let
batch_size
=
self
.get_batch_size
()
?
;
// Extract bootstrap port from prefill worker if it's a prefill type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
prefill_worker
.url
());
if
let
Some
(
batch_size
)
=
batch_size
{
if
let
Some
(
batch_size
)
=
batch_size
{
self
.set_bootstrap_info
(
self
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
prefill_info
.get_
hostname
()
;
batch_size
]),
BootstrapHost
::
Batch
(
vec!
[
hostname
;
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
prefill_info
.
bootstrap_port
;
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
bootstrap_port
;
batch_size
]),
// Use high-quality random numbers to minimize collision risk
// Use high-quality random numbers to minimize collision risk
BootstrapRoom
::
Batch
(
BootstrapRoom
::
Batch
(
(
0
..
batch_size
)
(
0
..
batch_size
)
...
@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
...
@@ -132,8 +111,8 @@ pub trait Bootstrap: Send + Sync {
);
);
}
else
{
}
else
{
self
.set_bootstrap_info
(
self
.set_bootstrap_info
(
BootstrapHost
::
Single
(
prefill_info
.get_
hostname
()
),
BootstrapHost
::
Single
(
hostname
),
BootstrapPort
::
Single
(
prefill_info
.
bootstrap_port
),
BootstrapPort
::
Single
(
bootstrap_port
),
BootstrapRoom
::
Single
({
BootstrapRoom
::
Single
({
// Use high-quality random number for single requests too
// Use high-quality random number for single requests too
let
r1
=
rand
::
random
::
<
u64
>
();
let
r1
=
rand
::
random
::
<
u64
>
();
...
...
sgl-router/src/router.rs
View file @
f2d5c492
use
crate
::
core
::{
HealthChecker
,
Worker
,
WorkerFactory
};
use
crate
::
pd_router
::
PDRouter
;
use
crate
::
pd_router
::
PDRouter
;
use
crate
::
pd_types
::
PDSelectionPolicy
;
use
crate
::
pd_types
::
PDSelectionPolicy
;
use
crate
::
tree
::
Tree
;
use
crate
::
tree
::
Tree
;
...
@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
...
@@ -5,7 +6,6 @@ use ::metrics::{counter, gauge, histogram};
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
...
@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
...
@@ -30,15 +30,17 @@ pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
#[derive(Debug)]
#[derive(Debug)]
pub
enum
Router
{
pub
enum
Router
{
RoundRobin
{
RoundRobin
{
worker
_url
s
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
current_index
:
AtomicUsize
,
current_index
:
AtomicUsize
,
timeout_secs
:
u64
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
interval_secs
:
u64
,
_
health_checker
:
Option
<
HealthChecker
>
,
},
},
Random
{
Random
{
worker
_url
s
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
timeout_secs
:
u64
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
interval_secs
:
u64
,
_
health_checker
:
Option
<
HealthChecker
>
,
},
},
PrefillDecode
{
PrefillDecode
{
pd_router
:
Arc
<
PDRouter
>
,
pd_router
:
Arc
<
PDRouter
>
,
...
@@ -104,16 +106,15 @@ pub enum Router {
...
@@ -104,16 +106,15 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
during the next eviction cycle.
*/
*/
worker
_url
s
:
Arc
<
RwLock
<
Vec
<
String
>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>
>>>
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
running_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
processed_queue
:
Arc
<
Mutex
<
HashMap
<
String
,
usize
>>>
,
cache_threshold
:
f32
,
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
timeout_secs
:
u64
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
interval_secs
:
u64
,
_
eviction_thread
:
Option
<
thread
::
JoinHandle
<
()
>>
,
_
eviction_thread
:
Option
<
thread
::
JoinHandle
<
()
>>
,
_
health_checker
:
Option
<
HealthChecker
>
,
},
},
}
}
...
@@ -192,25 +193,43 @@ impl Router {
...
@@ -192,25 +193,43 @@ impl Router {
}
}
}
}
// Create Worker trait objects from URLs
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
worker_urls
.iter
()
.map
(|
url
|
WorkerFactory
::
create_regular
(
url
.clone
()))
.collect
();
// Create router based on policy...
// Create router based on policy...
Ok
(
match
policy_config
{
Ok
(
match
policy_config
{
PolicyConfig
::
RandomConfig
{
PolicyConfig
::
RandomConfig
{
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
}
=>
Router
::
Random
{
}
=>
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
));
timeout_secs
,
let
health_checker
=
interval_secs
,
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
},
Router
::
Random
{
workers
,
timeout_secs
,
interval_secs
,
_
health_checker
:
Some
(
health_checker
),
}
}
PolicyConfig
::
RoundRobinConfig
{
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
}
=>
Router
::
RoundRobin
{
}
=>
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
)),
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
));
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
let
health_checker
=
timeout_secs
,
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
interval_secs
,
Router
::
RoundRobin
{
},
workers
,
current_index
:
std
::
sync
::
atomic
::
AtomicUsize
::
new
(
0
),
timeout_secs
,
interval_secs
,
_
health_checker
:
Some
(
health_checker
),
}
}
PolicyConfig
::
CacheAwareConfig
{
PolicyConfig
::
CacheAwareConfig
{
cache_threshold
,
cache_threshold
,
balance_abs_threshold
,
balance_abs_threshold
,
...
@@ -220,24 +239,12 @@ impl Router {
...
@@ -220,24 +239,12 @@ impl Router {
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
}
=>
{
}
=>
{
let
mut
running_queue
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
running_queue
.insert
(
url
.clone
(),
0
);
}
let
mut
processed_queue
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
processed_queue
.insert
(
url
.clone
(),
0
);
}
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
running_queue
=
Arc
::
new
(
Mutex
::
new
(
running_queue
));
let
processed_queue
=
Arc
::
new
(
Mutex
::
new
(
processed_queue
));
// Create background eviction thread
// Create background eviction thread
let
tree_clone
=
Arc
::
clone
(
&
tree
);
let
tree_clone
=
Arc
::
clone
(
&
tree
);
let
processed_queue_clone
=
Arc
::
clone
(
&
processed_queue
);
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
)
);
let
running_queue
_clone
=
Arc
::
clone
(
&
running_queue
);
let
workers
_clone
=
Arc
::
clone
(
&
workers
);
let
eviction_thread
=
thread
::
spawn
(
move
||
{
let
eviction_thread
=
thread
::
spawn
(
move
||
{
loop
{
loop
{
// Sleep for the specified interval
// Sleep for the specified interval
...
@@ -246,32 +253,41 @@ impl Router {
...
@@ -246,32 +253,41 @@ impl Router {
let
locked_tree_clone
=
tree_clone
.lock
()
.unwrap
();
let
locked_tree_clone
=
tree_clone
.lock
()
.unwrap
();
// Run eviction
// Run eviction
locked_tree_clone
.evict_tenant_by_size
(
max_tree_size
);
locked_tree_clone
.evict_tenant_by_size
(
max_tree_size
);
drop
(
locked_tree_clone
);
// Print the process queue
let
locked_processed_queue
=
processed_queue_clone
.lock
()
.unwrap
();
// Log worker loads and processed requests
info!
(
"Processed Queue: {:?}"
,
locked_processed_queue
);
let
workers_guard
=
workers_clone
.read
()
.unwrap
();
let
loads
:
Vec
<
(
String
,
usize
)
>
=
workers_guard
// Print the running queue
.iter
()
let
locked_running_queue
=
running_queue_clone
.lock
()
.unwrap
();
.map
(|
w
|
(
w
.url
()
.to_string
(),
w
.load
()))
info!
(
"Running Queue: {:?}"
,
locked_running_queue
);
.collect
();
info!
(
"Worker loads: {:?}"
,
loads
);
let
processed
:
Vec
<
(
String
,
usize
)
>
=
workers_guard
.iter
()
.map
(|
w
|
(
w
.url
()
.to_string
(),
w
.processed_requests
()))
.collect
();
info!
(
"Processed requests: {:?}"
,
processed
);
}
}
});
});
for
url
in
&
worker
_urls
{
for
worker
in
worker
s
.read
()
.unwrap
()
.iter
()
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
url
);
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker
.
url
()
);
}
}
let
health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
Router
::
CacheAware
{
Router
::
CacheAware
{
worker
_urls
:
Arc
::
new
(
RwLock
::
new
(
worker_urls
))
,
worker
s
,
tree
,
tree
,
running_queue
,
processed_queue
,
cache_threshold
,
cache_threshold
,
balance_abs_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
balance_rel_threshold
,
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
_
eviction_thread
:
Some
(
eviction_thread
),
_
eviction_thread
:
Some
(
eviction_thread
),
_
health_checker
:
Some
(
health_checker
),
}
}
}
}
PolicyConfig
::
PrefillDecodeConfig
{
PolicyConfig
::
PrefillDecodeConfig
{
...
@@ -297,16 +313,18 @@ impl Router {
...
@@ -297,16 +313,18 @@ impl Router {
})
})
}
}
/// Get
a reference to the worker URLs shared across thread
s
/// Get
the current list of worker URL
s
pub
fn
get_worker_urls
(
&
self
)
->
Arc
<
RwLock
<
Vec
<
String
>
>>
{
pub
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
RoundRobin
{
workers
,
..
}
Router
::
Random
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
|
Router
::
Random
{
workers
,
..
}
Router
::
CacheAware
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
|
Router
::
CacheAware
{
workers
,
..
}
=>
workers
Router
::
PrefillDecode
{
..
}
=>
{
.read
()
// For PD mode, return empty list since we manage workers differently
.unwrap
()
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()))
.iter
()
}
.map
(|
w
|
w
.url
()
.to_string
())
.collect
(),
Router
::
PrefillDecode
{
..
}
=>
Vec
::
new
(),
}
}
}
}
...
@@ -373,13 +391,14 @@ impl Router {
...
@@ -373,13 +391,14 @@ impl Router {
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
..
}
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
worker_urls
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
worker_urls
,
..
}
=>
{
|
Router
::
CacheAware
{
workers
,
..
}
=>
{
if
worker_urls
.read
()
.unwrap
()
.is_empty
()
{
let
workers_guard
=
workers
.read
()
.unwrap
();
if
workers_guard
.is_empty
()
{
Err
(
"No workers are available"
.to_string
())
Err
(
"No workers are available"
.to_string
())
}
else
{
}
else
{
Ok
(
worker
_urls
.read
()
.unwrap
()[
0
]
.clone
())
Ok
(
worker
s_guard
[
0
]
.url
()
.to_string
())
}
}
}
}
Router
::
PrefillDecode
{
..
}
=>
{
Router
::
PrefillDecode
{
..
}
=>
{
...
@@ -514,7 +533,7 @@ impl Router {
...
@@ -514,7 +533,7 @@ impl Router {
return
HttpResponse
::
NotImplemented
()
return
HttpResponse
::
NotImplemented
()
.body
(
"route_to_all not implemented for PrefillDecode mode"
);
.body
(
"route_to_all not implemented for PrefillDecode mode"
);
}
}
_
=>
self
.get_worker_urls
()
.read
()
.unwrap
()
.clone
()
,
_
=>
self
.get_worker_urls
(),
};
};
// Send requests to all workers concurrently
// Send requests to all workers concurrently
...
@@ -562,7 +581,7 @@ impl Router {
...
@@ -562,7 +581,7 @@ impl Router {
}
}
}
}
let
urls
=
self
.get_worker_urls
()
.read
()
.unwrap
()
.clone
()
;
let
urls
=
self
.get_worker_urls
();
let
prefill_urls
:
Vec
<
String
>
=
Vec
::
new
();
let
prefill_urls
:
Vec
<
String
>
=
Vec
::
new
();
let
decode_urls
=
urls
;
let
decode_urls
=
urls
;
...
@@ -631,6 +650,24 @@ impl Router {
...
@@ -631,6 +650,24 @@ impl Router {
.increment
(
1
);
.increment
(
1
);
}
}
// For CacheAware router, increment load before request
let
load_incremented
=
match
self
{
Router
::
CacheAware
{
workers
,
..
}
=>
{
let
workers_guard
=
workers
.read
()
.unwrap
();
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
&
worker_url
)
{
worker
.increment_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
true
}
else
{
false
}
}
_
=>
false
,
};
// Send typed request directly
// Send typed request directly
let
response
=
self
let
response
=
self
.send_typed_request
(
.send_typed_request
(
...
@@ -640,6 +677,7 @@ impl Router {
...
@@ -640,6 +677,7 @@ impl Router {
route
,
route
,
&
worker_url
,
&
worker_url
,
is_stream
,
is_stream
,
load_incremented
,
)
)
.await
;
.await
;
...
@@ -684,44 +722,47 @@ impl Router {
...
@@ -684,44 +722,47 @@ impl Router {
}
}
}
}
// Helper method to select worker from text
// Helper method to select worker from text
(returns index for RoundRobin/Random, URL for CacheAware)
fn
select_generate_worker_from_text
(
&
self
,
text
:
&
str
)
->
String
{
fn
select_generate_worker_from_text
(
&
self
,
text
:
&
str
)
->
String
{
match
self
{
match
self
{
Router
::
RoundRobin
{
Router
::
RoundRobin
{
worker
_url
s
,
workers
,
current_index
,
current_index
,
..
..
}
=>
{
}
=>
{
let
workers_guard
=
workers
.read
()
.unwrap
();
let
idx
=
current_index
let
idx
=
current_index
.fetch_update
(
.fetch_update
(
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
,
|
x
|
Some
((
x
+
1
)
%
worker
_urls
.read
()
.unwrap
()
.len
()),
|
x
|
Some
((
x
+
1
)
%
worker
s_guard
.len
()),
)
)
.unwrap
();
.unwrap
();
worker
_urls
.read
()
.unwrap
()[
idx
]
.clone
()
worker
s_guard
[
idx
]
.url
()
.to_string
()
}
}
Router
::
Random
{
worker_urls
,
..
}
=>
worker_urls
.read
()
.unwrap
()
Router
::
Random
{
workers
,
..
}
=>
{
[
rand
::
random
::
<
usize
>
()
%
worker_urls
.read
()
.unwrap
()
.len
()]
let
workers_guard
=
workers
.read
()
.unwrap
();
.clone
(),
workers_guard
[
rand
::
random
::
<
usize
>
()
%
workers_guard
.len
()]
.url
()
.to_string
()
}
Router
::
CacheAware
{
Router
::
CacheAware
{
worker
_url
s
,
workers
,
tree
,
tree
,
running_queue
,
processed_queue
,
cache_threshold
,
cache_threshold
,
balance_abs_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
balance_rel_threshold
,
..
..
}
=>
{
}
=>
{
let
tree
=
tree
.lock
()
.unwrap
();
let
tree
=
tree
.lock
()
.unwrap
();
let
mut
running_queue
=
running_queue
.lock
()
.unwrap
();
let
workers_guard
=
workers
.read
()
.unwrap
();
// Get current load statistics
// Get current load statistics from workers
let
max_load
=
*
running_queue
.values
()
.max
()
.unwrap_or
(
&
0
);
let
loads
:
Vec
<
usize
>
=
workers_guard
.iter
()
.map
(|
w
|
w
.load
())
.collect
();
let
min_load
=
*
running_queue
.values
()
.min
()
.unwrap_or
(
&
0
);
let
max_load
=
*
loads
.iter
()
.max
()
.unwrap_or
(
&
0
);
let
min_load
=
*
loads
.iter
()
.min
()
.unwrap_or
(
&
0
);
// Load is considered imbalanced if:
// Load is considered imbalanced if:
// 1. (max - min) > abs_threshold AND
// 1. (max - min) > abs_threshold AND
...
@@ -731,11 +772,16 @@ impl Router {
...
@@ -731,11 +772,16 @@ impl Router {
let
selected_url
=
if
is_imbalanced
{
let
selected_url
=
if
is_imbalanced
{
// Log load balancing trigger and current queue state
// Log load balancing trigger and current queue state
let
worker_loads
:
Vec
<
(
String
,
usize
)
>
=
workers_guard
.iter
()
.map
(|
w
|
(
w
.url
()
.to_string
(),
w
.load
()))
.collect
();
info!
(
info!
(
"Load balancing triggered due to workload imbalance:
\n
\
"Load balancing triggered due to workload imbalance:
\n
\
Max load: {}, Min load: {}
\n
\
Max load: {}, Min load: {}
\n
\
Current
running queue
: {:?}"
,
Current
worker loads
: {:?}"
,
max_load
,
min_load
,
running_queue
max_load
,
min_load
,
worker_loads
);
);
counter!
(
"sgl_router_load_balancing_events_total"
)
.increment
(
1
);
counter!
(
"sgl_router_load_balancing_events_total"
)
.increment
(
1
);
...
@@ -743,11 +789,11 @@ impl Router {
...
@@ -743,11 +789,11 @@ impl Router {
gauge!
(
"sgl_router_min_load"
)
.set
(
min_load
as
f64
);
gauge!
(
"sgl_router_min_load"
)
.set
(
min_load
as
f64
);
// Use shortest queue routing when load is imbalanced
// Use shortest queue routing when load is imbalanced
running_queue
workers_guard
.iter
()
.iter
()
.min_by_key
(|
(
_u
rl
,
&
count
)|
count
)
.min_by_key
(|
w
|
w
.load
()
)
.map
(|
(
url
,
_
)|
url
.clone
())
.map
(|
w
|
w
.url
()
.to_string
())
.unwrap_or_else
(||
worker
_urls
.read
()
.unwrap
()[
0
]
.clone
())
.unwrap_or_else
(||
worker
s_guard
[
0
]
.url
()
.to_string
())
}
else
{
}
else
{
// Use cache-aware routing when load is balanced
// Use cache-aware routing when load is balanced
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
&
text
);
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
&
text
);
...
@@ -763,18 +809,12 @@ impl Router {
...
@@ -763,18 +809,12 @@ impl Router {
}
}
};
};
// Update queues and tree
// Find the selected worker and increment processed counter only
*
running_queue
.get_mut
(
&
selected_url
)
.unwrap
()
+=
1
;
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
&
selected_url
)
{
worker
.increment_processed
();
*
processed_queue
counter!
(
"sgl_router_processed_requests_total"
,
"worker"
=>
selected_url
.to_string
())
.lock
()
.increment
(
1
);
.unwrap
()
}
.get_mut
(
&
selected_url
)
.unwrap
()
+=
1
;
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
selected_url
.to_string
())
.set
(
*
running_queue
.get
(
&
selected_url
)
.unwrap
()
as
f64
);
counter!
(
"sgl_router_processed_requests_total"
,
"worker"
=>
selected_url
.to_string
())
.increment
(
1
);
tree
.insert
(
&
text
,
&
selected_url
);
tree
.insert
(
&
text
,
&
selected_url
);
...
@@ -796,6 +836,7 @@ impl Router {
...
@@ -796,6 +836,7 @@ impl Router {
route
:
&
str
,
route
:
&
str
,
worker_url
:
&
str
,
worker_url
:
&
str
,
is_stream
:
bool
,
is_stream
:
bool
,
load_incremented
:
bool
,
// Whether load was incremented for this request
)
->
HttpResponse
{
)
->
HttpResponse
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
...
@@ -820,6 +861,22 @@ impl Router {
...
@@ -820,6 +861,22 @@ impl Router {
Ok
(
res
)
=>
res
,
Ok
(
res
)
=>
res
,
Err
(
e
)
=>
{
Err
(
e
)
=>
{
error!
(
"Failed to send request to {}: {}"
,
worker_url
,
e
);
error!
(
"Failed to send request to {}: {}"
,
worker_url
,
e
);
// Decrement load on error for CacheAware router
if
load_incremented
{
if
let
Router
::
CacheAware
{
workers
,
..
}
=
self
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
}
}
}
}
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Request failed: {}"
,
e
));
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Request failed: {}"
,
e
));
}
}
};
};
...
@@ -837,13 +894,15 @@ impl Router {
...
@@ -837,13 +894,15 @@ impl Router {
}
}
};
};
// Then decrement running queue counter if using CacheAware
// Decrement load counter for non-streaming CacheAware requests
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
if
load_incremented
&&
!
is_stream
{
if
let
Ok
(
mut
queue
)
=
running_queue
.lock
()
{
if
let
Router
::
CacheAware
{
workers
,
..
}
=
self
{
if
let
Some
(
count
)
=
queue
.get_mut
(
worker_url
)
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
*
count
=
count
.saturating_sub
(
1
);
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
worker
.decrement_load
();
.set
(
*
count
as
f64
);
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
}
}
}
}
}
}
}
...
@@ -855,8 +914,9 @@ impl Router {
...
@@ -855,8 +914,9 @@ impl Router {
counter!
(
"sgl_router_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
response
response
}
else
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
}
else
if
let
Router
::
CacheAware
{
workers
,
..
}
=
self
{
let
running_queue
=
Arc
::
clone
(
running_queue
);
// For streaming with CacheAware router, we need to manually decrement when done
let
workers
=
Arc
::
clone
(
workers
);
let
worker_url
=
worker_url
.to_string
();
let
worker_url
=
worker_url
.to_string
();
HttpResponse
::
build
(
status
)
HttpResponse
::
build
(
status
)
...
@@ -867,21 +927,28 @@ impl Router {
...
@@ -867,21 +927,28 @@ impl Router {
actix_web
::
error
::
ErrorInternalServerError
(
"Failed to read stream"
)
actix_web
::
error
::
ErrorInternalServerError
(
"Failed to read stream"
)
})
})
.inspect
(
move
|
bytes
|
{
.inspect
(
move
|
bytes
|
{
let
bytes
=
bytes
.as_ref
()
.unwrap
();
if
let
Ok
(
bytes
)
=
bytes
{
if
bytes
if
bytes
.as_ref
()
.as_ref
()
.windows
(
12
)
.windows
(
12
)
.any
(|
window
|
window
==
b
"data: [DONE]"
)
.any
(|
window
|
window
==
b
"data: [DONE]"
)
{
{
let
mut
locked_queue
=
running_queue
.lock
()
.unwrap
();
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
let
count
=
locked_queue
.get_mut
(
&
worker_url
)
.unwrap
();
if
let
Some
(
worker
)
=
*
count
=
count
.saturating_sub
(
1
);
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
&
worker_url
)
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
*
count
as
f64
);
{
debug!
(
"Streaming is done!!"
)
worker
.decrement_load
();
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker_url
.to_string
())
.set
(
worker
.load
()
as
f64
);
debug!
(
"Streaming is done!!"
)
}
}
}
}
}
}),
}),
)
)
}
else
{
}
else
{
// For non-CacheAware routers, just stream without load tracking
HttpResponse
::
build
(
status
)
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
(
res
.bytes_stream
()
.map_err
(|
_
|
{
.streaming
(
res
.bytes_stream
()
.map_err
(|
_
|
{
...
@@ -935,43 +1002,27 @@ impl Router {
...
@@ -935,43 +1002,27 @@ impl Router {
Ok
(
res
)
=>
{
Ok
(
res
)
=>
{
if
res
.status
()
.is_success
()
{
if
res
.status
()
.is_success
()
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker
_url
s
,
..
}
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
worker
_url
s
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
worker
_url
s
,
..
}
=>
{
|
Router
::
CacheAware
{
workers
,
..
}
=>
{
info!
(
"Worker {} health check passed"
,
worker_url
);
info!
(
"Worker {} health check passed"
,
worker_url
);
let
mut
urls
=
worker
_url
s
.write
()
.unwrap
();
let
mut
workers_guard
=
workers
.write
()
.unwrap
();
if
urls
.contains
(
&
worker_url
.to_string
()
)
{
if
workers_guard
.iter
()
.any
(|
w
|
w
.url
()
==
worker_url
)
{
return
Err
(
format!
(
"Worker {} already exists"
,
worker_url
));
return
Err
(
format!
(
"Worker {} already exists"
,
worker_url
));
}
}
info!
(
"Added worker: {}"
,
worker_url
);
info!
(
"Added worker: {}"
,
worker_url
);
urls
.push
(
worker_url
.to_string
());
let
new_worker
=
gauge!
(
"sgl_router_active_workers"
)
.set
(
urls
.len
()
as
f64
);
WorkerFactory
::
create_regular
(
worker_url
.to_string
());
workers_guard
.push
(
new_worker
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
workers_guard
.len
()
as
f64
);
}
}
Router
::
PrefillDecode
{
..
}
=>
{
Router
::
PrefillDecode
{
..
}
=>
{
return
Err
(
"Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods."
.to_string
());
return
Err
(
"Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods."
.to_string
());
}
}
}
}
// If cache aware, initialize the queues for the new worker
// If cache aware, add worker to tree
if
let
Router
::
CacheAware
{
if
let
Router
::
CacheAware
{
tree
,
..
}
=
self
{
running_queue
,
processed_queue
,
tree
,
..
}
=
self
{
// Add worker to running queue with initial count of 0
running_queue
.lock
()
.unwrap
()
.insert
(
worker_url
.to_string
(),
0
);
// Add worker to processed queue with initial count of 0
processed_queue
.lock
()
.unwrap
()
.insert
(
worker_url
.to_string
(),
0
);
// Add worker to tree
// Add worker to tree
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker_url
);
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker_url
);
}
}
...
@@ -1013,14 +1064,14 @@ impl Router {
...
@@ -1013,14 +1064,14 @@ impl Router {
pub
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
pub
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
match
self
{
match
self
{
Router
::
RoundRobin
{
worker
_url
s
,
..
}
Router
::
RoundRobin
{
workers
,
..
}
|
Router
::
Random
{
worker
_url
s
,
..
}
|
Router
::
Random
{
workers
,
..
}
|
Router
::
CacheAware
{
worker
_url
s
,
..
}
=>
{
|
Router
::
CacheAware
{
workers
,
..
}
=>
{
let
mut
urls
=
worker
_url
s
.write
()
.unwrap
();
let
mut
workers_guard
=
workers
.write
()
.unwrap
();
if
let
Some
(
index
)
=
urls
.iter
()
.position
(|
url
|
url
==
&
worker_url
)
{
if
let
Some
(
index
)
=
workers_guard
.iter
()
.position
(|
w
|
w
.
url
()
==
worker_url
)
{
urls
.remove
(
index
);
workers_guard
.remove
(
index
);
info!
(
"Removed worker: {}"
,
worker_url
);
info!
(
"Removed worker: {}"
,
worker_url
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
urls
.len
()
as
f64
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
workers_guard
.len
()
as
f64
);
}
else
{
}
else
{
warn!
(
"Worker {} not found, skipping removal"
,
worker_url
);
warn!
(
"Worker {} not found, skipping removal"
,
worker_url
);
return
;
return
;
...
@@ -1033,26 +1084,9 @@ impl Router {
...
@@ -1033,26 +1084,9 @@ impl Router {
}
}
// if cache aware, remove the worker from the tree
// if cache aware, remove the worker from the tree
if
let
Router
::
CacheAware
{
if
let
Router
::
CacheAware
{
tree
,
..
}
=
self
{
tree
,
running_queue
,
processed_queue
,
..
}
=
self
{
tree
.lock
()
.unwrap
()
.remove_tenant
(
&
worker_url
);
tree
.lock
()
.unwrap
()
.remove_tenant
(
&
worker_url
);
running_queue
info!
(
"Removed worker from tree: {}"
,
worker_url
);
.lock
()
.unwrap
()
.remove
(
&
worker_url
.to_string
());
processed_queue
.lock
()
.unwrap
()
.remove
(
&
worker_url
.to_string
());
info!
(
"Removed worker from tree and cleaned up queues: {}"
,
worker_url
);
}
}
}
}
...
@@ -1241,21 +1275,22 @@ mod tests {
...
@@ -1241,21 +1275,22 @@ mod tests {
use
crate
::
service_discovery
::
PodType
;
use
crate
::
service_discovery
::
PodType
;
fn
create_test_regular_router
()
->
Router
{
fn
create_test_regular_router
()
->
Router
{
let
workers
=
vec!
[
WorkerFactory
::
create_regular
(
"http://worker1:8080"
.to_string
()),
WorkerFactory
::
create_regular
(
"http://worker2:8080"
.to_string
()),
];
Router
::
Random
{
Router
::
Random
{
worker_urls
:
Arc
::
new
(
RwLock
::
new
(
vec!
[
workers
:
Arc
::
new
(
RwLock
::
new
(
workers
)),
"http://worker1:8080"
.to_string
(),
"http://worker2:8080"
.to_string
(),
])),
timeout_secs
:
5
,
timeout_secs
:
5
,
interval_secs
:
1
,
interval_secs
:
1
,
_
health_checker
:
None
,
}
}
}
}
#[test]
#[test]
fn
test_router_get_worker_urls_regular
()
{
fn
test_router_get_worker_urls_regular
()
{
let
router
=
create_test_regular_router
();
let
router
=
create_test_regular_router
();
let
worker_urls
=
router
.get_worker_urls
();
let
urls
=
router
.get_worker_urls
();
let
urls
=
worker_urls
.read
()
.unwrap
();
assert_eq!
(
urls
.len
(),
2
);
assert_eq!
(
urls
.len
(),
2
);
assert
!
(
urls
.contains
(
&
"http://worker1:8080"
.to_string
()));
assert
!
(
urls
.contains
(
&
"http://worker1:8080"
.to_string
()));
...
...
sgl-router/src/server.rs
View file @
f2d5c492
...
@@ -236,8 +236,7 @@ async fn add_worker(
...
@@ -236,8 +236,7 @@ async fn add_worker(
#[get(
"/list_workers"
)]
#[get(
"/list_workers"
)]
async
fn
list_workers
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
list_workers
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
workers
=
data
.router
.get_worker_urls
();
let
worker_list
=
data
.router
.get_worker_urls
();
let
worker_list
=
workers
.read
()
.unwrap
()
.clone
();
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"urls"
:
worker_list
}))
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"urls"
:
worker_list
}))
}
}
...
@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
...
@@ -381,7 +380,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
info!
(
"✅ Serving router on {}:{}"
,
config
.host
,
config
.port
);
info!
(
"✅ Serving router on {}:{}"
,
config
.host
,
config
.port
);
info!
(
info!
(
"✅ Serving workers on {:?}"
,
"✅ Serving workers on {:?}"
,
app_state
.router
.get_worker_urls
()
.read
()
.unwrap
()
app_state
.router
.get_worker_urls
()
);
);
HttpServer
::
new
(
move
||
{
HttpServer
::
new
(
move
||
{
...
...
sgl-router/src/service_discovery.rs
View file @
f2d5c492
...
@@ -547,11 +547,12 @@ mod tests {
...
@@ -547,11 +547,12 @@ mod tests {
// Helper to create a Router instance for testing event handlers
// Helper to create a Router instance for testing event handlers
fn
create_test_router
()
->
Arc
<
Router
>
{
fn
create_test_router
()
->
Arc
<
Router
>
{
let
worker
_url
s
=
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()));
let
workers
=
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()));
Arc
::
new
(
Router
::
Random
{
Arc
::
new
(
Router
::
Random
{
worker
_url
s
,
workers
,
timeout_secs
:
5
,
timeout_secs
:
5
,
interval_secs
:
1
,
interval_secs
:
1
,
_
health_checker
:
None
,
})
})
}
}
...
@@ -878,8 +879,6 @@ mod tests {
...
@@ -878,8 +879,6 @@ mod tests {
assert
!
(
!
tracked_pods
.lock
()
.unwrap
()
.contains
(
&
pod_info
));
assert
!
(
!
tracked_pods
.lock
()
.unwrap
()
.contains
(
&
pod_info
));
assert
!
(
!
router
assert
!
(
!
router
.get_worker_urls
()
.get_worker_urls
()
.read
()
.unwrap
()
.contains
(
&
pod_info
.worker_url
(
port
)));
.contains
(
&
pod_info
.worker_url
(
port
)));
}
}
...
@@ -907,7 +906,7 @@ mod tests {
...
@@ -907,7 +906,7 @@ mod tests {
.await
;
.await
;
assert
!
(
tracked_pods
.lock
()
.unwrap
()
.is_empty
());
assert
!
(
tracked_pods
.lock
()
.unwrap
()
.is_empty
());
assert
!
(
router
.get_worker_urls
()
.
read
()
.unwrap
()
.
is_empty
());
assert
!
(
router
.get_worker_urls
()
.is_empty
());
}
}
#[tokio::test]
#[tokio::test]
...
...
sgl-router/tests/test_pd_routing.rs
View file @
f2d5c492
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
mod
test_pd_routing
{
mod
test_pd_routing
{
use
rand
::
Rng
;
use
rand
::
Rng
;
use
serde_json
::
json
;
use
serde_json
::
json
;
use
sglang_router_rs
::
pd_types
::
{
EngineInfo
,
EngineType
,
PDSelectionPolicy
}
;
use
sglang_router_rs
::
pd_types
::
PDSelectionPolicy
;
use
sglang_router_rs
::
router
::{
PolicyConfig
,
Router
};
use
sglang_router_rs
::
router
::{
PolicyConfig
,
Router
};
// Test-only struct to help validate PD request parsing
// Test-only struct to help validate PD request parsing
...
@@ -51,40 +51,35 @@ mod test_pd_routing {
...
@@ -51,40 +51,35 @@ mod test_pd_routing {
// ========================================================================
// ========================================================================
#[test]
#[test]
fn
test_engine_info_creation
()
{
fn
test_worker_types
()
{
// Test EngineInfo creation for prefill servers
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
let
prefill_engine
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
match
prefill_engine
.engine_type
{
// Test worker creation for prefill servers
EngineType
::
Prefill
=>
(),
let
prefill_worker
=
_
=>
panic!
(
"Expected Prefill engine type"
),
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
assert_eq!
(
prefill_worker
.url
(),
"http://prefill:8080"
);
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
{
assert_eq!
(
bootstrap_port
,
Some
(
9000
));
}
_
=>
panic!
(
"Expected Prefill worker type"
),
}
}
assert_eq!
(
prefill_engine
.url
,
"http://prefill:8080"
);
assert_eq!
(
prefill_engine
.bootstrap_port
,
Some
(
9000
));
// Test worker creation for decode servers
assert_eq!
(
prefill_engine
.get_hostname
(),
"prefill"
);
let
decode_worker
=
WorkerFactory
::
create_decode
(
"http://decode:8080"
.to_string
());
assert_eq!
(
decode_worker
.url
(),
"http://decode:8080"
);
// Test EngineInfo creation for decode servers
match
decode_worker
.worker_type
()
{
let
decode_engine
=
EngineInfo
::
new_decode
(
"http://decode:8080"
.to_string
());
WorkerType
::
Decode
=>
(),
match
decode_engine
.engine_type
{
_
=>
panic!
(
"Expected Decode worker type"
),
EngineType
::
Decode
=>
(),
_
=>
panic!
(
"Expected Decode engine type"
),
}
}
assert_eq!
(
decode_engine
.url
,
"http://decode:8080"
);
assert_eq!
(
decode_engine
.bootstrap_port
,
None
);
assert_eq!
(
decode_engine
.get_hostname
(),
"decode"
);
// Test API path generation
// Test regular worker creation
assert_eq!
(
let
regular_worker
=
WorkerFactory
::
create_regular
(
"http://regular:8080"
.to_string
());
prefill_engine
.api_path
(
"/generate"
),
assert_eq!
(
regular_worker
.url
(),
"http://regular:8080"
);
"http://prefill:8080/generate"
match
regular_worker
.worker_type
()
{
);
WorkerType
::
Regular
=>
(),
assert_eq!
(
_
=>
panic!
(
"Expected Regular worker type"
),
prefill_engine
.api_path
(
"health"
),
}
"http://prefill:8080/health"
);
assert_eq!
(
decode_engine
.api_path
(
"/v1/chat/completions"
),
"http://decode:8080/v1/chat/completions"
);
}
}
#[test]
#[test]
...
@@ -230,6 +225,9 @@ mod test_pd_routing {
...
@@ -230,6 +225,9 @@ mod test_pd_routing {
#[test]
#[test]
fn
test_bootstrap_injection_simulation
()
{
fn
test_bootstrap_injection_simulation
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Since we can't test the actual inject_bootstrap_fields function here
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
// (it's private in the router module), we'll test the expected behavior
...
@@ -240,15 +238,24 @@ mod test_pd_routing {
...
@@ -240,15 +238,24 @@ mod test_pd_routing {
"temperature"
:
0.7
"temperature"
:
0.7
});
});
// Create a prefill worker to simulate injection
let
prefill_worker
=
WorkerFactory
::
create_prefill
(
"http://prefill1:8080"
.to_string
(),
Some
(
9000
));
// Extract bootstrap port from worker type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
// Simulate what inject_bootstrap_fields would do
// Simulate what inject_bootstrap_fields would do
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill1:8080"
.to_string
(),
Some
(
9000
));
single_json
[
"bootstrap_host"
]
=
json!
(
get_hostname
(
prefill_worker
.url
()));
single_json
[
"bootstrap_host"
]
=
json!
(
prefill_info
.get_hostname
());
single_json
[
"bootstrap_port"
]
=
json!
(
bootstrap_port
);
single_json
[
"bootstrap_port"
]
=
json!
(
prefill_info
.bootstrap_port
);
single_json
[
"bootstrap_room"
]
=
json!
(
12345u64
);
// Random room ID
single_json
[
"bootstrap_room"
]
=
json!
(
12345u64
);
// Random room ID
// Verify bootstrap fields are added correctly
// Verify bootstrap fields are added correctly
assert_eq!
(
single_json
[
"bootstrap_host"
],
"prefill1"
);
assert_eq!
(
single_json
[
"bootstrap_host"
],
"prefill1"
);
assert_eq!
(
single_json
[
"bootstrap_port"
],
9000
);
assert_eq!
(
single_json
[
"bootstrap_port"
],
json!
(
Some
(
9000
)
))
;
assert
!
(
single_json
[
"bootstrap_room"
]
.is_u64
());
assert
!
(
single_json
[
"bootstrap_room"
]
.is_u64
());
assert_eq!
(
single_json
[
"temperature"
],
0.7
);
// Original field preserved
assert_eq!
(
single_json
[
"temperature"
],
0.7
);
// Original field preserved
...
@@ -259,8 +266,9 @@ mod test_pd_routing {
...
@@ -259,8 +266,9 @@ mod test_pd_routing {
});
});
let
batch_size
=
3
;
let
batch_size
=
3
;
batch_json
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
let
hostname
=
get_hostname
(
prefill_worker
.url
());
batch_json
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
batch_json
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
batch_json
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
batch_json
[
"bootstrap_room"
]
=
json!
(
vec!
[
111u64
,
222u64
,
333u64
]);
batch_json
[
"bootstrap_room"
]
=
json!
(
vec!
[
111u64
,
222u64
,
333u64
]);
// Verify batch bootstrap fields
// Verify batch bootstrap fields
...
@@ -306,7 +314,9 @@ mod test_pd_routing {
...
@@ -306,7 +314,9 @@ mod test_pd_routing {
}
}
#[test]
#[test]
fn
test_engine_info_hostname_extraction
()
{
fn
test_hostname_extraction
()
{
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Test various URL formats
// Test various URL formats
let
test_cases
=
vec!
[
let
test_cases
=
vec!
[
(
"http://localhost:8080"
,
"localhost"
),
(
"http://localhost:8080"
,
"localhost"
),
...
@@ -318,8 +328,7 @@ mod test_pd_routing {
...
@@ -318,8 +328,7 @@ mod test_pd_routing {
];
];
for
(
url
,
expected_hostname
)
in
test_cases
{
for
(
url
,
expected_hostname
)
in
test_cases
{
let
engine
=
EngineInfo
::
new_prefill
(
url
.to_string
(),
None
);
assert_eq!
(
get_hostname
(
url
),
expected_hostname
);
assert_eq!
(
engine
.get_hostname
(),
expected_hostname
);
}
}
}
}
...
@@ -652,6 +661,9 @@ mod test_pd_routing {
...
@@ -652,6 +661,9 @@ mod test_pd_routing {
#[test]
#[test]
fn
test_bootstrap_injection_with_benchmark_requests
()
{
fn
test_bootstrap_injection_with_benchmark_requests
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Test bootstrap injection with actual benchmark request patterns
// Test bootstrap injection with actual benchmark request patterns
let
mut
benchmark_request
=
json!
({
let
mut
benchmark_request
=
json!
({
"input_ids"
:
vec!
[
vec!
[
1
,
2
,
3
,
4
];
16
],
// Batch size 16
"input_ids"
:
vec!
[
vec!
[
1
,
2
,
3
,
4
];
16
],
// Batch size 16
...
@@ -664,12 +676,20 @@ mod test_pd_routing {
...
@@ -664,12 +676,20 @@ mod test_pd_routing {
"stream"
:
true
"stream"
:
true
});
});
// Simulate bootstrap injection
// Create a prefill worker to simulate injection
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
let
prefill_worker
=
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
// Extract bootstrap port from worker type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
batch_size
=
16
;
let
batch_size
=
16
;
let
hostname
=
get_hostname
(
prefill_worker
.url
());
benchmark_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_
hostname
()
;
batch_size
]);
benchmark_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
benchmark_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.
bootstrap_port
;
batch_size
]);
benchmark_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
benchmark_request
[
"bootstrap_room"
]
=
benchmark_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
12345u64
)
.collect
::
<
Vec
<
_
>>
());
json!
((
0
..
batch_size
)
.map
(|
_
|
12345u64
)
.collect
::
<
Vec
<
_
>>
());
...
@@ -770,6 +790,9 @@ mod test_pd_routing {
...
@@ -770,6 +790,9 @@ mod test_pd_routing {
#[test]
#[test]
fn
test_large_batch_bootstrap_injection
()
{
fn
test_large_batch_bootstrap_injection
()
{
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
pd_types
::
get_hostname
;
// Test bootstrap injection performance with very large batches
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
// This simulates the bench_one_batch_server.py scenario
let
large_batch_sizes
=
vec!
[
1024
,
4096
,
8192
];
let
large_batch_sizes
=
vec!
[
1024
,
4096
,
8192
];
...
@@ -787,14 +810,19 @@ mod test_pd_routing {
...
@@ -787,14 +810,19 @@ mod test_pd_routing {
"stream"
:
true
"stream"
:
true
});
});
// Simulate bootstrap injection
// Create a prefill worker to simulate injection
let
prefill_info
=
let
prefill_worker
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
WorkerFactory
::
create_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
// Extract bootstrap port from worker type
let
bootstrap_port
=
match
prefill_worker
.worker_type
()
{
WorkerType
::
Prefill
{
bootstrap_port
}
=>
bootstrap_port
,
_
=>
None
,
};
let
hostname
=
get_hostname
(
prefill_worker
.url
());
large_batch_request
[
"bootstrap_host"
]
=
large_batch_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
hostname
;
batch_size
]);
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
large_batch_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
bootstrap_port
;
batch_size
]);
large_batch_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
large_batch_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
large_batch_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
rand
::
thread_rng
()
.gen
::
<
u64
>
())
.map
(|
_
|
rand
::
thread_rng
()
.gen
::
<
u64
>
())
.collect
::
<
Vec
<
_
>>
());
.collect
::
<
Vec
<
_
>>
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment