Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
90313fb0
Unverified
Commit
90313fb0
authored
Aug 26, 2025
by
Chang Su
Committed by
GitHub
Aug 26, 2025
Browse files
[router] add token bucket rate limiter (#9656)
parent
3578eb1e
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
533 additions
and
10 deletions
+533
-10
sgl-router/py_src/sglang_router/launch_router.py
sgl-router/py_src/sglang_router/launch_router.py
+42
-0
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+10
-1
sgl-router/src/config/types.rs
sgl-router/src/config/types.rs
+21
-0
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+1
-0
sgl-router/src/core/token_bucket.rs
sgl-router/src/core/token_bucket.rs
+195
-0
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+17
-2
sgl-router/src/main.rs
sgl-router/src/main.rs
+3
-0
sgl-router/src/middleware.rs
sgl-router/src/middleware.rs
+189
-2
sgl-router/src/server.rs
sgl-router/src/server.rs
+30
-4
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+12
-0
sgl-router/tests/common/mod.rs
sgl-router/tests/common/mod.rs
+1
-0
sgl-router/tests/common/test_app.rs
sgl-router/tests/common/test_app.rs
+2
-0
sgl-router/tests/request_formats_test.rs
sgl-router/tests/request_formats_test.rs
+3
-0
sgl-router/tests/streaming_tests.rs
sgl-router/tests/streaming_tests.rs
+3
-0
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+4
-1
No files found.
sgl-router/py_src/sglang_router/launch_router.py
View file @
90313fb0
...
...
@@ -72,6 +72,12 @@ class RouterArgs:
request_timeout_secs
:
int
=
1800
# Max concurrent requests for rate limiting
max_concurrent_requests
:
int
=
256
# Queue size for pending requests when max concurrent limit reached
queue_size
:
int
=
100
# Maximum time (in seconds) a request can wait in queue before timing out
queue_timeout_secs
:
int
=
60
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
rate_limit_tokens_per_second
:
Optional
[
int
]
=
None
# CORS allowed origins
cors_allowed_origins
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
# Retry configuration
...
...
@@ -402,6 +408,24 @@ class RouterArgs:
default
=
RouterArgs
.
max_concurrent_requests
,
help
=
"Maximum number of concurrent requests allowed (for rate limiting)"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
queue-size"
,
type
=
int
,
default
=
RouterArgs
.
queue_size
,
help
=
"Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
queue-timeout-secs"
,
type
=
int
,
default
=
RouterArgs
.
queue_timeout_secs
,
help
=
"Maximum time (in seconds) a request can wait in queue before timing out"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
rate-limit-tokens-per-second"
,
type
=
int
,
default
=
RouterArgs
.
rate_limit_tokens_per_second
,
help
=
"Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
cors-allowed-origins"
,
type
=
str
,
...
...
@@ -478,6 +502,21 @@ class RouterArgs:
f
"
{
prefix
}
max_concurrent_requests"
,
RouterArgs
.
max_concurrent_requests
,
),
queue_size
=
getattr
(
args
,
f
"
{
prefix
}
queue_size"
,
RouterArgs
.
queue_size
,
),
queue_timeout_secs
=
getattr
(
args
,
f
"
{
prefix
}
queue_timeout_secs"
,
RouterArgs
.
queue_timeout_secs
,
),
rate_limit_tokens_per_second
=
getattr
(
args
,
f
"
{
prefix
}
rate_limit_tokens_per_second"
,
RouterArgs
.
rate_limit_tokens_per_second
,
),
cors_allowed_origins
=
getattr
(
args
,
f
"
{
prefix
}
cors_allowed_origins"
,
[]),
retry_max_retries
=
getattr
(
args
,
f
"
{
prefix
}
retry_max_retries"
),
retry_initial_backoff_ms
=
getattr
(
args
,
f
"
{
prefix
}
retry_initial_backoff_ms"
),
...
...
@@ -700,6 +739,9 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
),
request_id_headers
=
router_args
.
request_id_headers
,
max_concurrent_requests
=
router_args
.
max_concurrent_requests
,
queue_size
=
router_args
.
queue_size
,
queue_timeout_secs
=
router_args
.
queue_timeout_secs
,
rate_limit_tokens_per_second
=
router_args
.
rate_limit_tokens_per_second
,
cors_allowed_origins
=
router_args
.
cors_allowed_origins
,
retry_max_retries
=
router_args
.
retry_max_retries
,
retry_initial_backoff_ms
=
router_args
.
retry_initial_backoff_ms
,
...
...
sgl-router/py_src/sglang_router/router.py
View file @
90313fb0
...
...
@@ -64,7 +64,10 @@ class Router:
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
Default: 'sglang.ai/bootstrap-port'
request_timeout_secs: Request timeout in seconds. Default: 600
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
...
...
@@ -108,6 +111,9 @@ class Router:
prefill_policy
:
Optional
[
PolicyType
]
=
None
,
decode_policy
:
Optional
[
PolicyType
]
=
None
,
max_concurrent_requests
:
int
=
256
,
queue_size
:
int
=
100
,
queue_timeout_secs
:
int
=
60
,
rate_limit_tokens_per_second
:
Optional
[
int
]
=
None
,
cors_allowed_origins
:
List
[
str
]
=
None
,
retry_max_retries
:
int
=
5
,
retry_initial_backoff_ms
:
int
=
50
,
...
...
@@ -169,6 +175,9 @@ class Router:
prefill_policy
=
prefill_policy
,
decode_policy
=
decode_policy
,
max_concurrent_requests
=
max_concurrent_requests
,
queue_size
=
queue_size
,
queue_timeout_secs
=
queue_timeout_secs
,
rate_limit_tokens_per_second
=
rate_limit_tokens_per_second
,
cors_allowed_origins
=
cors_allowed_origins
,
retry_max_retries
=
retry_max_retries
,
retry_initial_backoff_ms
=
retry_initial_backoff_ms
,
...
...
sgl-router/src/config/types.rs
View file @
90313fb0
...
...
@@ -37,6 +37,12 @@ pub struct RouterConfig {
pub
request_id_headers
:
Option
<
Vec
<
String
>>
,
/// Maximum concurrent requests allowed (for rate limiting)
pub
max_concurrent_requests
:
usize
,
/// Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)
pub
queue_size
:
usize
,
/// Maximum time (in seconds) a request can wait in queue before timing out
pub
queue_timeout_secs
:
u64
,
/// Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
pub
rate_limit_tokens_per_second
:
Option
<
usize
>
,
/// CORS allowed origins
pub
cors_allowed_origins
:
Vec
<
String
>
,
/// Retry configuration
...
...
@@ -320,6 +326,9 @@ impl Default for RouterConfig {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
256
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
@@ -466,6 +475,9 @@ mod tests {
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
};
let
json
=
serde_json
::
to_string
(
&
config
)
.unwrap
();
...
...
@@ -899,6 +911,9 @@ mod tests {
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
};
assert
!
(
config
.mode
.is_pd_mode
());
...
...
@@ -956,6 +971,9 @@ mod tests {
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
};
assert
!
(
!
config
.mode
.is_pd_mode
());
...
...
@@ -1009,6 +1027,9 @@ mod tests {
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
};
assert
!
(
config
.has_service_discovery
());
...
...
sgl-router/src/core/mod.rs
View file @
90313fb0
...
...
@@ -9,6 +9,7 @@
pub
mod
circuit_breaker
;
pub
mod
error
;
pub
mod
retry
;
pub
mod
token_bucket
;
pub
mod
worker
;
// Re-export commonly used types at the module level
...
...
sgl-router/src/core/token_bucket.rs
0 → 100644
View file @
90313fb0
use
std
::
sync
::
Arc
;
use
std
::
time
::{
Duration
,
Instant
};
use
tokio
::
sync
::{
Mutex
,
Notify
};
use
tracing
::{
debug
,
trace
};
/// Token bucket for rate limiting
///
/// This implementation provides:
/// - Smooth rate limiting with configurable refill rate
/// - Burst capacity handling
/// - Fair queuing for waiting requests
#[derive(Clone)]
pub
struct
TokenBucket
{
inner
:
Arc
<
Mutex
<
TokenBucketInner
>>
,
notify
:
Arc
<
Notify
>
,
capacity
:
f64
,
refill_rate
:
f64
,
// tokens per second
}
struct
TokenBucketInner
{
tokens
:
f64
,
last_refill
:
Instant
,
}
impl
TokenBucket
{
/// Create a new token bucket
///
/// # Arguments
/// * `capacity` - Maximum number of tokens (burst capacity)
/// * `refill_rate` - Tokens added per second
pub
fn
new
(
capacity
:
usize
,
refill_rate
:
usize
)
->
Self
{
let
capacity
=
capacity
as
f64
;
let
refill_rate
=
refill_rate
as
f64
;
// Ensure refill_rate is not zero to prevent division by zero
let
refill_rate
=
if
refill_rate
>
0.0
{
refill_rate
}
else
{
1.0
// Default to 1 token per second if zero
};
Self
{
inner
:
Arc
::
new
(
Mutex
::
new
(
TokenBucketInner
{
tokens
:
capacity
,
// Start full
last_refill
:
Instant
::
now
(),
})),
notify
:
Arc
::
new
(
Notify
::
new
()),
capacity
,
refill_rate
,
}
}
/// Try to acquire tokens immediately
pub
async
fn
try_acquire
(
&
self
,
tokens
:
f64
)
->
Result
<
(),
()
>
{
let
mut
inner
=
self
.inner
.lock
()
.await
;
// Refill tokens based on elapsed time
let
now
=
Instant
::
now
();
let
elapsed
=
now
.duration_since
(
inner
.last_refill
)
.as_secs_f64
();
let
refill_amount
=
elapsed
*
self
.refill_rate
;
inner
.tokens
=
(
inner
.tokens
+
refill_amount
)
.min
(
self
.capacity
);
inner
.last_refill
=
now
;
trace!
(
"Token bucket: {} tokens available, requesting {}"
,
inner
.tokens
,
tokens
);
if
inner
.tokens
>=
tokens
{
inner
.tokens
-=
tokens
;
debug!
(
"Token bucket: acquired {} tokens, {} remaining"
,
tokens
,
inner
.tokens
);
Ok
(())
}
else
{
Err
(())
}
}
/// Acquire tokens, waiting if necessary
pub
async
fn
acquire
(
&
self
,
tokens
:
f64
)
->
Result
<
(),
tokio
::
time
::
error
::
Elapsed
>
{
// First try to acquire immediately
if
self
.try_acquire
(
tokens
)
.await
.is_ok
()
{
return
Ok
(());
}
// Calculate wait time
let
wait_time
=
{
let
inner
=
self
.inner
.lock
()
.await
;
let
tokens_needed
=
tokens
-
inner
.tokens
;
let
wait_secs
=
tokens_needed
/
self
.refill_rate
;
Duration
::
from_secs_f64
(
wait_secs
)
};
debug!
(
"Token bucket: waiting {:?} for {} tokens"
,
wait_time
,
tokens
);
// Wait for tokens to be available
tokio
::
time
::
timeout
(
wait_time
,
async
{
loop
{
// Check if we can acquire now
if
self
.try_acquire
(
tokens
)
.await
.is_ok
()
{
return
;
}
// Wait for notification or small interval
tokio
::
select!
{
_
=
self
.notify
.notified
()
=>
{},
_
=
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
10
))
=>
{},
}
}
})
.await
?
;
Ok
(())
}
/// Acquire tokens with custom timeout
pub
async
fn
acquire_timeout
(
&
self
,
tokens
:
f64
,
timeout
:
Duration
,
)
->
Result
<
(),
tokio
::
time
::
error
::
Elapsed
>
{
tokio
::
time
::
timeout
(
timeout
,
self
.acquire
(
tokens
))
.await
?
}
/// Return tokens to the bucket (for cancelled requests)
pub
async
fn
return_tokens
(
&
self
,
tokens
:
f64
)
{
let
mut
inner
=
self
.inner
.lock
()
.await
;
inner
.tokens
=
(
inner
.tokens
+
tokens
)
.min
(
self
.capacity
);
self
.notify
.notify_waiters
();
debug!
(
"Token bucket: returned {} tokens, {} available"
,
tokens
,
inner
.tokens
);
}
/// Get current available tokens (for monitoring)
pub
async
fn
available_tokens
(
&
self
)
->
f64
{
let
mut
inner
=
self
.inner
.lock
()
.await
;
// Refill before checking
let
now
=
Instant
::
now
();
let
elapsed
=
now
.duration_since
(
inner
.last_refill
)
.as_secs_f64
();
let
refill_amount
=
elapsed
*
self
.refill_rate
;
inner
.tokens
=
(
inner
.tokens
+
refill_amount
)
.min
(
self
.capacity
);
inner
.last_refill
=
now
;
inner
.tokens
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[tokio::test]
async
fn
test_token_bucket_basic
()
{
let
bucket
=
TokenBucket
::
new
(
10
,
5
);
// 10 capacity, 5 per second
// Should succeed - bucket starts full
assert
!
(
bucket
.try_acquire
(
5.0
)
.await
.is_ok
());
assert
!
(
bucket
.try_acquire
(
5.0
)
.await
.is_ok
());
// Should fail - no tokens left
assert
!
(
bucket
.try_acquire
(
1.0
)
.await
.is_err
());
// Wait for refill
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
300
))
.await
;
// Should have ~1.5 tokens now
assert
!
(
bucket
.try_acquire
(
1.0
)
.await
.is_ok
());
}
#[tokio::test]
async
fn
test_token_bucket_refill
()
{
let
bucket
=
TokenBucket
::
new
(
10
,
10
);
// 10 capacity, 10 per second
// Use all tokens
assert
!
(
bucket
.try_acquire
(
10.0
)
.await
.is_ok
());
// Wait for partial refill
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
500
))
.await
;
// Should have ~5 tokens
let
available
=
bucket
.available_tokens
()
.await
;
assert
!
((
4.0
..=
6.0
)
.contains
(
&
available
));
}
}
sgl-router/src/lib.rs
View file @
90313fb0
...
...
@@ -85,6 +85,9 @@ struct Router {
health_check_endpoint
:
String
,
// IGW (Inference Gateway) configuration
enable_igw
:
bool
,
queue_size
:
usize
,
queue_timeout_secs
:
u64
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
}
impl
Router
{
...
...
@@ -176,6 +179,9 @@ impl Router {
log_level
:
self
.log_level
.clone
(),
request_id_headers
:
self
.request_id_headers
.clone
(),
max_concurrent_requests
:
self
.max_concurrent_requests
,
queue_size
:
self
.queue_size
,
queue_timeout_secs
:
self
.queue_timeout_secs
,
rate_limit_tokens_per_second
:
self
.rate_limit_tokens_per_second
,
cors_allowed_origins
:
self
.cors_allowed_origins
.clone
(),
retry
:
config
::
RetryConfig
{
max_retries
:
self
.retry_max_retries
,
...
...
@@ -190,8 +196,8 @@ impl Router {
timeout_duration_secs
:
self
.cb_timeout_duration_secs
,
window_duration_secs
:
self
.cb_window_duration_secs
,
},
disable_retries
:
false
,
disable_circuit_breaker
:
false
,
disable_retries
:
self
.disable_retries
,
disable_circuit_breaker
:
self
.disable_circuit_breaker
,
health_check
:
config
::
HealthCheckConfig
{
failure_threshold
:
self
.health_failure_threshold
,
success_threshold
:
self
.health_success_threshold
,
...
...
@@ -263,6 +269,9 @@ impl Router {
health_check_endpoint
=
String
::
from
(
"/health"
),
// IGW defaults
enable_igw
=
false
,
queue_size
=
100
,
queue_timeout_secs
=
60
,
rate_limit_tokens_per_second
=
None
,
))]
#[allow(clippy::too_many_arguments)]
fn
new
(
...
...
@@ -317,6 +326,9 @@ impl Router {
health_check_interval_secs
:
u64
,
health_check_endpoint
:
String
,
enable_igw
:
bool
,
queue_size
:
usize
,
queue_timeout_secs
:
u64
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
)
->
PyResult
<
Self
>
{
Ok
(
Router
{
host
,
...
...
@@ -370,6 +382,9 @@ impl Router {
health_check_interval_secs
,
health_check_endpoint
,
enable_igw
,
queue_size
,
queue_timeout_secs
,
rate_limit_tokens_per_second
,
})
}
...
...
sgl-router/src/main.rs
View file @
90313fb0
...
...
@@ -394,6 +394,8 @@ impl CliArgs {
Some
(
self
.request_id_headers
.clone
())
},
max_concurrent_requests
:
self
.max_concurrent_requests
,
queue_size
:
100
,
// Default queue size
queue_timeout_secs
:
60
,
// Default timeout
cors_allowed_origins
:
self
.cors_allowed_origins
.clone
(),
retry
:
RetryConfig
{
max_retries
:
self
.retry_max_retries
,
...
...
@@ -418,6 +420,7 @@ impl CliArgs {
endpoint
:
self
.health_check_endpoint
.clone
(),
},
enable_igw
:
self
.enable_igw
,
rate_limit_tokens_per_second
:
None
,
})
}
...
...
sgl-router/src/middleware.rs
View file @
90313fb0
use
axum
::{
extract
::
Request
,
http
::
HeaderValue
,
response
::
Response
};
use
axum
::{
extract
::
Request
,
extract
::
State
,
http
::
HeaderValue
,
http
::
StatusCode
,
middleware
::
Next
,
response
::
IntoResponse
,
response
::
Response
,
};
use
rand
::
Rng
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Instant
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tower
::{
Layer
,
Service
};
use
tower_http
::
trace
::{
MakeSpan
,
OnRequest
,
OnResponse
,
TraceLayer
};
use
tracing
::{
field
::
Empty
,
info_span
,
Span
};
use
tracing
::{
debug
,
error
,
field
::
Empty
,
info
,
info_span
,
warn
,
Span
};
pub
use
crate
::
core
::
token_bucket
::
TokenBucket
;
use
crate
::
server
::
AppState
;
/// Generate OpenAI-compatible request ID based on endpoint
fn
generate_request_id
(
path
:
&
str
)
->
String
{
...
...
@@ -313,3 +322,181 @@ pub fn log_request(entry: RequestLogEntry) {
);
}
}
// ============ Concurrency Limiting with Queue Support ============
/// Request queue entry
pub
struct
QueuedRequest
{
/// Time when the request was queued
queued_at
:
Instant
,
/// Channel to send the permit back when acquired
permit_tx
:
oneshot
::
Sender
<
Result
<
(),
StatusCode
>>
,
}
/// Queue metrics for monitoring
#[derive(Debug,
Default)]
pub
struct
QueueMetrics
{
pub
total_queued
:
std
::
sync
::
atomic
::
AtomicU64
,
pub
current_queued
:
std
::
sync
::
atomic
::
AtomicU64
,
pub
total_timeout
:
std
::
sync
::
atomic
::
AtomicU64
,
pub
total_rejected
:
std
::
sync
::
atomic
::
AtomicU64
,
}
/// Queue processor that handles queued requests
pub
struct
QueueProcessor
{
token_bucket
:
Arc
<
TokenBucket
>
,
queue_rx
:
mpsc
::
Receiver
<
QueuedRequest
>
,
queue_timeout
:
Duration
,
}
impl
QueueProcessor
{
pub
fn
new
(
token_bucket
:
Arc
<
TokenBucket
>
,
queue_rx
:
mpsc
::
Receiver
<
QueuedRequest
>
,
queue_timeout
:
Duration
,
)
->
Self
{
Self
{
token_bucket
,
queue_rx
,
queue_timeout
,
}
}
pub
async
fn
run
(
mut
self
)
{
info!
(
"Starting concurrency queue processor"
);
// Process requests in a single task to reduce overhead
while
let
Some
(
queued
)
=
self
.queue_rx
.recv
()
.await
{
// Check timeout immediately
let
elapsed
=
queued
.queued_at
.elapsed
();
if
elapsed
>=
self
.queue_timeout
{
warn!
(
"Request already timed out in queue"
);
let
_
=
queued
.permit_tx
.send
(
Err
(
StatusCode
::
REQUEST_TIMEOUT
));
continue
;
}
let
remaining_timeout
=
self
.queue_timeout
-
elapsed
;
// Try to acquire token for this request
if
self
.token_bucket
.try_acquire
(
1.0
)
.await
.is_ok
()
{
// Got token immediately
debug!
(
"Queue: acquired token immediately for queued request"
);
let
_
=
queued
.permit_tx
.send
(
Ok
(()));
}
else
{
// Need to wait for token
let
token_bucket
=
self
.token_bucket
.clone
();
// Spawn task only when we actually need to wait
tokio
::
spawn
(
async
move
{
if
token_bucket
.acquire_timeout
(
1.0
,
remaining_timeout
)
.await
.is_ok
()
{
debug!
(
"Queue: acquired token after waiting"
);
let
_
=
queued
.permit_tx
.send
(
Ok
(()));
}
else
{
warn!
(
"Queue: request timed out waiting for token"
);
let
_
=
queued
.permit_tx
.send
(
Err
(
StatusCode
::
REQUEST_TIMEOUT
));
}
});
}
}
warn!
(
"Concurrency queue processor shutting down"
);
}
}
/// State for the concurrency limiter
pub
struct
ConcurrencyLimiter
{
pub
queue_tx
:
Option
<
mpsc
::
Sender
<
QueuedRequest
>>
,
}
impl
ConcurrencyLimiter
{
/// Create new concurrency limiter with optional queue
pub
fn
new
(
token_bucket
:
Arc
<
TokenBucket
>
,
queue_size
:
usize
,
queue_timeout
:
Duration
,
)
->
(
Self
,
Option
<
QueueProcessor
>
)
{
if
queue_size
>
0
{
let
(
queue_tx
,
queue_rx
)
=
mpsc
::
channel
(
queue_size
);
let
processor
=
QueueProcessor
::
new
(
token_bucket
,
queue_rx
,
queue_timeout
);
(
Self
{
queue_tx
:
Some
(
queue_tx
),
},
Some
(
processor
),
)
}
else
{
(
Self
{
queue_tx
:
None
},
None
)
}
}
}
/// Middleware function for concurrency limiting with optional queuing
pub
async
fn
concurrency_limit_middleware
(
State
(
app_state
):
State
<
Arc
<
AppState
>>
,
request
:
Request
<
axum
::
body
::
Body
>
,
next
:
Next
,
)
->
Response
{
let
token_bucket
=
app_state
.context.rate_limiter
.clone
();
// Try to acquire token immediately
if
token_bucket
.try_acquire
(
1.0
)
.await
.is_ok
()
{
debug!
(
"Acquired token immediately"
);
let
response
=
next
.run
(
request
)
.await
;
// Return the token to the bucket
token_bucket
.return_tokens
(
1.0
)
.await
;
response
}
else
{
// No tokens available, try to queue if enabled
if
let
Some
(
queue_tx
)
=
&
app_state
.concurrency_queue_tx
{
debug!
(
"No tokens available, attempting to queue request"
);
// Create a channel for the token response
let
(
permit_tx
,
permit_rx
)
=
oneshot
::
channel
();
let
queued
=
QueuedRequest
{
queued_at
:
Instant
::
now
(),
permit_tx
,
};
// Try to send to queue
match
queue_tx
.try_send
(
queued
)
{
Ok
(
_
)
=>
{
// Wait for token from queue processor
match
permit_rx
.await
{
Ok
(
Ok
(()))
=>
{
debug!
(
"Acquired token from queue"
);
let
response
=
next
.run
(
request
)
.await
;
// Return the token to the bucket
token_bucket
.return_tokens
(
1.0
)
.await
;
response
}
Ok
(
Err
(
status
))
=>
{
warn!
(
"Queue returned error status: {}"
,
status
);
status
.into_response
()
}
Err
(
_
)
=>
{
error!
(
"Queue response channel closed"
);
StatusCode
::
INTERNAL_SERVER_ERROR
.into_response
()
}
}
}
Err
(
_
)
=>
{
warn!
(
"Request queue is full, returning 429"
);
StatusCode
::
TOO_MANY_REQUESTS
.into_response
()
}
}
}
else
{
warn!
(
"No tokens available and queuing is disabled, returning 429"
);
StatusCode
::
TOO_MANY_REQUESTS
.into_response
()
}
}
}
sgl-router/src/server.rs
View file @
90313fb0
use
crate
::
config
::
RouterConfig
;
use
crate
::
logging
::{
self
,
LoggingConfig
};
use
crate
::
metrics
::{
self
,
PrometheusConfig
};
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
routers
::{
RouterFactory
,
RouterTrait
};
use
crate
::
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
};
...
...
@@ -25,7 +26,7 @@ use tracing::{error, info, warn, Level};
pub
struct
AppContext
{
pub
client
:
Client
,
pub
router_config
:
RouterConfig
,
pub
concurrency
_limiter
:
Arc
<
t
ok
io
::
sync
::
Semaphore
>
,
pub
rate
_limiter
:
Arc
<
T
ok
enBucket
>
,
// Future dependencies can be added here
}
...
...
@@ -34,12 +35,14 @@ impl AppContext {
router_config
:
RouterConfig
,
client
:
Client
,
max_concurrent_requests
:
usize
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
)
->
Self
{
let
concurrency_limiter
=
Arc
::
new
(
tokio
::
sync
::
Semaphore
::
new
(
max_concurrent_requests
));
let
rate_limit_tokens
=
rate_limit_tokens_per_second
.unwrap_or
(
max_concurrent_requests
);
let
rate_limiter
=
Arc
::
new
(
TokenBucket
::
new
(
max_concurrent_requests
,
rate_limit_tokens
));
Self
{
client
,
router_config
,
concurrency
_limiter
,
rate
_limiter
,
}
}
}
...
...
@@ -48,6 +51,7 @@ impl AppContext {
pub
struct
AppState
{
pub
router
:
Arc
<
dyn
RouterTrait
>
,
pub
context
:
Arc
<
AppContext
>
,
pub
concurrency_queue_tx
:
Option
<
tokio
::
sync
::
mpsc
::
Sender
<
crate
::
middleware
::
QueuedRequest
>>
,
}
// Fallback handler for unmatched routes
...
...
@@ -186,7 +190,11 @@ pub fn build_app(
let
protected_routes
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/v1/chat/completions"
,
post
(
v1_chat_completions
))
.route
(
"/v1/completions"
,
post
(
v1_completions
));
.route
(
"/v1/completions"
,
post
(
v1_completions
))
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
app_state
.clone
(),
crate
::
middleware
::
concurrency_limit_middleware
,
));
let
public_routes
=
Router
::
new
()
.route
(
"/liveness"
,
get
(
liveness
))
...
...
@@ -282,15 +290,33 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config
.router_config
.clone
(),
client
.clone
(),
config
.router_config.max_concurrent_requests
,
config
.router_config.rate_limit_tokens_per_second
,
));
// Create router with the context
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
?
;
// Set up concurrency limiter with queue if configured
let
(
limiter
,
processor
)
=
crate
::
middleware
::
ConcurrencyLimiter
::
new
(
app_context
.rate_limiter
.clone
(),
config
.router_config.queue_size
,
Duration
::
from_secs
(
config
.router_config.queue_timeout_secs
),
);
// Start queue processor if enabled
if
let
Some
(
processor
)
=
processor
{
tokio
::
spawn
(
processor
.run
());
info!
(
"Started request queue with size: {}, timeout: {}s"
,
config
.router_config.queue_size
,
config
.router_config.queue_timeout_secs
);
}
// Create app state with router and context
let
app_state
=
Arc
::
new
(
AppState
{
router
:
Arc
::
from
(
router
),
context
:
app_context
.clone
(),
concurrency_queue_tx
:
limiter
.queue_tx
.clone
(),
});
let
router_arc
=
Arc
::
clone
(
&
app_state
.router
);
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
90313fb0
...
...
@@ -45,6 +45,9 @@ impl TestContext {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
@@ -1088,6 +1091,9 @@ mod error_tests {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
@@ -1440,6 +1446,9 @@ mod pd_mode_tests {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
@@ -1596,6 +1605,9 @@ mod request_id_tests {
log_level
:
None
,
request_id_headers
:
Some
(
vec!
[
"custom-id"
.to_string
(),
"trace-id"
.to_string
()]),
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
sgl-router/tests/common/mod.rs
View file @
90313fb0
...
...
@@ -16,6 +16,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
config
.clone
(),
reqwest
::
Client
::
new
(),
config
.max_concurrent_requests
,
config
.rate_limit_tokens_per_second
,
))
}
...
...
sgl-router/tests/common/test_app.rs
View file @
90313fb0
...
...
@@ -19,12 +19,14 @@ pub fn create_test_app(
router_config
.clone
(),
client
,
router_config
.max_concurrent_requests
,
router_config
.rate_limit_tokens_per_second
,
));
// Create AppState with the test router and context
let
app_state
=
Arc
::
new
(
AppState
{
router
,
context
:
app_context
,
concurrency_queue_tx
:
None
,
// No queue for tests
});
// Configure request ID headers (use defaults if not specified)
...
...
sgl-router/tests/request_formats_test.rs
View file @
90313fb0
...
...
@@ -36,6 +36,9 @@ impl TestContext {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
sgl-router/tests/streaming_tests.rs
View file @
90313fb0
...
...
@@ -37,6 +37,9 @@ impl TestContext {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
sgl-router/tests/test_pd_routing.rs
View file @
90313fb0
...
...
@@ -178,6 +178,8 @@ mod test_pd_routing {
log_level
:
None
,
request_id_headers
:
None
,
max_concurrent_requests
:
64
,
queue_size
:
0
,
queue_timeout_secs
:
60
,
cors_allowed_origins
:
vec!
[],
retry
:
RetryConfig
::
default
(),
circuit_breaker
:
CircuitBreakerConfig
::
default
(),
...
...
@@ -185,11 +187,12 @@ mod test_pd_routing {
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
rate_limit_tokens_per_second
:
None
,
};
// Router creation will fail due to health checks, but config should be valid
let
app_context
=
sglang_router_rs
::
server
::
AppContext
::
new
(
config
,
reqwest
::
Client
::
new
(),
64
);
sglang_router_rs
::
server
::
AppContext
::
new
(
config
,
reqwest
::
Client
::
new
(),
64
,
None
);
let
app_context
=
std
::
sync
::
Arc
::
new
(
app_context
);
let
result
=
RouterFactory
::
create_router
(
&
app_context
)
.await
;
assert
!
(
result
.is_err
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment