Unverified Commit 56321e9f authored by Jimmy's avatar Jimmy Committed by GitHub
Browse files

[Router]fix: fix get_load missing api_key (#10385)

parent 12d6cf18
......@@ -129,7 +129,9 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers
r = requests.post(
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180
f"{router_url}/add_worker",
params={"url": worker_url, "api_key": api_key},
timeout=180,
)
r.raise_for_status()
......
......@@ -139,7 +139,8 @@ def create_app(args: argparse.Namespace) -> FastAPI:
)
@app.get("/get_load")
async def get_load():
async def get_load(request: Request):
check_api_key(request)
return JSONResponse({"load": _inflight})
def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse:
......
......@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's URL
fn url(&self) -> &str;
/// Get the worker's API key
fn api_key(&self) -> &Option<String>;
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
......@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
pub labels: std::collections::HashMap<String, String>,
/// Health check configuration
pub health_config: HealthConfig,
/// API key
pub api_key: Option<String>,
}
/// Basic worker implementation
......@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
&self.metadata.url
}
fn api_key(&self) -> &Option<String> {
&self.metadata.api_key
}
fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone()
}
......@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
self.base_worker.url()
}
fn api_key(&self) -> &Option<String> {
self.base_worker.api_key()
}
fn worker_type(&self) -> WorkerType {
self.base_worker.worker_type()
}
......@@ -650,19 +661,21 @@ impl WorkerFactory {
dp_rank: usize,
dp_size: usize,
worker_type: WorkerType,
api_key: Option<String>,
) -> Box<dyn Worker> {
Box::new(
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
.worker_type(worker_type)
.build(),
)
let mut builder =
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size).worker_type(worker_type);
if let Some(api_key) = api_key {
builder = builder.api_key(api_key);
}
Box::new(builder.build())
}
#[allow(dead_code)]
/// Get DP size from a worker
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
if let Some(key) = api_key {
if let Some(key) = &api_key {
req_builder = req_builder.bearer_auth(key);
}
......@@ -708,14 +721,18 @@ impl WorkerFactory {
}
/// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
pub fn urls_to_workers(urls: Vec<String>, api_key: Option<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter()
.map(|url| {
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Regular)
.build(),
) as Box<dyn Worker>
let worker_builder = BasicWorkerBuilder::new(url).worker_type(WorkerType::Regular);
let worker = if let Some(ref api_key) = api_key {
worker_builder.api_key(api_key.clone()).build()
} else {
worker_builder.build()
};
Box::new(worker) as Box<dyn Worker>
})
.collect()
}
......@@ -961,6 +978,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://test:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
......@@ -998,6 +1016,7 @@ mod tests {
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.health_config(custom_config.clone())
.api_key("test_api_key")
.build();
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
......@@ -1011,6 +1030,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://worker1:8080");
}
......@@ -1020,6 +1040,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let regular = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(regular.worker_type(), WorkerType::Regular);
......@@ -1027,6 +1048,7 @@ mod tests {
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.api_key("test_api_key")
.build();
assert_eq!(
prefill.worker_type(),
......@@ -1037,6 +1059,7 @@ mod tests {
let decode = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Decode)
.api_key("test_api_key")
.build();
assert_eq!(decode.worker_type(), WorkerType::Decode);
}
......@@ -1065,6 +1088,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
// Initial load is 0
......@@ -1350,7 +1374,7 @@ mod tests {
fn test_urls_to_workers() {
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
let workers = urls_to_workers(urls);
let workers = urls_to_workers(urls, Some("test_api_key".to_string()));
assert_eq!(workers.len(), 2);
assert_eq!(workers[0].url(), "http://w1:8080");
assert_eq!(workers[1].url(), "http://w2:8080");
......@@ -1547,6 +1571,7 @@ mod tests {
1,
4,
WorkerType::Regular,
Some("test_api_key".to_string()),
);
assert_eq!(worker.url(), "http://worker1:8080@1");
......@@ -1565,6 +1590,7 @@ mod tests {
WorkerType::Prefill {
bootstrap_port: Some(8090),
},
Some("test_api_key".to_string()),
);
assert_eq!(worker.url(), "http://worker1:8080@0");
......@@ -1680,8 +1706,13 @@ mod tests {
.worker_type(WorkerType::Decode)
.build(),
);
let dp_aware_regular =
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
let dp_aware_regular = WorkerFactory::create_dp_aware(
"http://dp:8080".to_string(),
0,
2,
WorkerType::Regular,
Some("test_api_key".to_string()),
);
let dp_aware_prefill = WorkerFactory::create_dp_aware(
"http://dp-prefill:8080".to_string(),
1,
......@@ -1689,12 +1720,14 @@ mod tests {
WorkerType::Prefill {
bootstrap_port: None,
},
Some("test_api_key".to_string()),
);
let dp_aware_decode = WorkerFactory::create_dp_aware(
"http://dp-decode:8080".to_string(),
0,
4,
WorkerType::Decode,
Some("test_api_key".to_string()),
);
let workers: Vec<Box<dyn Worker>> = vec![
......
......@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
url: String,
// Optional fields with defaults
api_key: Option<String>,
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
......@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
api_key: None,
worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
......@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
Self {
url: url.into(),
api_key: None,
worker_type,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
......@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
}
}
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
......@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
let metadata = WorkerMetadata {
url: self.url.clone(),
api_key: self.api_key,
worker_type: self.worker_type,
connection_mode: self.connection_mode,
labels: self.labels,
......@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
pub struct DPAwareWorkerBuilder {
// Required fields
base_url: String,
api_key: Option<String>,
dp_rank: usize,
dp_size: usize,
......@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
Self {
base_url: base_url.into(),
api_key: None,
dp_rank,
dp_size,
worker_type: WorkerType::Regular,
......@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
) -> Self {
Self {
base_url: base_url.into(),
api_key: None,
dp_rank,
dp_size,
worker_type,
......@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
}
}
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
......@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
if let Some(client) = self.grpc_client {
builder = builder.grpc_client(client);
}
// Add API key if provided
if let Some(api_key) = self.api_key {
builder = builder.api_key(api_key);
}
let base_worker = builder.build();
......@@ -382,6 +405,7 @@ mod tests {
.connection_mode(ConnectionMode::Http)
.labels(labels.clone())
.health_config(health_config.clone())
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://localhost:8080@3");
......
......@@ -256,6 +256,18 @@ impl WorkerRegistry {
.collect()
}
pub fn get_all_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.workers
.iter()
.map(|entry| {
(
entry.value().url().to_string(),
entry.value().api_key().clone(),
)
})
.collect()
}
/// Get all model IDs with workers
pub fn get_models(&self) -> Vec<String> {
self.model_workers
......@@ -442,6 +454,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
......@@ -477,6 +490,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels1)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
......@@ -487,6 +501,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels2)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
......@@ -497,6 +512,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels3)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
......
......@@ -465,11 +465,13 @@ mod tests {
Arc::new(
BasicWorkerBuilder::new("http://w1:8000")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(),
),
Arc::new(
BasicWorkerBuilder::new("http://w2:8000")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(),
),
];
......
......@@ -129,16 +129,19 @@ mod tests {
Arc::new(
BasicWorkerBuilder::new("http://w1:8000")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(),
),
Arc::new(
BasicWorkerBuilder::new("http://w2:8000")
.worker_type(WorkerType::Regular)
.api_key("test_api_key2")
.build(),
),
Arc::new(
BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(),
),
];
......
......@@ -11,6 +11,10 @@ pub struct WorkerConfigRequest {
/// Worker URL (required)
pub url: String,
/// Worker API key (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
/// Model ID (optional, will query from server if not provided)
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
......
......@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter {
#[async_trait]
impl WorkerManagement for GrpcPDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Not implemented".to_string())
}
......
......@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter {
#[async_trait]
impl WorkerManagement for GrpcRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Not implemented".to_string())
}
......
......@@ -67,7 +67,11 @@ impl OpenAIRouter {
#[async_trait]
impl super::super::WorkerManagement for OpenAIRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Cannot add workers to OpenAI router".to_string())
}
......
......@@ -46,6 +46,8 @@ pub struct PDRouter {
pub prefill_client: Client,
pub retry_config: RetryConfig,
pub circuit_breaker_config: CircuitBreakerConfig,
pub api_key: Option<String>,
// Channel for sending prefill responses to background workers for draining
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
}
......@@ -113,21 +115,25 @@ impl PDRouter {
(results, errors)
}
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
(w.url().to_string(), w.api_key().clone())
}
// Helper to get prefill worker URLs
fn get_prefill_worker_urls(&self) -> Vec<String> {
fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry
.get_prefill_workers()
.iter()
.map(|w| w.url().to_string())
.map(|w| self._get_worker_url_and_key(w))
.collect()
}
// Helper to get decode worker URLs
fn get_decode_worker_urls(&self) -> Vec<String> {
fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry
.get_decode_workers()
.iter()
.map(|w| w.url().to_string())
.map(|w| self._get_worker_url_and_key(w))
.collect()
}
......@@ -208,6 +214,7 @@ impl PDRouter {
pub async fn add_prefill_server(
&self,
url: String,
api_key: Option<String>,
bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> {
// Wait for the new server to be healthy
......@@ -220,10 +227,15 @@ impl PDRouter {
// Create Worker for the new prefill server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = BasicWorkerBuilder::new(url.clone())
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill { bootstrap_port })
.circuit_breaker_config(self.circuit_breaker_config.clone())
.build();
.circuit_breaker_config(self.circuit_breaker_config.clone());
let worker = if let Some(api_key) = api_key {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc: Arc<dyn Worker> = Arc::new(worker);
......@@ -243,7 +255,11 @@ impl PDRouter {
Ok(format!("Successfully added prefill server: {}", url))
}
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
pub async fn add_decode_server(
&self,
url: String,
api_key: Option<String>,
) -> Result<String, PDRouterError> {
// Wait for the new server to be healthy
self.wait_for_server_health(&url).await?;
......@@ -254,10 +270,15 @@ impl PDRouter {
// Create Worker for the new decode server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = BasicWorkerBuilder::new(url.clone())
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode)
.circuit_breaker_config(self.circuit_breaker_config.clone())
.build();
.circuit_breaker_config(self.circuit_breaker_config.clone());
let worker = if let Some(api_key) = api_key {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc: Arc<dyn Worker> = Arc::new(worker);
......@@ -366,6 +387,12 @@ impl PDRouter {
.chain(decode_workers.iter())
.map(|w| w.url().to_string())
.collect();
// Get all worker API keys for monitoring
let all_api_keys: Vec<Option<String>> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.api_key().clone())
.collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
......@@ -387,6 +414,7 @@ impl PDRouter {
let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone();
let monitor_api_keys = all_api_keys.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let monitor_client = ctx.client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy);
......@@ -395,6 +423,7 @@ impl PDRouter {
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads_with_client(
monitor_urls,
monitor_api_keys,
tx,
monitor_interval,
monitor_client,
......@@ -500,6 +529,7 @@ impl PDRouter {
prefill_drain_tx,
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
api_key: ctx.router_config.api_key.clone(),
})
}
......@@ -1150,6 +1180,7 @@ impl PDRouter {
// Background task to monitor worker loads with shared client
async fn monitor_worker_loads_with_client(
worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
client: Client,
......@@ -1161,11 +1192,13 @@ impl PDRouter {
let futures: Vec<_> = worker_urls
.iter()
.map(|url| {
.zip(worker_api_keys.iter())
.map(|(url, api_key)| {
let client = client.clone();
let url = url.clone();
let api_key = api_key.clone();
async move {
let load = get_worker_load(&client, &url).await.unwrap_or(0);
let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0);
(url, load)
}
})
......@@ -1515,8 +1548,16 @@ impl PDRouter {
// Helper functions
async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
match client.get(format!("{}/get_load", worker_url)).send().await {
async fn get_worker_load(
client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
Ok(data) => data
......@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
#[async_trait]
impl WorkerManagement for PDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
// For PD router, we don't support adding workers via this generic method
Err(
"PD router requires specific add_prefill_server or add_decode_server methods"
......@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter {
let mut errors = Vec::new();
// Process prefill workers
let prefill_urls = self.get_prefill_worker_urls();
for worker_url in prefill_urls {
match get_worker_load(&self.client, &worker_url).await {
let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
for (worker_url, api_key) in prefill_urls_with_key {
match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
......@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter {
}
// Process decode workers
let decode_urls = self.get_decode_worker_urls();
for worker_url in decode_urls {
match get_worker_load(&self.client, &worker_url).await {
let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
for (worker_url, api_key) in decode_urls_with_key {
match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
......@@ -2069,12 +2114,14 @@ mod tests {
prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
api_key: Some("test_api_key".to_string()),
}
}
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
let worker = BasicWorkerBuilder::new(url)
.worker_type(worker_type)
.api_key("test_api_key")
.build();
worker.set_healthy(healthy);
Box::new(worker)
......
......@@ -38,6 +38,7 @@ pub struct Router {
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
dp_aware: bool,
#[allow(dead_code)]
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
......@@ -71,7 +72,6 @@ impl Router {
};
// Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
......@@ -82,6 +82,14 @@ impl Router {
// Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_api_keys = monitor_urls
.iter()
.map(|url| {
ctx.worker_registry
.get_by_url(url)
.and_then(|w| w.api_key().clone())
})
.collect::<Vec<Option<String>>>();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = default_policy.clone();
let client_clone = ctx.client.clone();
......@@ -89,6 +97,7 @@ impl Router {
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads(
monitor_urls,
monitor_api_keys,
tx,
monitor_interval,
policy_clone,
......@@ -912,7 +921,11 @@ impl Router {
}
}
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
pub async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
let start_time = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
......@@ -938,7 +951,7 @@ impl Router {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
......@@ -948,10 +961,18 @@ impl Router {
}
info!("Added worker: {}", dp_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone())
.build();
let new_worker_builder =
BasicWorkerBuilder::new(dp_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
......@@ -978,10 +999,16 @@ impl Router {
info!("Added worker: {}", worker_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone())
.build();
let new_worker_builder =
BasicWorkerBuilder::new(worker_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone());
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
......@@ -1094,7 +1121,7 @@ impl Router {
}
}
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
......@@ -1109,12 +1136,12 @@ impl Router {
worker_url
};
match self
.client
.get(format!("{}/get_load", worker_url))
.send()
.await
{
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
......@@ -1149,6 +1176,7 @@ impl Router {
// Background task to monitor worker loads
async fn monitor_worker_loads(
worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
policy: Arc<dyn LoadBalancingPolicy>,
......@@ -1160,8 +1188,8 @@ impl Router {
interval.tick().await;
let mut loads = HashMap::new();
for url in &worker_urls {
if let Some(load) = Self::get_worker_load_static(&client, url).await {
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
loads.insert(url.clone(), load);
}
}
......@@ -1179,7 +1207,11 @@ impl Router {
}
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
async fn get_worker_load_static(
client: &reqwest::Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
......@@ -1194,7 +1226,11 @@ impl Router {
worker_url
};
match client.get(format!("{}/get_load", worker_url)).send().await {
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
......@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
#[async_trait]
impl WorkerManagement for Router {
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
Router::add_worker(self, worker_url).await
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
Router::add_worker(self, worker_url, api_key).await
}
fn remove_worker(&self, worker_url: &str) {
......@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
}
async fn get_worker_loads(&self) -> Response {
let urls = self.get_worker_urls();
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
let mut loads = Vec::new();
// Get loads from all workers
for url in &urls {
let load = self.get_worker_load(url).await.unwrap_or(-1);
for (url, api_key) in &urls_with_key {
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
loads.push(serde_json::json!({
"worker": url,
"load": load
......@@ -1521,9 +1561,11 @@ mod tests {
// Register test workers
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2));
......
......@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router};
#[async_trait]
pub trait WorkerManagement: Send + Sync {
/// Add a worker to the router
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String>;
/// Remove a worker from the router
fn remove_worker(&self, worker_url: &str);
......
......@@ -161,7 +161,7 @@ impl RouterManager {
let model_id = if let Some(model_id) = config.model_id {
model_id
} else {
match self.query_server_info(&config.url).await {
match self.query_server_info(&config.url, &config.api_key).await {
Ok(info) => {
// Extract model_id from server info
info.model_id
......@@ -208,29 +208,44 @@ impl RouterManager {
}
let worker = match config.worker_type.as_deref() {
Some("prefill") => Box::new(
BasicWorkerBuilder::new(config.url.clone())
Some("prefill") => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: config.bootstrap_port,
})
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
) as Box<dyn Worker>,
Some("decode") => Box::new(
BasicWorkerBuilder::new(config.url.clone())
.circuit_breaker_config(CircuitBreakerConfig::default());
if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
Some("decode") => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Decode)
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
) as Box<dyn Worker>,
_ => Box::new(
BasicWorkerBuilder::new(config.url.clone())
.circuit_breaker_config(CircuitBreakerConfig::default());
if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
_ => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Regular)
.labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default())
.build(),
) as Box<dyn Worker>,
.circuit_breaker_config(CircuitBreakerConfig::default());
if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
};
// Register worker
......@@ -346,10 +361,18 @@ impl RouterManager {
}
/// Query server info from a worker URL
async fn query_server_info(&self, url: &str) -> Result<ServerInfo, String> {
async fn query_server_info(
&self,
url: &str,
api_key: &Option<String>,
) -> Result<ServerInfo, String> {
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
match self.client.get(&info_url).send().await {
let mut req_builder = self.client.get(&info_url);
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(response) => {
if response.status().is_success() {
response
......@@ -477,10 +500,15 @@ impl RouterManager {
#[async_trait]
impl WorkerManagement for RouterManager {
/// Add a worker - in multi-router mode, this adds to the registry
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
// Create a basic worker config request
let config = WorkerConfigRequest {
url: worker_url.to_string(),
api_key: api_key.clone(),
model_id: None,
worker_type: None,
priority: None,
......
......@@ -27,8 +27,12 @@ impl WorkerInitializer {
match &config.mode {
RoutingMode::Regular { worker_urls } => {
// use router's api_key, repeat for each worker
let worker_api_keys: Vec<Option<String>> =
worker_urls.iter().map(|_| config.api_key.clone()).collect();
Self::create_regular_workers(
worker_urls,
&worker_api_keys,
&config.connection_mode,
config,
worker_registry,
......@@ -41,8 +45,16 @@ impl WorkerInitializer {
decode_urls,
..
} => {
// use router's api_key, repeat for each prefill/decode worker
let prefill_api_keys: Vec<Option<String>> = prefill_urls
.iter()
.map(|_| config.api_key.clone())
.collect();
let decode_api_keys: Vec<Option<String>> =
decode_urls.iter().map(|_| config.api_key.clone()).collect();
Self::create_prefill_workers(
prefill_urls,
&prefill_api_keys,
&config.connection_mode,
config,
worker_registry,
......@@ -51,6 +63,7 @@ impl WorkerInitializer {
.await?;
Self::create_decode_workers(
decode_urls,
&decode_api_keys,
&config.connection_mode,
config,
worker_registry,
......@@ -79,6 +92,7 @@ impl WorkerInitializer {
/// Create regular workers for standard routing mode
async fn create_regular_workers(
urls: &[String],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
......@@ -109,14 +123,18 @@ impl WorkerInitializer {
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
.health_config(health_config.clone());
let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
......@@ -148,6 +166,7 @@ impl WorkerInitializer {
/// Create prefill workers for disaggregated routing mode
async fn create_prefill_workers(
prefill_entries: &[(String, Option<u16>)],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
......@@ -181,16 +200,20 @@ impl WorkerInitializer {
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, bootstrap_port) in prefill_entries {
for ((url, bootstrap_port), api_key) in prefill_entries.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
})
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
.health_config(health_config.clone());
let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
......@@ -227,6 +250,7 @@ impl WorkerInitializer {
/// Create decode workers for disaggregated routing mode
async fn create_decode_workers(
urls: &[String],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
......@@ -257,14 +281,18 @@ impl WorkerInitializer {
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
.health_config(health_config.clone());
let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
......
......@@ -282,15 +282,16 @@ async fn v1_responses_list_input_items(
// ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)]
struct UrlQuery {
struct AddWorkerQuery {
url: String,
api_key: Option<String>,
}
async fn add_worker(
State(state): State<Arc<AppState>>,
Query(UrlQuery { url }): Query<UrlQuery>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response {
match state.router.add_worker(&url).await {
match state.router.add_worker(&url, &api_key).await {
Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
}
......@@ -303,7 +304,7 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
async fn remove_worker(
State(state): State<Arc<AppState>>,
Query(UrlQuery { url }): Query<UrlQuery>,
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
) -> Response {
state.router.remove_worker(&url);
(
......@@ -337,7 +338,7 @@ async fn create_worker(
}
} else {
// In single router mode, use the router's add_worker with basic config
match state.router.add_worker(&config.url).await {
match state.router.add_worker(&config.url, &config.api_key).await {
Ok(message) => {
let response = WorkerApiResponse {
success: true,
......
......@@ -389,16 +389,20 @@ async fn handle_pod_event(
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type {
Some(PodType::Prefill) => pd_router
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
.add_prefill_server(
worker_url.clone(),
pd_router.api_key.clone(),
pod_info.bootstrap_port,
)
.await
.map_err(|e| e.to_string()),
Some(PodType::Decode) => pd_router
.add_decode_server(worker_url.clone())
.add_decode_server(worker_url.clone(), pd_router.api_key.clone())
.await
.map_err(|e| e.to_string()),
Some(PodType::Regular) | None => {
// Fall back to regular add_worker for regular pods
router.add_worker(&worker_url).await
router.add_worker(&worker_url, &pd_router.api_key).await
}
}
} else {
......@@ -406,7 +410,8 @@ async fn handle_pod_event(
}
} else {
// Regular mode or no pod type specified
router.add_worker(&worker_url).await
// In pod, no need api key
router.add_worker(&worker_url, &None).await
};
match result {
......
......@@ -18,6 +18,7 @@ fn test_backward_compatibility_with_empty_model_id() {
// Create workers with empty model_id (simulating existing routers)
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
// No model_id label - should default to "unknown"
......@@ -25,6 +26,7 @@ fn test_backward_compatibility_with_empty_model_id() {
labels2.insert("model_id".to_string(), "unknown".to_string());
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.labels(labels2)
.build();
......@@ -59,6 +61,7 @@ fn test_mixed_model_ids() {
// Create workers with different model_id scenarios
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
// No model_id label - defaults to "unknown" which goes to "default" tree
......@@ -67,6 +70,7 @@ fn test_mixed_model_ids() {
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.labels(labels2)
.api_key("test_api_key")
.build();
let mut labels3 = HashMap::new();
......@@ -123,10 +127,12 @@ fn test_remove_worker_by_url_backward_compat() {
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.labels(labels1)
.api_key("test_api_key")
.build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
// No model_id label - defaults to "unknown"
......
......@@ -41,6 +41,7 @@ async fn test_policy_registry_with_router_manager() {
let _worker1_config = WorkerConfigRequest {
url: "http://worker1:8000".to_string(),
model_id: Some("llama-3".to_string()),
api_key: Some("test_api_key".to_string()),
worker_type: None,
priority: None,
cost: None,
......@@ -66,6 +67,7 @@ async fn test_policy_registry_with_router_manager() {
let _worker2_config = WorkerConfigRequest {
url: "http://worker2:8000".to_string(),
model_id: Some("llama-3".to_string()),
api_key: Some("test_api_key".to_string()),
worker_type: None,
priority: None,
cost: None,
......@@ -86,6 +88,7 @@ async fn test_policy_registry_with_router_manager() {
let _worker3_config = WorkerConfigRequest {
url: "http://worker3:8000".to_string(),
model_id: Some("gpt-4".to_string()),
api_key: Some("test_api_key".to_string()),
worker_type: None,
priority: None,
cost: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment