Unverified Commit ee704e62 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] add auth middleware for api key auth (#10826)

parent f4e3ebeb
...@@ -69,6 +69,7 @@ rmcp = { version = "0.6.3", features = ["client", "server", ...@@ -69,6 +69,7 @@ rmcp = { version = "0.6.3", features = ["client", "server",
"reqwest", "reqwest",
"auth"] } "auth"] }
serde_yaml = "0.9" serde_yaml = "0.9"
subtle = "2.6"
# gRPC and Protobuf dependencies # gRPC and Protobuf dependencies
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
......
...@@ -331,6 +331,79 @@ python -m sglang_router.launch_router \ ...@@ -331,6 +331,79 @@ python -m sglang_router.launch_router \
--prometheus-port 9090 --prometheus-port 9090
``` ```
### API Key Authentication
The router supports multi-level API key authentication for both the router itself and individual workers:
#### Router API Key
Protect access to the router endpoints:
```bash
python -m sglang_router.launch_router \
--api-key "your-router-api-key" \
--worker-urls http://worker1:8000 http://worker2:8000
```
When router API key is set, clients must include the Bearer token:
```bash
curl -H "Authorization: Bearer your-router-api-key" http://localhost:8080/v1/chat/completions
```
#### Worker API Keys
Workers can have their own API keys for authentication:
```bash
# Workers specified in --worker-urls automatically inherit the router's API key
python -m sglang_router.launch_router \
--api-key "shared-api-key" \
--worker-urls http://worker1:8000 http://worker2:8000
# Both workers will use "shared-api-key" for authentication
# Adding workers dynamically WITHOUT inheriting router's key
curl -X POST http://localhost:8080/add_worker?url=http://worker3:8000
# WARNING: This worker has NO API key even though router has one!
# Adding workers with specific API keys dynamically
curl -X POST http://localhost:8080/add_worker?url=http://worker3:8000&api_key=worker3-specific-key
```
#### Security Configurations
1. **No Authentication** (default):
- Router and workers accessible without keys
- Suitable for trusted environments
2. **Router-only Authentication**:
- Clients need key to access router
- Router can access workers freely
3. **Worker-only Authentication**:
- Router accessible without key
- Each worker requires authentication
```bash
# Add workers with their API keys
curl -X POST http://localhost:8080/add_worker?url=http://worker:8000&api_key=worker-key
```
4. **Full Authentication**:
- Router requires key from clients
- Each worker requires its own key
```bash
# Start router with its key
python -m sglang_router.launch_router --api-key "router-key"
# Add workers with their keys
curl -H "Authorization: Bearer router-key" \
-X POST http://localhost:8080/add_worker?url=http://worker:8000&api_key=worker-key
```
#### Important Notes
- **Initial Workers**: Workers specified in `--worker-urls` automatically inherit the router's API key
- **Dynamic Workers**: When adding workers via API, you must explicitly specify their API keys - they do NOT inherit the router's key
- **Security Warning**: When adding workers without API keys while the router has one configured, a warning will be logged
- **Common Pitfall**: If router and workers use the same API key, you must still specify the key when adding workers dynamically
### Command Line Arguments Reference ### Command Line Arguments Reference
#### Service Discovery #### Service Discovery
...@@ -349,6 +422,9 @@ python -m sglang_router.launch_router \ ...@@ -349,6 +422,9 @@ python -m sglang_router.launch_router \
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill) - `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode) - `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)
#### Authentication
- `--api-key`: API key for router authentication (clients must provide this as Bearer token)
## Development ## Development
### Build Process ### Build Process
......
...@@ -131,31 +131,29 @@ def test_dp_aware_worker_expansion_and_api_key( ...@@ -131,31 +131,29 @@ def test_dp_aware_worker_expansion_and_api_key(
r = requests.post( r = requests.post(
f"{router_url}/add_worker", f"{router_url}/add_worker",
params={"url": worker_url, "api_key": api_key}, params={"url": worker_url, "api_key": api_key},
headers={"Authorization": f"Bearer {api_key}"},
timeout=180, timeout=180,
) )
r.raise_for_status() r.raise_for_status()
r = requests.get(f"{router_url}/list_workers", timeout=30) r = requests.get(
f"{router_url}/list_workers",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30,
)
r.raise_for_status() r.raise_for_status()
urls = r.json().get("urls", []) urls = r.json().get("urls", [])
assert len(urls) == 2 assert len(urls) == 2
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"} assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
# TODO: Router currently doesn't enforce API key authentication on incoming requests. # Verify API key enforcement
# It only adds the API key to outgoing requests to workers. # 1) Without Authorization -> Should get 401 Unauthorized
# Need to implement auth middleware to properly protect router endpoints.
# For now, both requests succeed (200) regardless of client authentication.
# Verify API key enforcement path-through
# 1) Without Authorization -> Currently 200 (should be 401 after auth middleware added)
r = requests.post( r = requests.post(
f"{router_url}/v1/completions", f"{router_url}/v1/completions",
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1}, json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
timeout=60, timeout=60,
) )
assert ( assert r.status_code == 401
r.status_code == 200
) # TODO: Change to 401 after auth middleware implementation
# 2) With correct Authorization -> 200 # 2) With correct Authorization -> 200
r = requests.post( r = requests.post(
......
use axum::{ use axum::{
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next, body::Body, extract::Request, extract::State, http::header, http::HeaderValue,
response::IntoResponse, response::Response, http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
}; };
use rand::Rng; use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service}; use tower::{Layer, Service};
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
...@@ -17,6 +18,49 @@ pub use crate::core::token_bucket::TokenBucket; ...@@ -17,6 +18,49 @@ pub use crate::core::token_bucket::TokenBucket;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::server::AppState; use crate::server::AppState;
#[derive(Clone)]
pub struct AuthConfig {
pub api_key: Option<String>,
}
/// Middleware to validate Bearer token against configured API key
/// Only active when router has an API key configured
pub async fn auth_middleware(
State(auth_config): State<AuthConfig>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if let Some(expected_key) = &auth_config.api_key {
// Extract Authorization header
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
match auth_header {
Some(header_value) if header_value.starts_with("Bearer ") => {
let token = &header_value[7..]; // Skip "Bearer "
// Use constant-time comparison to prevent timing attacks
let token_bytes = token.as_bytes();
let expected_bytes = expected_key.as_bytes();
// Check if lengths match first (this is not constant-time but necessary)
if token_bytes.len() != expected_bytes.len() {
return Err(StatusCode::UNAUTHORIZED);
}
// Constant-time comparison of the actual values
if token_bytes.ct_eq(expected_bytes).unwrap_u8() != 1 {
return Err(StatusCode::UNAUTHORIZED);
}
}
_ => return Err(StatusCode::UNAUTHORIZED),
}
}
Ok(next.run(request).await)
}
/// Generate OpenAI-compatible request ID based on endpoint /// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String { fn generate_request_id(path: &str) -> String {
let prefix = if path.contains("/chat/completions") { let prefix = if path.contains("/chat/completions") {
......
...@@ -4,7 +4,7 @@ use crate::{ ...@@ -4,7 +4,7 @@ use crate::{
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage}, data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
logging::{self, LoggingConfig}, logging::{self, LoggingConfig},
metrics::{self, PrometheusConfig}, metrics::{self, PrometheusConfig},
middleware::{self, QueuedRequest, TokenBucket}, middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
policies::PolicyRegistry, policies::PolicyRegistry,
protocols::{ protocols::{
spec::{ spec::{
...@@ -275,6 +275,16 @@ async fn add_worker( ...@@ -275,6 +275,16 @@ async fn add_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>, Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response { ) -> Response {
// Warn if router has API key but worker is being added without one
if state.context.router_config.api_key.is_some() && api_key.is_none() {
warn!(
"Adding worker {} without API key while router has API key configured. \
Worker will be accessible without authentication. \
If the worker requires the same API key as the router, please specify it explicitly.",
url
);
}
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await; let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
match result { match result {
...@@ -312,6 +322,16 @@ async fn create_worker( ...@@ -312,6 +322,16 @@ async fn create_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>, Json(config): Json<WorkerConfigRequest>,
) -> Response { ) -> Response {
// Warn if router has API key but worker is being added without one
if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
warn!(
"Adding worker {} without API key while router has API key configured. \
Worker will be accessible without authentication. \
If the worker requires the same API key as the router, please specify it explicitly.",
config.url
);
}
let result = WorkerManager::add_worker_from_config(&config, &state.context).await; let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
match result { match result {
...@@ -423,6 +443,7 @@ pub struct ServerConfig { ...@@ -423,6 +443,7 @@ pub struct ServerConfig {
pub fn build_app( pub fn build_app(
app_state: Arc<AppState>, app_state: Arc<AppState>,
auth_config: AuthConfig,
max_payload_size: usize, max_payload_size: usize,
request_id_headers: Vec<String>, request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>, cors_allowed_origins: Vec<String>,
...@@ -448,6 +469,10 @@ pub fn build_app( ...@@ -448,6 +469,10 @@ pub fn build_app(
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
app_state.clone(), app_state.clone(),
middleware::concurrency_limit_middleware, middleware::concurrency_limit_middleware,
))
.route_layer(axum::middleware::from_fn_with_state(
auth_config.clone(),
middleware::auth_middleware,
)); ));
let public_routes = Router::new() let public_routes = Router::new()
...@@ -464,13 +489,21 @@ pub fn build_app( ...@@ -464,13 +489,21 @@ pub fn build_app(
.route("/remove_worker", post(remove_worker)) .route("/remove_worker", post(remove_worker))
.route("/list_workers", get(list_workers)) .route("/list_workers", get(list_workers))
.route("/flush_cache", post(flush_cache)) .route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads)); .route("/get_loads", get(get_loads))
.route_layer(axum::middleware::from_fn_with_state(
auth_config.clone(),
middleware::auth_middleware,
));
let worker_routes = Router::new() let worker_routes = Router::new()
.route("/workers", post(create_worker)) .route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest)) .route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker)) .route("/workers/{url}", get(get_worker))
.route("/workers/{url}", delete(delete_worker)); .route("/workers/{url}", delete(delete_worker))
.route_layer(axum::middleware::from_fn_with_state(
auth_config.clone(),
middleware::auth_middleware,
));
Router::new() Router::new()
.merge(protected_routes) .merge(protected_routes)
...@@ -629,8 +662,13 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -629,8 +662,13 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
] ]
}); });
let auth_config = AuthConfig {
api_key: config.router_config.api_key.clone(),
};
let app = build_app( let app = build_app(
app_state, app_state,
auth_config,
config.max_payload_size, config.max_payload_size,
request_id_headers, request_id_headers,
config.router_config.cors_allowed_origins.clone(), config.router_config.cors_allowed_origins.clone(),
......
...@@ -2,6 +2,7 @@ use axum::Router; ...@@ -2,6 +2,7 @@ use axum::Router;
use reqwest::Client; use reqwest::Client;
use sglang_router_rs::{ use sglang_router_rs::{
config::RouterConfig, config::RouterConfig,
middleware::AuthConfig,
routers::RouterTrait, routers::RouterTrait,
server::{build_app, AppContext, AppState}, server::{build_app, AppContext, AppState},
}; };
...@@ -43,9 +44,15 @@ pub fn create_test_app( ...@@ -43,9 +44,15 @@ pub fn create_test_app(
] ]
}); });
// Create auth config from router config
let auth_config = AuthConfig {
api_key: router_config.api_key.clone(),
};
// Use the actual server's build_app function // Use the actual server's build_app function
build_app( build_app(
app_state, app_state,
auth_config,
router_config.max_payload_size, router_config.max_payload_size,
request_id_headers, request_id_headers,
router_config.cors_allowed_origins.clone(), router_config.cors_allowed_origins.clone(),
...@@ -79,9 +86,15 @@ pub fn create_test_app_with_context( ...@@ -79,9 +86,15 @@ pub fn create_test_app_with_context(
] ]
}); });
// Create auth config from router config
let auth_config = AuthConfig {
api_key: router_config.api_key.clone(),
};
// Use the actual server's build_app function // Use the actual server's build_app function
build_app( build_app(
app_state, app_state,
auth_config,
router_config.max_payload_size, router_config.max_payload_size,
request_id_headers, request_id_headers,
router_config.cors_allowed_origins.clone(), router_config.cors_allowed_origins.clone(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment