Unverified Commit 71f94eda authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: lora - centralize lora cache key, restructure folders, s3 resiliency (#4644)

parent 4c1bc4ee
......@@ -335,6 +335,7 @@ mod tests {
endpoint: format!("ep{}", i),
instance_id: i,
card_json: serde_json::json!({"model": "test"}),
model_suffix: None,
};
metadata.register_model_card(instance).unwrap();
}
......
......@@ -79,6 +79,9 @@ pub enum DiscoverySpec {
/// This allows lib/runtime to remain independent of lib/llm types
/// DiscoverySpec.from_model() and DiscoveryInstance.deserialize_model() are ergonomic helpers to create and deserialize the model card.
card_json: serde_json::Value,
/// Optional suffix appended after instance_id in the key path (e.g., for LoRA adapters)
/// Key format: {namespace}/{component}/{endpoint}/{instance_id}[/{model_suffix}]
model_suffix: Option<String>,
},
}
......@@ -91,6 +94,21 @@ impl DiscoverySpec {
endpoint: String,
card: &T,
) -> Result<Self>
where
T: Serialize,
{
Self::from_model_with_suffix(namespace, component, endpoint, card, None)
}
/// Creates a Model discovery spec with an optional suffix (e.g., for LoRA adapters)
/// The suffix is appended after the instance_id in the key path
pub fn from_model_with_suffix<T>(
namespace: String,
component: String,
endpoint: String,
card: &T,
model_suffix: Option<String>,
) -> Result<Self>
where
T: Serialize,
{
......@@ -100,6 +118,7 @@ impl DiscoverySpec {
component,
endpoint,
card_json,
model_suffix,
})
}
......@@ -123,12 +142,14 @@ impl DiscoverySpec {
component,
endpoint,
card_json,
model_suffix,
} => DiscoveryInstance::Model {
namespace,
component,
endpoint,
instance_id,
card_json,
model_suffix,
},
}
}
......@@ -149,6 +170,9 @@ pub enum DiscoveryInstance {
/// ModelDeploymentCard serialized as JSON
/// This allows lib/runtime to remain independent of lib/llm types
card_json: serde_json::Value,
/// Optional suffix appended after instance_id in the key path (e.g., for LoRA adapters)
#[serde(default, skip_serializing_if = "Option::is_none")]
model_suffix: Option<String>,
},
}
......
......@@ -70,6 +70,9 @@ pub struct DistributedRuntime {
// Health Status
system_health: Arc<parking_lot::Mutex<SystemHealth>>,
// Local endpoint registry for in-process calls
local_endpoint_registry: crate::local_endpoint_registry::LocalEndpointRegistry,
// This hierarchy's own metrics registry
metrics_registry: MetricsRegistry,
}
......@@ -195,6 +198,7 @@ impl DistributedRuntime {
metrics_registry: crate::MetricsRegistry::new(),
system_health,
request_plane,
local_endpoint_registry: crate::local_endpoint_registry::LocalEndpointRegistry::new(),
};
if let Some(nats_client_for_metrics) = nats_client_for_metrics {
......@@ -316,6 +320,13 @@ impl DistributedRuntime {
self.system_health.clone()
}
/// Get the local endpoint registry for in-process endpoint calls
pub fn local_endpoint_registry(
&self,
) -> &crate::local_endpoint_registry::LocalEndpointRegistry {
&self.local_endpoint_registry
}
pub fn connection_id(&self) -> u64 {
self.discovery_client.instance_id()
}
......
......@@ -25,6 +25,7 @@ pub mod compute;
pub mod discovery;
pub mod engine;
pub mod health_check;
pub mod local_endpoint_registry;
pub mod system_status_server;
pub use system_status_server::SystemStatusServerInfo;
pub mod distributed;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Local Endpoint Registry
//!
//! Provides a registry for locally registered endpoints that can be called in-process
//! without going through the network stack.
use crate::engine::AsyncEngine;
use dashmap::DashMap;
use std::sync::Arc;
/// Type alias for a boxed async engine that can handle generic requests and responses
pub type LocalAsyncEngine = Arc<
dyn AsyncEngine<
crate::pipeline::SingleIn<serde_json::Value>,
crate::pipeline::ManyOut<crate::protocols::annotated::Annotated<serde_json::Value>>,
anyhow::Error,
> + Send
+ Sync,
>;
/// Registry for locally registered endpoints
///
/// This registry stores endpoints that are registered locally (in the same process)
/// and allows them to be called directly without going through the network transport layer.
#[derive(Clone, Default)]
pub struct LocalEndpointRegistry {
/// Map of endpoint name to async engine
engines: Arc<DashMap<String, LocalAsyncEngine>>,
}
impl LocalEndpointRegistry {
/// Create a new local endpoint registry
pub fn new() -> Self {
Self {
engines: Arc::new(DashMap::new()),
}
}
/// Register a local endpoint
///
/// # Arguments
///
/// * `endpoint_name` - Name of the endpoint (e.g., "load_lora", "generate")
/// * `engine` - The async engine that handles requests for this endpoint
pub fn register(&self, endpoint_name: String, engine: LocalAsyncEngine) {
tracing::debug!("Registering local endpoint: {}", endpoint_name);
self.engines.insert(endpoint_name, engine);
}
/// Get a registered local endpoint
///
/// The async engine if found, None otherwise
pub fn get(&self, endpoint_name: &str) -> Option<LocalAsyncEngine> {
self.engines.get(endpoint_name).map(|e| e.clone())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// TODO: (DEP-635) this file should be renamed to system_http_server.rs
// it is being used not just for status, health, but others like loras management.
use crate::config::HealthStatus;
use crate::config::environment_names::logging as env_logging;
use crate::config::environment_names::runtime::canary as env_canary;
......@@ -9,7 +12,15 @@ use crate::logging::make_request_span;
use crate::metrics::MetricsHierarchy;
use crate::metrics::prometheus_names::{nats_client, nats_service};
use crate::traits::DistributedRuntimeProvider;
use axum::{Router, http::StatusCode, response::IntoResponse, routing::get};
use axum::{
Router,
extract::{Json, Path, State},
http::StatusCode,
response::IntoResponse,
routing::{delete, get, post},
};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
......@@ -88,6 +99,35 @@ impl SystemStatusState {
}
}
/// Request body for POST /v1/loras
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LoadLoraRequest {
pub lora_name: String,
pub source: LoraSource,
}
/// Source information for loading a LoRA
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LoraSource {
pub uri: String,
}
/// Response body for LoRA operations
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LoraResponse {
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub loras: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub count: Option<usize>,
}
/// Start system status server with metrics support
pub async fn spawn_system_status_server(
host: &str,
......@@ -111,7 +151,12 @@ pub async fn spawn_system_status_server(
.live_path()
.to_string();
let app = Router::new()
// Check if LoRA feature is enabled
let lora_enabled = std::env::var(crate::config::environment_names::llm::DYN_LORA_ENABLED)
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
let mut app = Router::new()
.route(
&health_path,
get({
......@@ -139,7 +184,32 @@ pub async fn spawn_system_status_server(
let state = Arc::clone(&server_state);
move || metadata_handler(state)
}),
)
);
// Add LoRA routes only if DYN_LORA_ENABLED is set to true
if lora_enabled {
app = app
.route(
"/v1/loras",
get({
let state = Arc::clone(&server_state);
move || list_loras_handler(State(state))
})
.post({
let state = Arc::clone(&server_state);
move |body| load_lora_handler(State(state), body)
}),
)
.route(
"/v1/loras/{*lora_name}",
delete({
let state = Arc::clone(&server_state);
move |path| unload_lora_handler(State(state), path)
}),
);
}
let app = app
.fallback(|| async {
tracing::info!("[fallback handler] called");
(StatusCode::NOT_FOUND, "Route not found").into_response()
......@@ -269,6 +339,192 @@ async fn metadata_handler(state: Arc<SystemStatusState>) -> impl IntoResponse {
}
}
/// Handler for POST /v1/loras - Load a LoRA adapter
#[tracing::instrument(skip_all, level = "debug")]
async fn load_lora_handler(
State(state): State<Arc<SystemStatusState>>,
Json(request): Json<LoadLoraRequest>,
) -> impl IntoResponse {
tracing::info!("Loading LoRA: {}", request.lora_name);
// Call the load_lora endpoint for each available backend
match call_lora_endpoint(
state.drt(),
"load_lora",
json!({
"lora_name": request.lora_name,
"source": {
"uri": request.source.uri
},
}),
)
.await
{
Ok(response) => {
tracing::info!("LoRA loaded successfully: {}", request.lora_name);
(StatusCode::OK, Json(response))
}
Err(e) => {
tracing::error!("Failed to load LoRA {}: {}", request.lora_name, e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(LoraResponse {
status: "error".to_string(),
message: Some(e.to_string()),
lora_name: Some(request.lora_name),
lora_id: None,
loras: None,
count: None,
}),
)
}
}
}
/// Handler for DELETE /v1/loras/*lora_name - Unload a LoRA adapter
#[tracing::instrument(skip_all, level = "debug")]
async fn unload_lora_handler(
State(state): State<Arc<SystemStatusState>>,
Path(lora_name): Path<String>,
) -> impl IntoResponse {
// Strip the leading slash from the wildcard capture
let lora_name = lora_name
.strip_prefix('/')
.unwrap_or(&lora_name)
.to_string();
tracing::info!("Unloading LoRA: {}", lora_name);
// Call the unload_lora endpoint for each available backend
match call_lora_endpoint(
state.drt(),
"unload_lora",
json!({
"lora_name": lora_name.clone(),
}),
)
.await
{
Ok(response) => {
tracing::info!("LoRA unloaded successfully: {}", lora_name);
(StatusCode::OK, Json(response))
}
Err(e) => {
tracing::error!("Failed to unload LoRA {}: {}", lora_name, e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(LoraResponse {
status: "error".to_string(),
message: Some(e.to_string()),
lora_name: Some(lora_name),
lora_id: None,
loras: None,
count: None,
}),
)
}
}
}
/// Handler for GET /v1/loras - List all LoRA adapters
#[tracing::instrument(skip_all, level = "debug")]
async fn list_loras_handler(State(state): State<Arc<SystemStatusState>>) -> impl IntoResponse {
tracing::info!("Listing all LoRAs");
// Call the list_loras endpoint for each available backend
match call_lora_endpoint(state.drt(), "list_loras", json!({})).await {
Ok(response) => {
tracing::info!("Successfully retrieved LoRA list");
(StatusCode::OK, Json(response))
}
Err(e) => {
tracing::error!("Failed to list LoRAs: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(LoraResponse {
status: "error".to_string(),
message: Some(e.to_string()),
lora_name: None,
lora_id: None,
loras: None,
count: None,
}),
)
}
}
}
/// Helper function to call a LoRA management endpoint locally via in-process registry
///
/// This function ONLY uses the local endpoint registry for direct in-process calls.
/// It does NOT fall back to network discovery if the endpoint is not found.
async fn call_lora_endpoint(
drt: &crate::DistributedRuntime,
endpoint_name: &str,
request_body: serde_json::Value,
) -> anyhow::Result<LoraResponse> {
use crate::engine::AsyncEngine;
tracing::debug!("Calling local endpoint: '{}'", endpoint_name);
// Get the endpoint from the local registry (in-process call only)
let local_registry = drt.local_endpoint_registry();
let engine = local_registry
.get(endpoint_name)
.ok_or_else(|| {
anyhow::anyhow!(
"Endpoint '{}' not found in local registry. Make sure it's registered with .register_local_engine()",
endpoint_name
)
})?;
tracing::debug!(
"Found endpoint '{}' in local registry, calling directly",
endpoint_name
);
// Call the engine directly without going through the network stack
let request = crate::pipeline::SingleIn::new(request_body);
let mut stream = engine.generate(request).await?;
// Get the first response
if let Some(response) = stream.next().await {
let response_data = response.data.unwrap_or_default();
// Try structured deserialization first, fall back to manual field extraction
let lora_response = serde_json::from_value::<LoraResponse>(response_data.clone())
.unwrap_or_else(|_| parse_lora_response(&response_data));
return Ok(lora_response);
}
anyhow::bail!("No response received from endpoint '{}'", endpoint_name)
}
/// Helper to parse response data into LoraResponse
fn parse_lora_response(response_data: &serde_json::Value) -> LoraResponse {
LoraResponse {
status: response_data
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("success")
.to_string(),
message: response_data
.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string()),
lora_name: response_data
.get("lora_name")
.and_then(|n| n.as_str())
.map(|s| s.to_string()),
lora_id: response_data.get("lora_id").and_then(|id| id.as_u64()),
loras: response_data.get("loras").cloned(),
count: response_data
.get("count")
.and_then(|c| c.as_u64())
.map(|c| c as usize),
}
}
// Regular tests: cargo test system_status_server --lib
#[cfg(test)]
mod tests {
......@@ -636,7 +892,7 @@ mod integration_tests {
// Now create a namespace, component, and endpoint to make the system healthy
let namespace = drt.namespace("ns1234").unwrap();
let component = namespace.component("comp1234").unwrap();
let mut component = namespace.component("comp1234").unwrap();
// Create a simple test handler
use crate::pipeline::{async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, SingleIn};
......@@ -662,6 +918,7 @@ mod integration_tests {
// Start the service and endpoint with a health check payload
// This will automatically register the endpoint for health monitoring
tokio::spawn(async move {
component.add_stats_service().await.unwrap();
let _ = component.endpoint(ENDPOINT_NAME)
.endpoint_builder()
.handler(ingress)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment