Unverified Commit 56321e9f authored by Jimmy's avatar Jimmy Committed by GitHub
Browse files

[Router]fix: fix get_load missing api_key (#10385)

parent 12d6cf18
...@@ -129,7 +129,9 @@ def test_dp_aware_worker_expansion_and_api_key( ...@@ -129,7 +129,9 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers # Attach worker; router should expand to dp_size logical workers
r = requests.post( r = requests.post(
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180 f"{router_url}/add_worker",
params={"url": worker_url, "api_key": api_key},
timeout=180,
) )
r.raise_for_status() r.raise_for_status()
......
...@@ -139,7 +139,8 @@ def create_app(args: argparse.Namespace) -> FastAPI: ...@@ -139,7 +139,8 @@ def create_app(args: argparse.Namespace) -> FastAPI:
) )
@app.get("/get_load") @app.get("/get_load")
async def get_load(): async def get_load(request: Request):
check_api_key(request)
return JSONResponse({"load": _inflight}) return JSONResponse({"load": _inflight})
def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse: def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse:
......
...@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { ...@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
pub trait Worker: Send + Sync + fmt::Debug { pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's URL /// Get the worker's URL
fn url(&self) -> &str; fn url(&self) -> &str;
/// Get the worker's API key
fn api_key(&self) -> &Option<String>;
/// Get the worker's type (Regular, Prefill, or Decode) /// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType; fn worker_type(&self) -> WorkerType;
...@@ -323,6 +324,8 @@ pub struct WorkerMetadata { ...@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
pub labels: std::collections::HashMap<String, String>, pub labels: std::collections::HashMap<String, String>,
/// Health check configuration /// Health check configuration
pub health_config: HealthConfig, pub health_config: HealthConfig,
/// API key
pub api_key: Option<String>,
} }
/// Basic worker implementation /// Basic worker implementation
...@@ -379,6 +382,10 @@ impl Worker for BasicWorker { ...@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
&self.metadata.url &self.metadata.url
} }
fn api_key(&self) -> &Option<String> {
&self.metadata.api_key
}
fn worker_type(&self) -> WorkerType { fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone() self.metadata.worker_type.clone()
} }
...@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker { ...@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
self.base_worker.url() self.base_worker.url()
} }
fn api_key(&self) -> &Option<String> {
self.base_worker.api_key()
}
fn worker_type(&self) -> WorkerType { fn worker_type(&self) -> WorkerType {
self.base_worker.worker_type() self.base_worker.worker_type()
} }
...@@ -650,19 +661,21 @@ impl WorkerFactory { ...@@ -650,19 +661,21 @@ impl WorkerFactory {
dp_rank: usize, dp_rank: usize,
dp_size: usize, dp_size: usize,
worker_type: WorkerType, worker_type: WorkerType,
api_key: Option<String>,
) -> Box<dyn Worker> { ) -> Box<dyn Worker> {
Box::new( let mut builder =
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size) DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size).worker_type(worker_type);
.worker_type(worker_type) if let Some(api_key) = api_key {
.build(), builder = builder.api_key(api_key);
) }
Box::new(builder.build())
} }
#[allow(dead_code)] #[allow(dead_code)]
/// Get DP size from a worker /// Get DP size from a worker
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> { async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url)); let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
if let Some(key) = api_key { if let Some(key) = &api_key {
req_builder = req_builder.bearer_auth(key); req_builder = req_builder.bearer_auth(key);
} }
...@@ -708,14 +721,18 @@ impl WorkerFactory { ...@@ -708,14 +721,18 @@ impl WorkerFactory {
} }
/// Convert a list of worker URLs to worker trait objects /// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> { pub fn urls_to_workers(urls: Vec<String>, api_key: Option<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter() urls.into_iter()
.map(|url| { .map(|url| {
Box::new( let worker_builder = BasicWorkerBuilder::new(url).worker_type(WorkerType::Regular);
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Regular) let worker = if let Some(ref api_key) = api_key {
.build(), worker_builder.api_key(api_key.clone()).build()
) as Box<dyn Worker> } else {
worker_builder.build()
};
Box::new(worker) as Box<dyn Worker>
}) })
.collect() .collect()
} }
...@@ -961,6 +978,7 @@ mod tests { ...@@ -961,6 +978,7 @@ mod tests {
use crate::core::BasicWorkerBuilder; use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080") let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
assert_eq!(worker.url(), "http://test:8080"); assert_eq!(worker.url(), "http://test:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular); assert_eq!(worker.worker_type(), WorkerType::Regular);
...@@ -998,6 +1016,7 @@ mod tests { ...@@ -998,6 +1016,7 @@ mod tests {
let worker = BasicWorkerBuilder::new("http://test:8080") let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.health_config(custom_config.clone()) .health_config(custom_config.clone())
.api_key("test_api_key")
.build(); .build();
assert_eq!(worker.metadata().health_config.timeout_secs, 15); assert_eq!(worker.metadata().health_config.timeout_secs, 15);
...@@ -1011,6 +1030,7 @@ mod tests { ...@@ -1011,6 +1030,7 @@ mod tests {
use crate::core::BasicWorkerBuilder; use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://worker1:8080") let worker = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
assert_eq!(worker.url(), "http://worker1:8080"); assert_eq!(worker.url(), "http://worker1:8080");
} }
...@@ -1020,6 +1040,7 @@ mod tests { ...@@ -1020,6 +1040,7 @@ mod tests {
use crate::core::BasicWorkerBuilder; use crate::core::BasicWorkerBuilder;
let regular = BasicWorkerBuilder::new("http://test:8080") let regular = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
assert_eq!(regular.worker_type(), WorkerType::Regular); assert_eq!(regular.worker_type(), WorkerType::Regular);
...@@ -1027,6 +1048,7 @@ mod tests { ...@@ -1027,6 +1048,7 @@ mod tests {
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090), bootstrap_port: Some(9090),
}) })
.api_key("test_api_key")
.build(); .build();
assert_eq!( assert_eq!(
prefill.worker_type(), prefill.worker_type(),
...@@ -1037,6 +1059,7 @@ mod tests { ...@@ -1037,6 +1059,7 @@ mod tests {
let decode = BasicWorkerBuilder::new("http://test:8080") let decode = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.api_key("test_api_key")
.build(); .build();
assert_eq!(decode.worker_type(), WorkerType::Decode); assert_eq!(decode.worker_type(), WorkerType::Decode);
} }
...@@ -1065,6 +1088,7 @@ mod tests { ...@@ -1065,6 +1088,7 @@ mod tests {
use crate::core::BasicWorkerBuilder; use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080") let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
// Initial load is 0 // Initial load is 0
...@@ -1350,7 +1374,7 @@ mod tests { ...@@ -1350,7 +1374,7 @@ mod tests {
fn test_urls_to_workers() { fn test_urls_to_workers() {
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()]; let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
let workers = urls_to_workers(urls); let workers = urls_to_workers(urls, Some("test_api_key".to_string()));
assert_eq!(workers.len(), 2); assert_eq!(workers.len(), 2);
assert_eq!(workers[0].url(), "http://w1:8080"); assert_eq!(workers[0].url(), "http://w1:8080");
assert_eq!(workers[1].url(), "http://w2:8080"); assert_eq!(workers[1].url(), "http://w2:8080");
...@@ -1547,6 +1571,7 @@ mod tests { ...@@ -1547,6 +1571,7 @@ mod tests {
1, 1,
4, 4,
WorkerType::Regular, WorkerType::Regular,
Some("test_api_key".to_string()),
); );
assert_eq!(worker.url(), "http://worker1:8080@1"); assert_eq!(worker.url(), "http://worker1:8080@1");
...@@ -1565,6 +1590,7 @@ mod tests { ...@@ -1565,6 +1590,7 @@ mod tests {
WorkerType::Prefill { WorkerType::Prefill {
bootstrap_port: Some(8090), bootstrap_port: Some(8090),
}, },
Some("test_api_key".to_string()),
); );
assert_eq!(worker.url(), "http://worker1:8080@0"); assert_eq!(worker.url(), "http://worker1:8080@0");
...@@ -1680,8 +1706,13 @@ mod tests { ...@@ -1680,8 +1706,13 @@ mod tests {
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.build(), .build(),
); );
let dp_aware_regular = let dp_aware_regular = WorkerFactory::create_dp_aware(
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular); "http://dp:8080".to_string(),
0,
2,
WorkerType::Regular,
Some("test_api_key".to_string()),
);
let dp_aware_prefill = WorkerFactory::create_dp_aware( let dp_aware_prefill = WorkerFactory::create_dp_aware(
"http://dp-prefill:8080".to_string(), "http://dp-prefill:8080".to_string(),
1, 1,
...@@ -1689,12 +1720,14 @@ mod tests { ...@@ -1689,12 +1720,14 @@ mod tests {
WorkerType::Prefill { WorkerType::Prefill {
bootstrap_port: None, bootstrap_port: None,
}, },
Some("test_api_key".to_string()),
); );
let dp_aware_decode = WorkerFactory::create_dp_aware( let dp_aware_decode = WorkerFactory::create_dp_aware(
"http://dp-decode:8080".to_string(), "http://dp-decode:8080".to_string(),
0, 0,
4, 4,
WorkerType::Decode, WorkerType::Decode,
Some("test_api_key".to_string()),
); );
let workers: Vec<Box<dyn Worker>> = vec![ let workers: Vec<Box<dyn Worker>> = vec![
......
...@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder { ...@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
url: String, url: String,
// Optional fields with defaults // Optional fields with defaults
api_key: Option<String>,
worker_type: WorkerType, worker_type: WorkerType,
connection_mode: ConnectionMode, connection_mode: ConnectionMode,
labels: HashMap<String, String>, labels: HashMap<String, String>,
...@@ -24,6 +25,7 @@ impl BasicWorkerBuilder { ...@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
pub fn new(url: impl Into<String>) -> Self { pub fn new(url: impl Into<String>) -> Self {
Self { Self {
url: url.into(), url: url.into(),
api_key: None,
worker_type: WorkerType::Regular, worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http, connection_mode: ConnectionMode::Http,
labels: HashMap::new(), labels: HashMap::new(),
...@@ -37,6 +39,7 @@ impl BasicWorkerBuilder { ...@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self { pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
Self { Self {
url: url.into(), url: url.into(),
api_key: None,
worker_type, worker_type,
connection_mode: ConnectionMode::Http, connection_mode: ConnectionMode::Http,
labels: HashMap::new(), labels: HashMap::new(),
...@@ -46,6 +49,12 @@ impl BasicWorkerBuilder { ...@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
} }
} }
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode) /// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self { pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type; self.worker_type = worker_type;
...@@ -98,6 +107,7 @@ impl BasicWorkerBuilder { ...@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
let metadata = WorkerMetadata { let metadata = WorkerMetadata {
url: self.url.clone(), url: self.url.clone(),
api_key: self.api_key,
worker_type: self.worker_type, worker_type: self.worker_type,
connection_mode: self.connection_mode, connection_mode: self.connection_mode,
labels: self.labels, labels: self.labels,
...@@ -121,6 +131,7 @@ impl BasicWorkerBuilder { ...@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
pub struct DPAwareWorkerBuilder { pub struct DPAwareWorkerBuilder {
// Required fields // Required fields
base_url: String, base_url: String,
api_key: Option<String>,
dp_rank: usize, dp_rank: usize,
dp_size: usize, dp_size: usize,
...@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder { ...@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self { pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
Self { Self {
base_url: base_url.into(), base_url: base_url.into(),
api_key: None,
dp_rank, dp_rank,
dp_size, dp_size,
worker_type: WorkerType::Regular, worker_type: WorkerType::Regular,
...@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder { ...@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
) -> Self { ) -> Self {
Self { Self {
base_url: base_url.into(), base_url: base_url.into(),
api_key: None,
dp_rank, dp_rank,
dp_size, dp_size,
worker_type, worker_type,
...@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder { ...@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
} }
} }
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode) /// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self { pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type; self.worker_type = worker_type;
...@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder { ...@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
if let Some(client) = self.grpc_client { if let Some(client) = self.grpc_client {
builder = builder.grpc_client(client); builder = builder.grpc_client(client);
} }
// Add API key if provided
if let Some(api_key) = self.api_key {
builder = builder.api_key(api_key);
}
let base_worker = builder.build(); let base_worker = builder.build();
...@@ -382,6 +405,7 @@ mod tests { ...@@ -382,6 +405,7 @@ mod tests {
.connection_mode(ConnectionMode::Http) .connection_mode(ConnectionMode::Http)
.labels(labels.clone()) .labels(labels.clone())
.health_config(health_config.clone()) .health_config(health_config.clone())
.api_key("test_api_key")
.build(); .build();
assert_eq!(worker.url(), "http://localhost:8080@3"); assert_eq!(worker.url(), "http://localhost:8080@3");
......
...@@ -256,6 +256,18 @@ impl WorkerRegistry { ...@@ -256,6 +256,18 @@ impl WorkerRegistry {
.collect() .collect()
} }
pub fn get_all_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.workers
.iter()
.map(|entry| {
(
entry.value().url().to_string(),
entry.value().api_key().clone(),
)
})
.collect()
}
/// Get all model IDs with workers /// Get all model IDs with workers
pub fn get_models(&self) -> Vec<String> { pub fn get_models(&self) -> Vec<String> {
self.model_workers self.model_workers
...@@ -442,6 +454,7 @@ mod tests { ...@@ -442,6 +454,7 @@ mod tests {
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels) .labels(labels)
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(), .build(),
); );
...@@ -477,6 +490,7 @@ mod tests { ...@@ -477,6 +490,7 @@ mod tests {
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels1) .labels(labels1)
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(), .build(),
); );
...@@ -487,6 +501,7 @@ mod tests { ...@@ -487,6 +501,7 @@ mod tests {
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels2) .labels(labels2)
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(), .build(),
); );
...@@ -497,6 +512,7 @@ mod tests { ...@@ -497,6 +512,7 @@ mod tests {
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels3) .labels(labels3)
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(), .build(),
); );
......
...@@ -465,11 +465,13 @@ mod tests { ...@@ -465,11 +465,13 @@ mod tests {
Arc::new( Arc::new(
BasicWorkerBuilder::new("http://w1:8000") BasicWorkerBuilder::new("http://w1:8000")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(), .build(),
), ),
Arc::new( Arc::new(
BasicWorkerBuilder::new("http://w2:8000") BasicWorkerBuilder::new("http://w2:8000")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(), .build(),
), ),
]; ];
......
...@@ -129,16 +129,19 @@ mod tests { ...@@ -129,16 +129,19 @@ mod tests {
Arc::new( Arc::new(
BasicWorkerBuilder::new("http://w1:8000") BasicWorkerBuilder::new("http://w1:8000")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(), .build(),
), ),
Arc::new( Arc::new(
BasicWorkerBuilder::new("http://w2:8000") BasicWorkerBuilder::new("http://w2:8000")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key2")
.build(), .build(),
), ),
Arc::new( Arc::new(
BasicWorkerBuilder::new("http://w3:8000") BasicWorkerBuilder::new("http://w3:8000")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(), .build(),
), ),
]; ];
......
...@@ -11,6 +11,10 @@ pub struct WorkerConfigRequest { ...@@ -11,6 +11,10 @@ pub struct WorkerConfigRequest {
/// Worker URL (required) /// Worker URL (required)
pub url: String, pub url: String,
/// Worker API key (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
/// Model ID (optional, will query from server if not provided) /// Model ID (optional, will query from server if not provided)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>, pub model_id: Option<String>,
......
...@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter { ...@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter {
#[async_trait] #[async_trait]
impl WorkerManagement for GrpcPDRouter { impl WorkerManagement for GrpcPDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> { async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Not implemented".to_string()) Err("Not implemented".to_string())
} }
......
...@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter { ...@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter {
#[async_trait] #[async_trait]
impl WorkerManagement for GrpcRouter { impl WorkerManagement for GrpcRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> { async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Not implemented".to_string()) Err("Not implemented".to_string())
} }
......
...@@ -67,7 +67,11 @@ impl OpenAIRouter { ...@@ -67,7 +67,11 @@ impl OpenAIRouter {
#[async_trait] #[async_trait]
impl super::super::WorkerManagement for OpenAIRouter { impl super::super::WorkerManagement for OpenAIRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> { async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
Err("Cannot add workers to OpenAI router".to_string()) Err("Cannot add workers to OpenAI router".to_string())
} }
......
...@@ -46,6 +46,8 @@ pub struct PDRouter { ...@@ -46,6 +46,8 @@ pub struct PDRouter {
pub prefill_client: Client, pub prefill_client: Client,
pub retry_config: RetryConfig, pub retry_config: RetryConfig,
pub circuit_breaker_config: CircuitBreakerConfig, pub circuit_breaker_config: CircuitBreakerConfig,
pub api_key: Option<String>,
// Channel for sending prefill responses to background workers for draining // Channel for sending prefill responses to background workers for draining
prefill_drain_tx: mpsc::Sender<reqwest::Response>, prefill_drain_tx: mpsc::Sender<reqwest::Response>,
} }
...@@ -113,21 +115,25 @@ impl PDRouter { ...@@ -113,21 +115,25 @@ impl PDRouter {
(results, errors) (results, errors)
} }
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
(w.url().to_string(), w.api_key().clone())
}
// Helper to get prefill worker URLs // Helper to get prefill worker URLs
fn get_prefill_worker_urls(&self) -> Vec<String> { fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry self.worker_registry
.get_prefill_workers() .get_prefill_workers()
.iter() .iter()
.map(|w| w.url().to_string()) .map(|w| self._get_worker_url_and_key(w))
.collect() .collect()
} }
// Helper to get decode worker URLs // Helper to get decode worker URLs
fn get_decode_worker_urls(&self) -> Vec<String> { fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.worker_registry self.worker_registry
.get_decode_workers() .get_decode_workers()
.iter() .iter()
.map(|w| w.url().to_string()) .map(|w| self._get_worker_url_and_key(w))
.collect() .collect()
} }
...@@ -208,6 +214,7 @@ impl PDRouter { ...@@ -208,6 +214,7 @@ impl PDRouter {
pub async fn add_prefill_server( pub async fn add_prefill_server(
&self, &self,
url: String, url: String,
api_key: Option<String>,
bootstrap_port: Option<u16>, bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> { ) -> Result<String, PDRouterError> {
// Wait for the new server to be healthy // Wait for the new server to be healthy
...@@ -220,10 +227,15 @@ impl PDRouter { ...@@ -220,10 +227,15 @@ impl PDRouter {
// Create Worker for the new prefill server with circuit breaker configuration // Create Worker for the new prefill server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = BasicWorkerBuilder::new(url.clone()) let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill { bootstrap_port }) .worker_type(WorkerType::Prefill { bootstrap_port })
.circuit_breaker_config(self.circuit_breaker_config.clone()) .circuit_breaker_config(self.circuit_breaker_config.clone());
.build();
let worker = if let Some(api_key) = api_key {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc: Arc<dyn Worker> = Arc::new(worker); let worker_arc: Arc<dyn Worker> = Arc::new(worker);
...@@ -243,7 +255,11 @@ impl PDRouter { ...@@ -243,7 +255,11 @@ impl PDRouter {
Ok(format!("Successfully added prefill server: {}", url)) Ok(format!("Successfully added prefill server: {}", url))
} }
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> { pub async fn add_decode_server(
&self,
url: String,
api_key: Option<String>,
) -> Result<String, PDRouterError> {
// Wait for the new server to be healthy // Wait for the new server to be healthy
self.wait_for_server_health(&url).await?; self.wait_for_server_health(&url).await?;
...@@ -254,10 +270,15 @@ impl PDRouter { ...@@ -254,10 +270,15 @@ impl PDRouter {
// Create Worker for the new decode server with circuit breaker configuration // Create Worker for the new decode server with circuit breaker configuration
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let worker = BasicWorkerBuilder::new(url.clone()) let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.circuit_breaker_config(self.circuit_breaker_config.clone()) .circuit_breaker_config(self.circuit_breaker_config.clone());
.build();
let worker = if let Some(api_key) = api_key {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc: Arc<dyn Worker> = Arc::new(worker); let worker_arc: Arc<dyn Worker> = Arc::new(worker);
...@@ -366,6 +387,12 @@ impl PDRouter { ...@@ -366,6 +387,12 @@ impl PDRouter {
.chain(decode_workers.iter()) .chain(decode_workers.iter())
.map(|w| w.url().to_string()) .map(|w| w.url().to_string())
.collect(); .collect();
// Get all worker API keys for monitoring
let all_api_keys: Vec<Option<String>> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.api_key().clone())
.collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config(); let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
...@@ -387,6 +414,7 @@ impl PDRouter { ...@@ -387,6 +414,7 @@ impl PDRouter {
let load_monitor_handle = let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone(); let monitor_urls = all_urls.clone();
let monitor_api_keys = all_api_keys.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let monitor_client = ctx.client.clone(); let monitor_client = ctx.client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy); let prefill_policy_clone = Arc::clone(&prefill_policy);
...@@ -395,6 +423,7 @@ impl PDRouter { ...@@ -395,6 +423,7 @@ impl PDRouter {
Some(Arc::new(tokio::spawn(async move { Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads_with_client( Self::monitor_worker_loads_with_client(
monitor_urls, monitor_urls,
monitor_api_keys,
tx, tx,
monitor_interval, monitor_interval,
monitor_client, monitor_client,
...@@ -500,6 +529,7 @@ impl PDRouter { ...@@ -500,6 +529,7 @@ impl PDRouter {
prefill_drain_tx, prefill_drain_tx,
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
api_key: ctx.router_config.api_key.clone(),
}) })
} }
...@@ -1150,6 +1180,7 @@ impl PDRouter { ...@@ -1150,6 +1180,7 @@ impl PDRouter {
// Background task to monitor worker loads with shared client // Background task to monitor worker loads with shared client
async fn monitor_worker_loads_with_client( async fn monitor_worker_loads_with_client(
worker_urls: Vec<String>, worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>, tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64, interval_secs: u64,
client: Client, client: Client,
...@@ -1161,11 +1192,13 @@ impl PDRouter { ...@@ -1161,11 +1192,13 @@ impl PDRouter {
let futures: Vec<_> = worker_urls let futures: Vec<_> = worker_urls
.iter() .iter()
.map(|url| { .zip(worker_api_keys.iter())
.map(|(url, api_key)| {
let client = client.clone(); let client = client.clone();
let url = url.clone(); let url = url.clone();
let api_key = api_key.clone();
async move { async move {
let load = get_worker_load(&client, &url).await.unwrap_or(0); let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0);
(url, load) (url, load)
} }
}) })
...@@ -1515,8 +1548,16 @@ impl PDRouter { ...@@ -1515,8 +1548,16 @@ impl PDRouter {
// Helper functions // Helper functions
async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> { async fn get_worker_load(
match client.get(format!("{}/get_load", worker_url)).send().await { client: &Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await { Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) { Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
Ok(data) => data Ok(data) => data
...@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> { ...@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
#[async_trait] #[async_trait]
impl WorkerManagement for PDRouter { impl WorkerManagement for PDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> { async fn add_worker(
&self,
_worker_url: &str,
_api_key: &Option<String>,
) -> Result<String, String> {
// For PD router, we don't support adding workers via this generic method // For PD router, we don't support adding workers via this generic method
Err( Err(
"PD router requires specific add_prefill_server or add_decode_server methods" "PD router requires specific add_prefill_server or add_decode_server methods"
...@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter { ...@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter {
let mut errors = Vec::new(); let mut errors = Vec::new();
// Process prefill workers // Process prefill workers
let prefill_urls = self.get_prefill_worker_urls(); let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
for worker_url in prefill_urls { for (worker_url, api_key) in prefill_urls_with_key {
match get_worker_load(&self.client, &worker_url).await { match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => { Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load); loads.insert(format!("prefill_{}", worker_url), load);
} }
...@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter { ...@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter {
} }
// Process decode workers // Process decode workers
let decode_urls = self.get_decode_worker_urls(); let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
for worker_url in decode_urls { for (worker_url, api_key) in decode_urls_with_key {
match get_worker_load(&self.client, &worker_url).await { match get_worker_load(&self.client, &worker_url, &api_key).await {
Some(load) => { Some(load) => {
loads.insert(format!("decode_{}", worker_url), load); loads.insert(format!("decode_{}", worker_url), load);
} }
...@@ -2069,12 +2114,14 @@ mod tests { ...@@ -2069,12 +2114,14 @@ mod tests {
prefill_drain_tx: mpsc::channel(100).0, prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(), circuit_breaker_config: CircuitBreakerConfig::default(),
api_key: Some("test_api_key".to_string()),
} }
} }
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> { fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
let worker = BasicWorkerBuilder::new(url) let worker = BasicWorkerBuilder::new(url)
.worker_type(worker_type) .worker_type(worker_type)
.api_key("test_api_key")
.build(); .build();
worker.set_healthy(healthy); worker.set_healthy(healthy);
Box::new(worker) Box::new(worker)
......
...@@ -38,6 +38,7 @@ pub struct Router { ...@@ -38,6 +38,7 @@ pub struct Router {
worker_startup_timeout_secs: u64, worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64, worker_startup_check_interval_secs: u64,
dp_aware: bool, dp_aware: bool,
#[allow(dead_code)]
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
...@@ -71,7 +72,6 @@ impl Router { ...@@ -71,7 +72,6 @@ impl Router {
}; };
// Cache-aware policies are initialized in WorkerInitializer // Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy // Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx); let worker_loads = Arc::new(rx);
...@@ -82,6 +82,14 @@ impl Router { ...@@ -82,6 +82,14 @@ impl Router {
// Check if default policy is power_of_two for load monitoring // Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" { let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone(); let monitor_urls = worker_urls.clone();
let monitor_api_keys = monitor_urls
.iter()
.map(|url| {
ctx.worker_registry
.get_by_url(url)
.and_then(|w| w.api_key().clone())
})
.collect::<Vec<Option<String>>>();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = default_policy.clone(); let policy_clone = default_policy.clone();
let client_clone = ctx.client.clone(); let client_clone = ctx.client.clone();
...@@ -89,6 +97,7 @@ impl Router { ...@@ -89,6 +97,7 @@ impl Router {
Some(Arc::new(tokio::spawn(async move { Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads( Self::monitor_worker_loads(
monitor_urls, monitor_urls,
monitor_api_keys,
tx, tx,
monitor_interval, monitor_interval,
policy_clone, policy_clone,
...@@ -912,7 +921,11 @@ impl Router { ...@@ -912,7 +921,11 @@ impl Router {
} }
} }
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> { pub async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.worker_startup_timeout_secs)) .timeout(Duration::from_secs(self.worker_startup_timeout_secs))
...@@ -938,7 +951,7 @@ impl Router { ...@@ -938,7 +951,7 @@ impl Router {
// Need to contact the worker to extract the dp_size, // Need to contact the worker to extract the dp_size,
// and add them as multiple workers // and add them as multiple workers
let url_vec = vec![String::from(worker_url)]; let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key) let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?; .map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false; let mut worker_added: bool = false;
for dp_url in &dp_url_vec { for dp_url in &dp_url_vec {
...@@ -948,10 +961,18 @@ impl Router { ...@@ -948,10 +961,18 @@ impl Router {
} }
info!("Added worker: {}", dp_url); info!("Added worker: {}", dp_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = BasicWorkerBuilder::new(dp_url.to_string()) let new_worker_builder =
.worker_type(WorkerType::Regular) BasicWorkerBuilder::new(dp_url.to_string())
.circuit_breaker_config(self.circuit_breaker_config.clone()) .worker_type(WorkerType::Regular)
.build(); .circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker); let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone()); self.worker_registry.register(worker_arc.clone());
...@@ -978,10 +999,16 @@ impl Router { ...@@ -978,10 +999,16 @@ impl Router {
info!("Added worker: {}", worker_url); info!("Added worker: {}", worker_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint // TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = BasicWorkerBuilder::new(worker_url.to_string()) let new_worker_builder =
.worker_type(WorkerType::Regular) BasicWorkerBuilder::new(worker_url.to_string())
.circuit_breaker_config(self.circuit_breaker_config.clone()) .worker_type(WorkerType::Regular)
.build(); .circuit_breaker_config(self.circuit_breaker_config.clone());
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker); let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone()); self.worker_registry.register(worker_arc.clone());
...@@ -1094,7 +1121,7 @@ impl Router { ...@@ -1094,7 +1121,7 @@ impl Router {
} }
} }
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> { async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware { let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
...@@ -1109,12 +1136,12 @@ impl Router { ...@@ -1109,12 +1136,12 @@ impl Router {
worker_url worker_url
}; };
match self let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
.client if let Some(key) = api_key {
.get(format!("{}/get_load", worker_url)) req_builder = req_builder.bearer_auth(key);
.send() }
.await
{ match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await { Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) { Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data Ok(data) => data
...@@ -1149,6 +1176,7 @@ impl Router { ...@@ -1149,6 +1176,7 @@ impl Router {
// Background task to monitor worker loads // Background task to monitor worker loads
async fn monitor_worker_loads( async fn monitor_worker_loads(
worker_urls: Vec<String>, worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>, tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64, interval_secs: u64,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
...@@ -1160,8 +1188,8 @@ impl Router { ...@@ -1160,8 +1188,8 @@ impl Router {
interval.tick().await; interval.tick().await;
let mut loads = HashMap::new(); let mut loads = HashMap::new();
for url in &worker_urls { for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
if let Some(load) = Self::get_worker_load_static(&client, url).await { if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
loads.insert(url.clone(), load); loads.insert(url.clone(), load);
} }
} }
...@@ -1179,7 +1207,11 @@ impl Router { ...@@ -1179,7 +1207,11 @@ impl Router {
} }
// Static version of get_worker_load for use in monitoring task // Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> { async fn get_worker_load_static(
client: &reqwest::Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let worker_url = if worker_url.contains("@") { let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
...@@ -1194,7 +1226,11 @@ impl Router { ...@@ -1194,7 +1226,11 @@ impl Router {
worker_url worker_url
}; };
match client.get(format!("{}/get_load", worker_url)).send().await { let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await { Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) { Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data Ok(data) => data
...@@ -1250,8 +1286,12 @@ use async_trait::async_trait; ...@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
#[async_trait] #[async_trait]
impl WorkerManagement for Router { impl WorkerManagement for Router {
async fn add_worker(&self, worker_url: &str) -> Result<String, String> { async fn add_worker(
Router::add_worker(self, worker_url).await &self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
Router::add_worker(self, worker_url, api_key).await
} }
fn remove_worker(&self, worker_url: &str) { fn remove_worker(&self, worker_url: &str) {
...@@ -1457,12 +1497,12 @@ impl RouterTrait for Router { ...@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
} }
async fn get_worker_loads(&self) -> Response { async fn get_worker_loads(&self) -> Response {
let urls = self.get_worker_urls(); let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
let mut loads = Vec::new(); let mut loads = Vec::new();
// Get loads from all workers // Get loads from all workers
for url in &urls { for (url, api_key) in &urls_with_key {
let load = self.get_worker_load(url).await.unwrap_or(-1); let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
loads.push(serde_json::json!({ loads.push(serde_json::json!({
"worker": url, "worker": url,
"load": load "load": load
...@@ -1521,9 +1561,11 @@ mod tests { ...@@ -1521,9 +1561,11 @@ mod tests {
// Register test workers // Register test workers
let worker1 = BasicWorkerBuilder::new("http://worker1:8080") let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080") let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
worker_registry.register(Arc::new(worker1)); worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2)); worker_registry.register(Arc::new(worker2));
......
...@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router}; ...@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router};
#[async_trait] #[async_trait]
pub trait WorkerManagement: Send + Sync { pub trait WorkerManagement: Send + Sync {
/// Add a worker to the router /// Add a worker to the router
async fn add_worker(&self, worker_url: &str) -> Result<String, String>; async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String>;
/// Remove a worker from the router /// Remove a worker from the router
fn remove_worker(&self, worker_url: &str); fn remove_worker(&self, worker_url: &str);
......
...@@ -161,7 +161,7 @@ impl RouterManager { ...@@ -161,7 +161,7 @@ impl RouterManager {
let model_id = if let Some(model_id) = config.model_id { let model_id = if let Some(model_id) = config.model_id {
model_id model_id
} else { } else {
match self.query_server_info(&config.url).await { match self.query_server_info(&config.url, &config.api_key).await {
Ok(info) => { Ok(info) => {
// Extract model_id from server info // Extract model_id from server info
info.model_id info.model_id
...@@ -208,29 +208,44 @@ impl RouterManager { ...@@ -208,29 +208,44 @@ impl RouterManager {
} }
let worker = match config.worker_type.as_deref() { let worker = match config.worker_type.as_deref() {
Some("prefill") => Box::new( Some("prefill") => {
BasicWorkerBuilder::new(config.url.clone()) let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
bootstrap_port: config.bootstrap_port, bootstrap_port: config.bootstrap_port,
}) })
.labels(labels.clone()) .labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default());
.build(),
) as Box<dyn Worker>, if let Some(api_key) = config.api_key.clone() {
Some("decode") => Box::new( builder = builder.api_key(api_key);
BasicWorkerBuilder::new(config.url.clone()) }
Box::new(builder.build()) as Box<dyn Worker>
}
Some("decode") => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.labels(labels.clone()) .labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default());
.build(),
) as Box<dyn Worker>, if let Some(api_key) = config.api_key.clone() {
_ => Box::new( builder = builder.api_key(api_key);
BasicWorkerBuilder::new(config.url.clone()) }
Box::new(builder.build()) as Box<dyn Worker>
}
_ => {
let mut builder = BasicWorkerBuilder::new(config.url.clone())
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels.clone()) .labels(labels.clone())
.circuit_breaker_config(CircuitBreakerConfig::default()) .circuit_breaker_config(CircuitBreakerConfig::default());
.build(),
) as Box<dyn Worker>, if let Some(api_key) = config.api_key.clone() {
builder = builder.api_key(api_key);
}
Box::new(builder.build()) as Box<dyn Worker>
}
}; };
// Register worker // Register worker
...@@ -346,10 +361,18 @@ impl RouterManager { ...@@ -346,10 +361,18 @@ impl RouterManager {
} }
/// Query server info from a worker URL /// Query server info from a worker URL
async fn query_server_info(&self, url: &str) -> Result<ServerInfo, String> { async fn query_server_info(
&self,
url: &str,
api_key: &Option<String>,
) -> Result<ServerInfo, String> {
let info_url = format!("{}/get_server_info", url.trim_end_matches('/')); let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
match self.client.get(&info_url).send().await { let mut req_builder = self.client.get(&info_url);
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(response) => { Ok(response) => {
if response.status().is_success() { if response.status().is_success() {
response response
...@@ -477,10 +500,15 @@ impl RouterManager { ...@@ -477,10 +500,15 @@ impl RouterManager {
#[async_trait] #[async_trait]
impl WorkerManagement for RouterManager { impl WorkerManagement for RouterManager {
/// Add a worker - in multi-router mode, this adds to the registry /// Add a worker - in multi-router mode, this adds to the registry
async fn add_worker(&self, worker_url: &str) -> Result<String, String> { async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
// Create a basic worker config request // Create a basic worker config request
let config = WorkerConfigRequest { let config = WorkerConfigRequest {
url: worker_url.to_string(), url: worker_url.to_string(),
api_key: api_key.clone(),
model_id: None, model_id: None,
worker_type: None, worker_type: None,
priority: None, priority: None,
......
...@@ -27,8 +27,12 @@ impl WorkerInitializer { ...@@ -27,8 +27,12 @@ impl WorkerInitializer {
match &config.mode { match &config.mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
// use router's api_key, repeat for each worker
let worker_api_keys: Vec<Option<String>> =
worker_urls.iter().map(|_| config.api_key.clone()).collect();
Self::create_regular_workers( Self::create_regular_workers(
worker_urls, worker_urls,
&worker_api_keys,
&config.connection_mode, &config.connection_mode,
config, config,
worker_registry, worker_registry,
...@@ -41,8 +45,16 @@ impl WorkerInitializer { ...@@ -41,8 +45,16 @@ impl WorkerInitializer {
decode_urls, decode_urls,
.. ..
} => { } => {
// use router's api_key, repeat for each prefill/decode worker
let prefill_api_keys: Vec<Option<String>> = prefill_urls
.iter()
.map(|_| config.api_key.clone())
.collect();
let decode_api_keys: Vec<Option<String>> =
decode_urls.iter().map(|_| config.api_key.clone()).collect();
Self::create_prefill_workers( Self::create_prefill_workers(
prefill_urls, prefill_urls,
&prefill_api_keys,
&config.connection_mode, &config.connection_mode,
config, config,
worker_registry, worker_registry,
...@@ -51,6 +63,7 @@ impl WorkerInitializer { ...@@ -51,6 +63,7 @@ impl WorkerInitializer {
.await?; .await?;
Self::create_decode_workers( Self::create_decode_workers(
decode_urls, decode_urls,
&decode_api_keys,
&config.connection_mode, &config.connection_mode,
config, config,
worker_registry, worker_registry,
...@@ -79,6 +92,7 @@ impl WorkerInitializer { ...@@ -79,6 +92,7 @@ impl WorkerInitializer {
/// Create regular workers for standard routing mode /// Create regular workers for standard routing mode
async fn create_regular_workers( async fn create_regular_workers(
urls: &[String], urls: &[String],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode, config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
...@@ -109,14 +123,18 @@ impl WorkerInitializer { ...@@ -109,14 +123,18 @@ impl WorkerInitializer {
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new(); let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls { for (url, api_key) in urls.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info // TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone()) let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.connection_mode(connection_mode.clone()) .connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone()) .circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone()) .health_config(health_config.clone());
.build(); let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>; let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
...@@ -148,6 +166,7 @@ impl WorkerInitializer { ...@@ -148,6 +166,7 @@ impl WorkerInitializer {
/// Create prefill workers for disaggregated routing mode /// Create prefill workers for disaggregated routing mode
async fn create_prefill_workers( async fn create_prefill_workers(
prefill_entries: &[(String, Option<u16>)], prefill_entries: &[(String, Option<u16>)],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode, config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
...@@ -181,16 +200,20 @@ impl WorkerInitializer { ...@@ -181,16 +200,20 @@ impl WorkerInitializer {
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new(); let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, bootstrap_port) in prefill_entries { for ((url, bootstrap_port), api_key) in prefill_entries.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info // TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone()) let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port, bootstrap_port: *bootstrap_port,
}) })
.connection_mode(connection_mode.clone()) .connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone()) .circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone()) .health_config(health_config.clone());
.build(); let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>; let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
...@@ -227,6 +250,7 @@ impl WorkerInitializer { ...@@ -227,6 +250,7 @@ impl WorkerInitializer {
/// Create decode workers for disaggregated routing mode /// Create decode workers for disaggregated routing mode
async fn create_decode_workers( async fn create_decode_workers(
urls: &[String], urls: &[String],
api_keys: &[Option<String>],
config_connection_mode: &ConfigConnectionMode, config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
...@@ -257,14 +281,18 @@ impl WorkerInitializer { ...@@ -257,14 +281,18 @@ impl WorkerInitializer {
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new(); let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls { for (url, api_key) in urls.iter().zip(api_keys.iter()) {
// TODO: Add DP-aware support when we have dp_rank/dp_size info // TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone()) let worker_builder = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.connection_mode(connection_mode.clone()) .connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone()) .circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone()) .health_config(health_config.clone());
.build(); let worker = if let Some(api_key) = api_key.clone() {
worker_builder.api_key(api_key).build()
} else {
worker_builder.build()
};
let worker_arc = Arc::new(worker) as Arc<dyn Worker>; let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
......
...@@ -282,15 +282,16 @@ async fn v1_responses_list_input_items( ...@@ -282,15 +282,16 @@ async fn v1_responses_list_input_items(
// ---------- Worker management endpoints (Legacy) ---------- // ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)] #[derive(Deserialize)]
struct UrlQuery { struct AddWorkerQuery {
url: String, url: String,
api_key: Option<String>,
} }
async fn add_worker( async fn add_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(UrlQuery { url }): Query<UrlQuery>, Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response { ) -> Response {
match state.router.add_worker(&url).await { match state.router.add_worker(&url, &api_key).await {
Ok(message) => (StatusCode::OK, message).into_response(), Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
} }
...@@ -303,7 +304,7 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response { ...@@ -303,7 +304,7 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
async fn remove_worker( async fn remove_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(UrlQuery { url }): Query<UrlQuery>, Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
) -> Response { ) -> Response {
state.router.remove_worker(&url); state.router.remove_worker(&url);
( (
...@@ -337,7 +338,7 @@ async fn create_worker( ...@@ -337,7 +338,7 @@ async fn create_worker(
} }
} else { } else {
// In single router mode, use the router's add_worker with basic config // In single router mode, use the router's add_worker with basic config
match state.router.add_worker(&config.url).await { match state.router.add_worker(&config.url, &config.api_key).await {
Ok(message) => { Ok(message) => {
let response = WorkerApiResponse { let response = WorkerApiResponse {
success: true, success: true,
......
...@@ -389,16 +389,20 @@ async fn handle_pod_event( ...@@ -389,16 +389,20 @@ async fn handle_pod_event(
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() { if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type { match &pod_info.pod_type {
Some(PodType::Prefill) => pd_router Some(PodType::Prefill) => pd_router
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port) .add_prefill_server(
worker_url.clone(),
pd_router.api_key.clone(),
pod_info.bootstrap_port,
)
.await .await
.map_err(|e| e.to_string()), .map_err(|e| e.to_string()),
Some(PodType::Decode) => pd_router Some(PodType::Decode) => pd_router
.add_decode_server(worker_url.clone()) .add_decode_server(worker_url.clone(), pd_router.api_key.clone())
.await .await
.map_err(|e| e.to_string()), .map_err(|e| e.to_string()),
Some(PodType::Regular) | None => { Some(PodType::Regular) | None => {
// Fall back to regular add_worker for regular pods // Fall back to regular add_worker for regular pods
router.add_worker(&worker_url).await router.add_worker(&worker_url, &pd_router.api_key).await
} }
} }
} else { } else {
...@@ -406,7 +410,8 @@ async fn handle_pod_event( ...@@ -406,7 +410,8 @@ async fn handle_pod_event(
} }
} else { } else {
// Regular mode or no pod type specified // Regular mode or no pod type specified
router.add_worker(&worker_url).await // In pod, no need api key
router.add_worker(&worker_url, &None).await
}; };
match result { match result {
......
...@@ -18,6 +18,7 @@ fn test_backward_compatibility_with_empty_model_id() { ...@@ -18,6 +18,7 @@ fn test_backward_compatibility_with_empty_model_id() {
// Create workers with empty model_id (simulating existing routers) // Create workers with empty model_id (simulating existing routers)
let worker1 = BasicWorkerBuilder::new("http://worker1:8080") let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
// No model_id label - should default to "unknown" // No model_id label - should default to "unknown"
...@@ -25,6 +26,7 @@ fn test_backward_compatibility_with_empty_model_id() { ...@@ -25,6 +26,7 @@ fn test_backward_compatibility_with_empty_model_id() {
labels2.insert("model_id".to_string(), "unknown".to_string()); labels2.insert("model_id".to_string(), "unknown".to_string());
let worker2 = BasicWorkerBuilder::new("http://worker2:8080") let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.labels(labels2) .labels(labels2)
.build(); .build();
...@@ -59,6 +61,7 @@ fn test_mixed_model_ids() { ...@@ -59,6 +61,7 @@ fn test_mixed_model_ids() {
// Create workers with different model_id scenarios // Create workers with different model_id scenarios
let worker1 = BasicWorkerBuilder::new("http://worker1:8080") let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
// No model_id label - defaults to "unknown" which goes to "default" tree // No model_id label - defaults to "unknown" which goes to "default" tree
...@@ -67,6 +70,7 @@ fn test_mixed_model_ids() { ...@@ -67,6 +70,7 @@ fn test_mixed_model_ids() {
let worker2 = BasicWorkerBuilder::new("http://worker2:8080") let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels2) .labels(labels2)
.api_key("test_api_key")
.build(); .build();
let mut labels3 = HashMap::new(); let mut labels3 = HashMap::new();
...@@ -123,10 +127,12 @@ fn test_remove_worker_by_url_backward_compat() { ...@@ -123,10 +127,12 @@ fn test_remove_worker_by_url_backward_compat() {
let worker1 = BasicWorkerBuilder::new("http://worker1:8080") let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.labels(labels1) .labels(labels1)
.api_key("test_api_key")
.build(); .build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080") let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build(); .build();
// No model_id label - defaults to "unknown" // No model_id label - defaults to "unknown"
......
...@@ -41,6 +41,7 @@ async fn test_policy_registry_with_router_manager() { ...@@ -41,6 +41,7 @@ async fn test_policy_registry_with_router_manager() {
let _worker1_config = WorkerConfigRequest { let _worker1_config = WorkerConfigRequest {
url: "http://worker1:8000".to_string(), url: "http://worker1:8000".to_string(),
model_id: Some("llama-3".to_string()), model_id: Some("llama-3".to_string()),
api_key: Some("test_api_key".to_string()),
worker_type: None, worker_type: None,
priority: None, priority: None,
cost: None, cost: None,
...@@ -66,6 +67,7 @@ async fn test_policy_registry_with_router_manager() { ...@@ -66,6 +67,7 @@ async fn test_policy_registry_with_router_manager() {
let _worker2_config = WorkerConfigRequest { let _worker2_config = WorkerConfigRequest {
url: "http://worker2:8000".to_string(), url: "http://worker2:8000".to_string(),
model_id: Some("llama-3".to_string()), model_id: Some("llama-3".to_string()),
api_key: Some("test_api_key".to_string()),
worker_type: None, worker_type: None,
priority: None, priority: None,
cost: None, cost: None,
...@@ -86,6 +88,7 @@ async fn test_policy_registry_with_router_manager() { ...@@ -86,6 +88,7 @@ async fn test_policy_registry_with_router_manager() {
let _worker3_config = WorkerConfigRequest { let _worker3_config = WorkerConfigRequest {
url: "http://worker3:8000".to_string(), url: "http://worker3:8000".to_string(),
model_id: Some("gpt-4".to_string()), model_id: Some("gpt-4".to_string()),
api_key: Some("test_api_key".to_string()),
worker_type: None, worker_type: None,
priority: None, priority: None,
cost: None, cost: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment