Unverified Commit 2479b894 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Simplify model_id determination (#11684)

parent 54644572
...@@ -22,7 +22,6 @@ use once_cell::sync::Lazy; ...@@ -22,7 +22,6 @@ use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{watch, Mutex}; use tokio::sync::{watch, Mutex};
...@@ -1228,8 +1227,7 @@ impl WorkerManager { ...@@ -1228,8 +1227,7 @@ impl WorkerManager {
Ok(server_info) => { Ok(server_info) => {
if let Some(model_id) = server_info.model_id { if let Some(model_id) = server_info.model_id {
if !model_id.is_empty() { if !model_id.is_empty() {
let normalized = Self::normalize_model_identifier(&model_id); discovery.labels.insert("model_id".to_string(), model_id);
discovery.labels.insert("model_id".to_string(), normalized);
} }
} }
if let Some(model_path) = server_info.model_path { if let Some(model_path) = server_info.model_path {
...@@ -1309,9 +1307,9 @@ impl WorkerManager { ...@@ -1309,9 +1307,9 @@ impl WorkerManager {
"served_model_name".to_string(), "served_model_name".to_string(),
model_info.served_model_name.clone(), model_info.served_model_name.clone(),
); );
let normalized = discovery
Self::normalize_model_identifier(&model_info.served_model_name); .labels
discovery.labels.insert("model_id".to_string(), normalized); .insert("model_id".to_string(), model_info.served_model_name);
} }
if !model_info.weight_version.is_empty() { if !model_info.weight_version.is_empty() {
discovery.labels.insert( discovery.labels.insert(
...@@ -1368,15 +1366,6 @@ impl WorkerManager { ...@@ -1368,15 +1366,6 @@ impl WorkerManager {
discovery discovery
} }
fn normalize_model_identifier(value: &str) -> String {
let trimmed = value.trim();
if trimmed.contains('/') || trimmed.contains('\\') {
Self::derive_model_id_from_path(trimmed)
} else {
trimmed.to_string()
}
}
fn finalize_model_id(labels: &mut HashMap<String, String>) { fn finalize_model_id(labels: &mut HashMap<String, String>) {
let has_model_id = labels let has_model_id = labels
.get("model_id") .get("model_id")
...@@ -1386,41 +1375,20 @@ impl WorkerManager { ...@@ -1386,41 +1375,20 @@ impl WorkerManager {
return; return;
} }
if let Some(served_name) = labels.get("served_model_name").cloned() { if let Some(served_name) = labels.get("served_model_name") {
if !served_name.trim().is_empty() { if !served_name.trim().is_empty() {
let normalized = Self::normalize_model_identifier(&served_name); labels.insert("model_id".to_string(), served_name.clone());
labels.insert("model_id".to_string(), normalized);
return; return;
} }
} }
if let Some(model_path) = labels.get("model_path").cloned() { if let Some(model_path) = labels.get("model_path") {
if !model_path.trim().is_empty() { if !model_path.trim().is_empty() {
let derived = Self::derive_model_id_from_path(&model_path); labels.insert("model_id".to_string(), model_path.clone());
if !derived.is_empty() {
labels.insert("model_id".to_string(), derived);
}
} }
} }
} }
fn derive_model_id_from_path(path: &str) -> String {
let trimmed = path.trim_end_matches(['/', '\\']);
if trimmed.is_empty() {
return path.to_string();
}
let candidate = Path::new(trimmed)
.file_name()
.and_then(|p| p.to_str())
.map(|s| s.to_string());
match candidate {
Some(name) if !name.is_empty() => name,
_ => trimmed.to_string(),
}
}
/// Parse server info from JSON response /// Parse server info from JSON response
fn parse_server_info(json: Value) -> Result<ServerInfo, String> { fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
Ok(ServerInfo { Ok(ServerInfo {
...@@ -1872,20 +1840,6 @@ mod tests { ...@@ -1872,20 +1840,6 @@ mod tests {
assert_eq!(info.dp_size, None); assert_eq!(info.dp_size, None);
} }
#[test]
fn test_derive_model_id_from_path() {
let path = "/raid/models/meta-llama/Llama-3.1-8B-Instruct";
let derived = WorkerManager::derive_model_id_from_path(path);
assert_eq!(derived, "Llama-3.1-8B-Instruct");
}
#[test]
fn test_derive_model_id_trailing_slash() {
let path = "/models/foo/bar/";
let derived = WorkerManager::derive_model_id_from_path(path);
assert_eq!(derived, "bar");
}
#[test] #[test]
fn test_finalize_model_id_prefers_existing() { fn test_finalize_model_id_prefers_existing() {
let mut labels = HashMap::new(); let mut labels = HashMap::new();
...@@ -1908,12 +1862,6 @@ mod tests { ...@@ -1908,12 +1862,6 @@ mod tests {
let mut labels = HashMap::new(); let mut labels = HashMap::new();
labels.insert("model_path".to_string(), "/models/alpha".to_string()); labels.insert("model_path".to_string(), "/models/alpha".to_string());
WorkerManager::finalize_model_id(&mut labels); WorkerManager::finalize_model_id(&mut labels);
assert_eq!(labels.get("model_id").unwrap(), "alpha"); assert_eq!(labels.get("model_id").unwrap(), "/models/alpha");
}
#[test]
fn test_normalize_model_identifier_from_path() {
let normalized = WorkerManager::normalize_model_identifier("/raid/models/foo/bar-model");
assert_eq!(normalized, "bar-model");
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment