"git@developer.sourcefind.cn:change/sglang.git" did not exist on "54fb1c80c0d7bbf100d4efc84d1aad4bee094ff0"
Unverified Commit 40e0082d authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add worker self discovery for metadata (#11638)

parent e9e120ac
...@@ -11,6 +11,7 @@ use crate::core::{ ...@@ -11,6 +11,7 @@ use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig,
Worker, WorkerFactory, WorkerRegistry, WorkerType, Worker, WorkerFactory, WorkerRegistry, WorkerType,
}; };
use crate::grpc_client::SglangSchedulerClient;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{ use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
...@@ -21,6 +22,7 @@ use once_cell::sync::Lazy; ...@@ -21,6 +22,7 @@ use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{watch, Mutex}; use tokio::sync::{watch, Mutex};
...@@ -55,6 +57,21 @@ pub struct DpInfo { ...@@ -55,6 +57,21 @@ pub struct DpInfo {
pub model_id: String, pub model_id: String,
} }
/// Worker discovery results gathered from backend endpoints
struct WorkerDiscovery {
labels: HashMap<String, String>,
grpc_client: Option<SglangSchedulerClient>,
}
impl WorkerDiscovery {
fn new() -> Self {
Self {
labels: HashMap::new(),
grpc_client: None,
}
}
}
/// Unified worker management /// Unified worker management
pub struct WorkerManager; pub struct WorkerManager;
...@@ -318,7 +335,8 @@ impl WorkerManager { ...@@ -318,7 +335,8 @@ impl WorkerManager {
None, None,
circuit_breaker_config.clone(), circuit_breaker_config.clone(),
health_config.clone(), health_config.clone(),
); )
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry); Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
} }
} }
...@@ -363,7 +381,8 @@ impl WorkerManager { ...@@ -363,7 +381,8 @@ impl WorkerManager {
None, None,
circuit_breaker_config.clone(), circuit_breaker_config.clone(),
health_config.clone(), health_config.clone(),
); )
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry); Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
} }
...@@ -408,7 +427,8 @@ impl WorkerManager { ...@@ -408,7 +427,8 @@ impl WorkerManager {
None, None,
circuit_breaker_config.clone(), circuit_breaker_config.clone(),
health_config.clone(), health_config.clone(),
); )
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry); Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
} }
...@@ -448,7 +468,8 @@ impl WorkerManager { ...@@ -448,7 +468,8 @@ impl WorkerManager {
None, None,
circuit_breaker_config.clone(), circuit_breaker_config.clone(),
health_config.clone(), health_config.clone(),
); )
.await;
Self::register_worker(worker, registry, &mut registered_workers, policy_registry); Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
info!( info!(
"Registered gRPC worker at {} (will connect on first use)", "Registered gRPC worker at {} (will connect on first use)",
...@@ -497,7 +518,8 @@ impl WorkerManager { ...@@ -497,7 +518,8 @@ impl WorkerManager {
None, None,
circuit_breaker_config.clone(), circuit_breaker_config.clone(),
health_config.clone(), health_config.clone(),
); )
.await;
Self::register_worker( Self::register_worker(
worker, worker,
registry, registry,
...@@ -522,7 +544,8 @@ impl WorkerManager { ...@@ -522,7 +544,8 @@ impl WorkerManager {
None, None,
circuit_breaker_config.clone(), circuit_breaker_config.clone(),
health_config.clone(), health_config.clone(),
); )
.await;
Self::register_worker( Self::register_worker(
worker, worker,
registry, registry,
...@@ -563,12 +586,9 @@ impl WorkerManager { ...@@ -563,12 +586,9 @@ impl WorkerManager {
} }
let mut labels = config.labels.clone(); let mut labels = config.labels.clone();
// Use provided model_id or default to "unknown" if let Some(model_id) = &config.model_id {
let model_id = config labels.insert("model_id".to_string(), model_id.clone());
.model_id }
.clone()
.unwrap_or_else(|| "unknown".to_string());
labels.insert("model_id".to_string(), model_id.clone());
if let Some(priority) = config.priority { if let Some(priority) = config.priority {
labels.insert("priority".to_string(), priority.to_string()); labels.insert("priority".to_string(), priority.to_string());
} }
...@@ -620,12 +640,14 @@ impl WorkerManager { ...@@ -620,12 +640,14 @@ impl WorkerManager {
Some(labels.clone()), Some(labels.clone()),
circuit_breaker_config, circuit_breaker_config,
health_config, health_config,
); )
.await;
worker.set_healthy(false); worker.set_healthy(false);
context.worker_registry.register(worker.clone()); context.worker_registry.register(worker.clone());
let policy_hint = labels.get("policy").map(|s| s.as_str()); let policy_hint = labels.get("policy").map(|s| s.as_str());
let model_id = worker.model_id().to_string();
context context
.policy_registry .policy_registry
.on_worker_added(&model_id, policy_hint); .on_worker_added(&model_id, policy_hint);
...@@ -793,7 +815,8 @@ impl WorkerManager { ...@@ -793,7 +815,8 @@ impl WorkerManager {
labels, labels,
circuit_breaker_config, circuit_breaker_config,
health_config, health_config,
); )
.await;
let model_id = worker.model_id().to_string(); let model_id = worker.model_id().to_string();
context.worker_registry.register(worker.clone()); context.worker_registry.register(worker.clone());
...@@ -893,7 +916,7 @@ impl WorkerManager { ...@@ -893,7 +916,7 @@ impl WorkerManager {
} }
/// Create a basic worker /// Create a basic worker
fn create_basic_worker( async fn create_basic_worker(
url: String, url: String,
worker_type: WorkerType, worker_type: WorkerType,
connection_mode: ConnectionMode, connection_mode: ConnectionMode,
...@@ -902,6 +925,16 @@ impl WorkerManager { ...@@ -902,6 +925,16 @@ impl WorkerManager {
circuit_breaker_config: CircuitBreakerConfig, circuit_breaker_config: CircuitBreakerConfig,
health_config: HealthConfig, health_config: HealthConfig,
) -> Arc<dyn Worker> { ) -> Arc<dyn Worker> {
let discovery =
Self::discover_worker_metadata(&url, &connection_mode, api_key.as_deref()).await;
let mut final_labels = discovery.labels;
if let Some(custom_labels) = labels {
for (key, value) in custom_labels {
final_labels.insert(key, value);
}
}
let mut builder = BasicWorkerBuilder::new(url) let mut builder = BasicWorkerBuilder::new(url)
.worker_type(worker_type) .worker_type(worker_type)
.connection_mode(connection_mode) .connection_mode(connection_mode)
...@@ -912,8 +945,12 @@ impl WorkerManager { ...@@ -912,8 +945,12 @@ impl WorkerManager {
builder = builder.api_key(key); builder = builder.api_key(key);
} }
if let Some(worker_labels) = labels { if !final_labels.is_empty() {
builder = builder.labels(worker_labels); builder = builder.labels(final_labels);
}
if let Some(client) = discovery.grpc_client {
builder = builder.grpc_client(client);
} }
let worker = builder.build(); let worker = builder.build();
...@@ -1084,6 +1121,306 @@ impl WorkerManager { ...@@ -1084,6 +1121,306 @@ impl WorkerManager {
} }
} }
/// Gather worker metadata directly from the backend before registration.
async fn discover_worker_metadata(
url: &str,
connection_mode: &ConnectionMode,
api_key: Option<&str>,
) -> WorkerDiscovery {
match connection_mode {
ConnectionMode::Http => Self::discover_http_metadata(url, api_key).await,
ConnectionMode::Grpc { .. } => Self::discover_grpc_metadata(url).await,
}
}
async fn discover_http_metadata(url: &str, api_key: Option<&str>) -> WorkerDiscovery {
let mut discovery = WorkerDiscovery::new();
match Self::get_model_info(url, api_key).await {
Ok(model_info) => {
if let Some(model_path) = model_info.get("model_path").and_then(|v| v.as_str()) {
if !model_path.is_empty() {
discovery
.labels
.insert("model_path".to_string(), model_path.to_string());
}
}
if let Some(tokenizer_path) =
model_info.get("tokenizer_path").and_then(|v| v.as_str())
{
if !tokenizer_path.is_empty() {
discovery
.labels
.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
}
}
if let Some(served_model_name) =
model_info.get("served_model_name").and_then(|v| v.as_str())
{
if !served_model_name.is_empty() {
discovery.labels.insert(
"served_model_name".to_string(),
served_model_name.to_string(),
);
}
}
if let Some(weight_version) =
model_info.get("weight_version").and_then(|v| v.as_str())
{
if !weight_version.is_empty() {
discovery
.labels
.insert("weight_version".to_string(), weight_version.to_string());
}
}
if let Some(model_type) = model_info.get("model_type").and_then(|v| v.as_str()) {
if !model_type.is_empty() {
discovery
.labels
.insert("model_type".to_string(), model_type.to_string());
}
}
if let Some(is_generation) =
model_info.get("is_generation").and_then(|v| v.as_bool())
{
discovery
.labels
.insert("is_generation".to_string(), is_generation.to_string());
}
if let Some(preferred_sampling_params) = model_info
.get("preferred_sampling_params")
.and_then(|v| v.as_str())
{
if !preferred_sampling_params.is_empty() {
discovery.labels.insert(
"preferred_sampling_params".to_string(),
preferred_sampling_params.to_string(),
);
}
}
if let Some(max_context_length) = model_info
.get("max_context_length")
.and_then(|v| v.as_i64())
{
discovery.labels.insert(
"max_context_length".to_string(),
max_context_length.to_string(),
);
}
if let Some(max_req_input_len) =
model_info.get("max_req_input_len").and_then(|v| v.as_i64())
{
discovery.labels.insert(
"max_req_input_len".to_string(),
max_req_input_len.to_string(),
);
}
}
Err(e) => {
warn!(
"Worker discovery: failed to fetch HTTP model info from {}: {}",
url, e
);
}
}
match Self::get_server_info(url, api_key).await {
Ok(server_info) => {
if let Some(model_id) = server_info.model_id {
if !model_id.is_empty() {
let normalized = Self::normalize_model_identifier(&model_id);
discovery.labels.insert("model_id".to_string(), normalized);
}
}
if let Some(model_path) = server_info.model_path {
if !model_path.is_empty() {
discovery
.labels
.insert("model_path".to_string(), model_path);
}
}
if let Some(version) = server_info.version {
if !version.is_empty() {
discovery
.labels
.insert("server_version".to_string(), version);
}
}
if let Some(max_total_tokens) = server_info.max_total_tokens {
discovery
.labels
.insert("max_total_tokens".to_string(), max_total_tokens.to_string());
}
if let Some(max_prefill_tokens) = server_info.max_prefill_tokens {
discovery.labels.insert(
"max_prefill_tokens".to_string(),
max_prefill_tokens.to_string(),
);
}
if let Some(max_running_requests) = server_info.max_running_requests {
discovery.labels.insert(
"max_running_requests".to_string(),
max_running_requests.to_string(),
);
}
}
Err(e) => {
warn!(
"Worker discovery: failed to fetch HTTP server info from {}: {}",
url, e
);
}
}
Self::finalize_model_id(&mut discovery.labels);
discovery
}
async fn discover_grpc_metadata(url: &str) -> WorkerDiscovery {
let mut discovery = WorkerDiscovery::new();
let client = match SglangSchedulerClient::connect(url).await {
Ok(client) => client,
Err(e) => {
warn!(
"Worker discovery: failed to connect to gRPC worker {}: {}",
url, e
);
return discovery;
}
};
match client.get_model_info().await {
Ok(model_info) => {
if !model_info.model_path.is_empty() {
discovery
.labels
.insert("model_path".to_string(), model_info.model_path.clone());
}
if !model_info.tokenizer_path.is_empty() {
discovery.labels.insert(
"tokenizer_path".to_string(),
model_info.tokenizer_path.clone(),
);
}
if !model_info.served_model_name.is_empty() {
discovery.labels.insert(
"served_model_name".to_string(),
model_info.served_model_name.clone(),
);
let normalized =
Self::normalize_model_identifier(&model_info.served_model_name);
discovery.labels.insert("model_id".to_string(), normalized);
}
if !model_info.weight_version.is_empty() {
discovery.labels.insert(
"weight_version".to_string(),
model_info.weight_version.clone(),
);
}
if !model_info.model_type.is_empty() {
discovery
.labels
.insert("model_type".to_string(), model_info.model_type.clone());
}
if !model_info.preferred_sampling_params.is_empty() {
discovery.labels.insert(
"preferred_sampling_params".to_string(),
model_info.preferred_sampling_params.clone(),
);
}
discovery.labels.insert(
"is_generation".to_string(),
model_info.is_generation.to_string(),
);
if model_info.max_context_length > 0 {
discovery.labels.insert(
"max_context_length".to_string(),
model_info.max_context_length.to_string(),
);
}
if model_info.max_req_input_len > 0 {
discovery.labels.insert(
"max_req_input_len".to_string(),
model_info.max_req_input_len.to_string(),
);
}
if model_info.vocab_size > 0 {
discovery
.labels
.insert("vocab_size".to_string(), model_info.vocab_size.to_string());
}
}
Err(e) => {
warn!(
"Worker discovery: failed to fetch gRPC model info from {}: {}",
url, e
);
}
}
if !discovery.labels.contains_key("model_id") {
Self::finalize_model_id(&mut discovery.labels);
}
discovery.grpc_client = Some(client);
discovery
}
fn normalize_model_identifier(value: &str) -> String {
let trimmed = value.trim();
if trimmed.contains('/') || trimmed.contains('\\') {
Self::derive_model_id_from_path(trimmed)
} else {
trimmed.to_string()
}
}
fn finalize_model_id(labels: &mut HashMap<String, String>) {
let has_model_id = labels
.get("model_id")
.map(|v| !v.trim().is_empty())
.unwrap_or(false);
if has_model_id {
return;
}
if let Some(served_name) = labels.get("served_model_name").cloned() {
if !served_name.trim().is_empty() {
let normalized = Self::normalize_model_identifier(&served_name);
labels.insert("model_id".to_string(), normalized);
return;
}
}
if let Some(model_path) = labels.get("model_path").cloned() {
if !model_path.trim().is_empty() {
let derived = Self::derive_model_id_from_path(&model_path);
if !derived.is_empty() {
labels.insert("model_id".to_string(), derived);
}
}
}
}
fn derive_model_id_from_path(path: &str) -> String {
let trimmed = path.trim_end_matches(['/', '\\']);
if trimmed.is_empty() {
return path.to_string();
}
let candidate = Path::new(trimmed)
.file_name()
.and_then(|p| p.to_str())
.map(|s| s.to_string());
match candidate {
Some(name) if !name.is_empty() => name,
_ => trimmed.to_string(),
}
}
/// Parse server info from JSON response /// Parse server info from JSON response
fn parse_server_info(json: Value) -> Result<ServerInfo, String> { fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
Ok(ServerInfo { Ok(ServerInfo {
...@@ -1499,6 +1836,7 @@ impl Drop for LoadMonitor { ...@@ -1499,6 +1836,7 @@ impl Drop for LoadMonitor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashMap;
#[test] #[test]
fn test_parse_server_info() { fn test_parse_server_info() {
...@@ -1533,4 +1871,49 @@ mod tests { ...@@ -1533,4 +1871,49 @@ mod tests {
assert_eq!(info.model_id, None); assert_eq!(info.model_id, None);
assert_eq!(info.dp_size, None); assert_eq!(info.dp_size, None);
} }
#[test]
fn test_derive_model_id_from_path() {
let path = "/raid/models/meta-llama/Llama-3.1-8B-Instruct";
let derived = WorkerManager::derive_model_id_from_path(path);
assert_eq!(derived, "Llama-3.1-8B-Instruct");
}
#[test]
fn test_derive_model_id_trailing_slash() {
let path = "/models/foo/bar/";
let derived = WorkerManager::derive_model_id_from_path(path);
assert_eq!(derived, "bar");
}
#[test]
fn test_finalize_model_id_prefers_existing() {
let mut labels = HashMap::new();
labels.insert("model_id".to_string(), "manual-id".to_string());
labels.insert("served_model_name".to_string(), "auto-id".to_string());
WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "manual-id");
}
#[test]
fn test_finalize_model_id_prefers_served_name() {
let mut labels = HashMap::new();
labels.insert("served_model_name".to_string(), "served-name".to_string());
WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "served-name");
}
#[test]
fn test_finalize_model_id_falls_back_to_path() {
let mut labels = HashMap::new();
labels.insert("model_path".to_string(), "/models/alpha".to_string());
WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "alpha");
}
#[test]
fn test_normalize_model_identifier_from_path() {
let normalized = WorkerManager::normalize_model_identifier("/raid/models/foo/bar-model");
assert_eq!(normalized, "bar-model");
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment