Unverified Commit 953e5d7b authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat(lora): implement HRW-based LoRA allocation system (#5992)

parent e351c634
......@@ -374,13 +374,17 @@ fn register_llm<'p>(
.media_fetcher(media_fetcher.map(|m| m.inner));
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Convert lora_identifier (Option<String>) to Option<LoraInfo>
let lora_info = lora_identifier
.as_ref()
.map(|name| llm_rs::model_card::LoraInfo {
name: name.clone(),
max_gpu_lora_count: None,
});
local_model
.attach(
&endpoint.inner,
model_type_obj,
model_input,
lora_identifier.as_deref(),
)
.attach(&endpoint.inner, model_type_obj, model_input, lora_info)
.await
.map_err(to_pyerr)?;
......
......@@ -447,14 +447,16 @@ impl LocalModel {
endpoint: &Endpoint,
model_type: ModelType,
model_input: ModelInput,
lora_name: Option<&str>,
lora_info: Option<crate::model_card::LoraInfo>,
) -> anyhow::Result<()> {
self.card.model_type = model_type;
self.card.model_input = model_input;
self.card.lora_name = lora_name.map(|name| name.to_string());
self.card.lora = lora_info.clone();
// Compute model_suffix from lora_name if present
let model_suffix = lora_name.map(|name| Slug::slugify(name).to_string());
let model_suffix = lora_info
.as_ref()
.map(|info| Slug::slugify(&info.name).to_string());
let suffix_for_log = model_suffix
.as_ref()
......
......@@ -5,13 +5,20 @@
//!
//! This module provides a minimal, extensible interface for downloading LoRA adapters
//! from various sources (local filesystem, S3, etc.) with automatic caching.
//! It also provides routing and lora allocation algorithms for distributing LoRA adapters
//! across workers in a cluster.
mod cache;
mod downloader;
pub mod load_estimator;
pub mod routing;
mod source;
pub use cache::LoRACache;
pub use downloader::LoRADownloader;
pub use load_estimator::{LoadEstimator, LoadEstimatorConfig, LoadSample};
pub use routing::{
AllocationAlgorithmType, LoraAllocator, LoraReplicaConfig, LoraRoutingTable, RendezvousHasher,
create_lora_allocator,
};
pub use source::{LoRASource, LocalLoRASource, S3LoRASource};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::LoraAllocator;
use crate::kv_router::protocols::WorkerWithDpRank;
/// Rendezvous (HRW) hashing implementation for LoRA allocation
pub struct RendezvousHasher;
impl RendezvousHasher {
/// Compute hash score for a (lora_name, worker) pair using HRW hashing with blake3
pub fn compute_score(lora_name: &str, worker: WorkerWithDpRank) -> u64 {
let mut hasher = blake3::Hasher::new();
hasher.update(lora_name.as_bytes());
hasher.update(&worker.worker_id.to_le_bytes());
hasher.update(&worker.dp_rank.to_le_bytes());
let hash = hasher.finalize();
// Extract first 8 bytes as u64
let hash_bytes = hash.as_bytes();
let mut bytes_array = [0u8; 8];
bytes_array.copy_from_slice(&hash_bytes[..8]);
u64::from_le_bytes(bytes_array)
}
/// Rank workers by their hash scores for a given LoRA
/// Returns workers sorted by score in descending order (highest first).
pub fn rank_workers(
lora_name: &str,
workers: &[WorkerWithDpRank],
) -> Vec<(WorkerWithDpRank, u64)> {
let mut scores: Vec<_> = workers
.iter()
.map(|&w| (w, Self::compute_score(lora_name, w)))
.collect();
// Sort by score descending (highest score first)
scores.sort_by_key(|(_, score)| std::cmp::Reverse(*score));
scores
}
}
impl LoraAllocator for RendezvousHasher {
fn compute_replica_set(
&self,
lora_name: &str,
workers: &[WorkerWithDpRank],
replica_factor: usize,
) -> Vec<WorkerWithDpRank> {
if workers.is_empty() {
return Vec::new();
}
// Rank all workers and take top N
let ranked = Self::rank_workers(lora_name, workers);
ranked
.into_iter()
.take(replica_factor.min(workers.len()))
.map(|(w, _)| w)
.collect()
}
fn name(&self) -> &str {
"hrw"
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_workers(count: usize) -> Vec<WorkerWithDpRank> {
(0..count)
.map(|i| WorkerWithDpRank::new(i as u64, 0))
.collect()
}
#[test]
fn test_deterministic() {
let worker = WorkerWithDpRank::new(1, 0);
let lora_name = "test-lora";
// Same inputs should always produce same score
let score1 = RendezvousHasher::compute_score(lora_name, worker);
let score2 = RendezvousHasher::compute_score(lora_name, worker);
assert_eq!(score1, score2, "Same inputs should produce same score");
}
#[test]
fn test_stability_adding_workers() {
// Start with 3 workers
let workers_before = make_workers(3);
let hasher = RendezvousHasher;
let replica_set_before = hasher.compute_replica_set("test-lora", &workers_before, 2);
assert_eq!(replica_set_before.len(), 2);
// Add 2 more workers
let workers_after = make_workers(5);
let replica_set_after = hasher.compute_replica_set("test-lora", &workers_after, 2);
assert_eq!(replica_set_after.len(), 2);
let top2_after: Vec<_> = replica_set_after.iter().map(|w| w.worker_id).collect();
// The top 2 should be the same if they're still in top 2 after adding workers
// This tests stability property: adding workers shouldn't change existing placements
// (unless the new workers rank higher, which is expected behavior)
// At minimum, verify determinism: same inputs produce same outputs
let replica_set_after2 = hasher.compute_replica_set("test-lora", &workers_after, 2);
let top2_after2: Vec<_> = replica_set_after2.iter().map(|w| w.worker_id).collect();
assert_eq!(
top2_after, top2_after2,
"Same inputs should produce same outputs"
);
}
#[test]
fn test_stability_removing_workers() {
let hasher = RendezvousHasher;
// Start with 5 workers
let workers_5 = make_workers(5);
let set_5 = hasher.compute_replica_set("test-lora", &workers_5, 3);
assert_eq!(set_5.len(), 3);
// Remove worker 2 (keep 0,1,3,4)
let workers_4: Vec<_> = workers_5
.iter()
.filter(|w| w.worker_id != 2)
.copied()
.collect();
let set_4 = hasher.compute_replica_set("test-lora", &workers_4, 3);
assert_eq!(set_4.len(), 3);
// If worker 2 wasn't in the original top 3, the other selections should stay the same
if !set_5.iter().any(|w| w.worker_id == 2) {
// The workers that were in top 3 and are still available should still be in top 3
for worker in &set_5 {
if workers_4.contains(worker) {
assert!(
set_4.contains(worker),
"Worker {} was in top 3 and is still available, should remain in top 3",
worker.worker_id
);
}
}
}
}
#[test]
fn test_compute_replica_set_more_replicas_than_workers() {
let hasher = RendezvousHasher;
let workers = make_workers(3);
let result = hasher.compute_replica_set("test-lora", &workers, 10);
// Should return all workers when replica_factor > worker count
assert_eq!(result.len(), 3);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! LoRA Allocation Algorithms - HRW and Random
use crate::kv_router::protocols::WorkerWithDpRank;
use std::str::FromStr;
pub mod hrw;
pub mod table;
pub use hrw::RendezvousHasher;
pub use table::{LoraReplicaConfig, LoraRoutingTable};
/// Trait for LoRA allocation algorithms
pub trait LoraAllocator: Send + Sync {
/// Returns a list of workers that should host this LoRA, ordered by preference
fn compute_replica_set(
&self,
lora_name: &str,
workers: &[WorkerWithDpRank],
replica_factor: usize,
) -> Vec<WorkerWithDpRank>;
/// Name of this algorithm (for logging/metrics)
fn name(&self) -> &str;
}
/// Factory for creating allocation algorithms
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocationAlgorithmType {
/// Rendezvous (Highest Random Weight) hashing
Hrw,
/// Random selection (for testing)
Random,
}
impl FromStr for AllocationAlgorithmType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"hrw" => Ok(Self::Hrw),
"random" => Ok(Self::Random),
_ => Err(format!("Unknown allocation algorithm type: {}", s)),
}
}
}
/// Create a LoRA allocation algorithm instance
pub fn create_lora_allocator(algo_type: AllocationAlgorithmType) -> Box<dyn LoraAllocator> {
match algo_type {
AllocationAlgorithmType::Hrw => Box::new(RendezvousHasher),
AllocationAlgorithmType::Random => Box::new(RandomAllocation),
}
}
/// Random allocation algorithm
struct RandomAllocation;
impl LoraAllocator for RandomAllocation {
fn compute_replica_set(
&self,
_lora_name: &str,
workers: &[WorkerWithDpRank],
_replica_factor: usize,
) -> Vec<WorkerWithDpRank> {
// Return all workers regardless of replica_factor
workers.to_vec()
}
fn name(&self) -> &str {
"random"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_lora_allocator() {
let hrw = create_lora_allocator(AllocationAlgorithmType::Hrw);
assert_eq!(hrw.name(), "hrw");
let random = create_lora_allocator(AllocationAlgorithmType::Random);
assert_eq!(random.name(), "random");
}
#[test]
fn test_random_allocation_basic() {
let random = RandomAllocation;
let workers = vec![
WorkerWithDpRank::new(1, 0),
WorkerWithDpRank::new(2, 0),
WorkerWithDpRank::new(3, 0),
];
// RandomAllocation returns all workers regardless of replica_factor
let result = random.compute_replica_set("test-lora", &workers, 2);
assert_eq!(result.len(), 3);
assert_eq!(result[0].worker_id, 1);
assert_eq!(result[1].worker_id, 2);
assert_eq!(result[2].worker_id, 3);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! LoRA Routing Table - Thread-safe data structure for storing LoRA allocation decisions.
use dashmap::DashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::kv_router::protocols::WorkerWithDpRank;
/// Configuration for a single LoRA's allocation
#[derive(Debug, Clone)]
pub struct LoraReplicaConfig {
/// Name of the LoRA adapter
pub lora_name: String,
/// Number of replicas configured
pub replica_factor: usize,
/// Workers selected to host this LoRA (in preference order)
pub replica_set: Vec<WorkerWithDpRank>,
/// When this allocation was last updated
pub updated_at: Instant,
}
/// Thread-safe allocation table using DashMap for concurrent access
#[derive(Clone)]
pub struct LoraRoutingTable {
allocations: Arc<DashMap<String, LoraReplicaConfig>>,
}
impl LoraRoutingTable {
/// Create a new empty allocation table
pub fn new() -> Self {
Self {
allocations: Arc::new(DashMap::new()),
}
}
/// Get the replica set for a LoRA
pub fn get_replica_set(&self, lora_name: &str) -> Option<Vec<WorkerWithDpRank>> {
self.allocations
.get(lora_name)
.map(|entry| entry.replica_set.clone())
}
/// Get the full configuration for a LoRA
pub fn get_config(&self, lora_name: &str) -> Option<LoraReplicaConfig> {
self.allocations.get(lora_name).map(|entry| entry.clone())
}
/// Update or insert an allocation configuration
pub fn update_allocation(&self, lora_name: String, config: LoraReplicaConfig) {
self.allocations.insert(lora_name, config);
}
/// Remove a LoRA from the allocation table
pub fn remove_lora(&self, lora_name: &str) -> Option<LoraReplicaConfig> {
self.allocations.remove(lora_name).map(|(_, v)| v)
}
/// List all LoRA names in the allocation table
pub fn list_loras(&self) -> Vec<String> {
self.allocations
.iter()
.map(|entry| entry.key().clone())
.collect()
}
/// Get the number of LoRAs in the allocation table
pub fn len(&self) -> usize {
self.allocations.len()
}
/// Check if the table is empty
pub fn is_empty(&self) -> bool {
self.allocations.is_empty()
}
/// Clear all entries from the table
pub fn clear(&self) {
self.allocations.clear();
}
}
impl Default for LoraRoutingTable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_workers(count: usize) -> Vec<WorkerWithDpRank> {
(0..count)
.map(|i| WorkerWithDpRank::new(i as u64, 0))
.collect()
}
#[test]
fn test_new_table_is_empty() {
let table = LoraRoutingTable::new();
assert!(table.is_empty());
assert_eq!(table.len(), 0);
assert_eq!(table.list_loras().len(), 0);
}
#[test]
fn test_insert_and_get() {
let table = LoraRoutingTable::new();
let workers = make_workers(3);
let config = LoraReplicaConfig {
lora_name: "test-lora".to_string(),
replica_factor: 2,
replica_set: workers[..2].to_vec(),
updated_at: Instant::now(),
};
table.update_allocation("test-lora".to_string(), config);
assert_eq!(table.len(), 1);
assert!(!table.is_empty());
let replica_set = table.get_replica_set("test-lora").unwrap();
assert_eq!(replica_set.len(), 2);
assert_eq!(replica_set[0].worker_id, 0);
assert_eq!(replica_set[1].worker_id, 1);
}
#[test]
fn test_get_nonexistent() {
let table = LoraRoutingTable::new();
assert!(table.get_replica_set("nonexistent").is_none());
assert!(table.get_config("nonexistent").is_none());
}
#[test]
fn test_update_existing() {
let table = LoraRoutingTable::new();
let workers = make_workers(3);
// Insert initial config
let config1 = LoraReplicaConfig {
lora_name: "test-lora".to_string(),
replica_factor: 1,
replica_set: workers[..1].to_vec(),
updated_at: Instant::now(),
};
table.update_allocation("test-lora".to_string(), config1);
// Update with new config
let config2 = LoraReplicaConfig {
lora_name: "test-lora".to_string(),
replica_factor: 2,
replica_set: workers[..2].to_vec(),
updated_at: Instant::now(),
};
table.update_allocation("test-lora".to_string(), config2);
// Should have new config
assert_eq!(table.len(), 1);
let replica_set = table.get_replica_set("test-lora").unwrap();
assert_eq!(replica_set.len(), 2);
}
#[test]
fn test_remove() {
let table = LoraRoutingTable::new();
let workers = make_workers(1);
let config = LoraReplicaConfig {
lora_name: "test-lora".to_string(),
replica_factor: 1,
replica_set: workers.clone(),
updated_at: Instant::now(),
};
table.update_allocation("test-lora".to_string(), config);
assert_eq!(table.len(), 1);
let removed = table.remove_lora("test-lora");
assert!(removed.is_some());
assert_eq!(table.len(), 0);
assert!(table.is_empty());
}
#[test]
fn test_list_loras() {
let table = LoraRoutingTable::new();
let workers = make_workers(1);
for i in 0..3 {
let config = LoraReplicaConfig {
lora_name: format!("lora-{}", i),
replica_factor: 1,
replica_set: workers.clone(),
updated_at: Instant::now(),
};
table.update_allocation(format!("lora-{}", i), config);
}
let loras = table.list_loras();
assert_eq!(loras.len(), 3);
assert!(loras.contains(&"lora-0".to_string()));
assert!(loras.contains(&"lora-1".to_string()));
assert!(loras.contains(&"lora-2".to_string()));
}
#[test]
fn test_clear() {
let table = LoraRoutingTable::new();
let workers = make_workers(1);
for i in 0..3 {
let config = LoraReplicaConfig {
lora_name: format!("lora-{}", i),
replica_factor: 1,
replica_set: workers.clone(),
updated_at: Instant::now(),
};
table.update_allocation(format!("lora-{}", i), config);
}
assert_eq!(table.len(), 3);
table.clear();
assert_eq!(table.len(), 0);
assert!(table.is_empty());
}
}
......@@ -230,10 +230,9 @@ pub struct ModelDeploymentCard {
/// `Text` for engines that take care of pre-processing themselves.
pub model_input: ModelInput,
/// Optional LoRA adapter name for this model card.
/// Present when this card represents a LoRA adapter registered on top of a base model.
/// LoRA metadata for routing
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lora_name: Option<String>,
pub lora: Option<LoraInfo>,
/// User-defined metadata for custom worker behavior
#[serde(default, skip_serializing_if = "Option::is_none")]
......@@ -254,6 +253,17 @@ pub struct ModelDeploymentCard {
checksum: OnceLock<String>,
}
/// LoRA adapter information for routing decisions
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct LoraInfo {
/// LoRA adapter name (e.g., "customer-123-v2")
pub name: String,
/// Maximum number of LoRA adapters that can be loaded at once on a single GPU
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_gpu_lora_count: Option<u32>,
}
impl ModelDeploymentCard {
pub fn builder() -> ModelDeploymentCardBuilder {
ModelDeploymentCardBuilder::default()
......@@ -656,7 +666,7 @@ impl ModelDeploymentCard {
migration_limit: 0,
model_type: Default::default(), // set later
model_input: Default::default(), // set later
lora_name: None,
lora: None,
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
media_decoder: None,
......
......@@ -157,7 +157,7 @@ impl OpenAIPreprocessor {
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let lora_name = mdc.lora_name.clone();
let lora_name = mdc.lora.as_ref().map(|l| l.name.clone());
let Some(ref model_info) = mdc.model_info else {
anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment