Unverified Commit 373e76c1 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat(lora): Add lora_name tracking to scheduling and sequence management (#5875)

parent 18d9d1fa
...@@ -950,6 +950,7 @@ pub unsafe extern "C" fn dynamo_router_add_request( ...@@ -950,6 +950,7 @@ pub unsafe extern "C" fn dynamo_router_add_request(
overlap_blocks, overlap_blocks,
None, None,
worker, worker,
None, // lora_name not exposed in C API yet
) )
.await; .await;
......
...@@ -1249,6 +1249,7 @@ impl KvPushRouter { ...@@ -1249,6 +1249,7 @@ impl KvPushRouter {
&token_ids, &token_ids,
router_config_override.as_ref(), router_config_override.as_ref(),
update_states, update_states,
None, // lora_name not exposed in Python API yet
) )
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -222,6 +222,8 @@ pub struct ActiveSequenceEvent { ...@@ -222,6 +222,8 @@ pub struct ActiveSequenceEvent {
pub worker: WorkerWithDpRank, pub worker: WorkerWithDpRank,
pub data: ActiveSequenceEventData, pub data: ActiveSequenceEventData,
pub router_id: u64, pub router_id: u64,
#[serde(default)]
pub lora_name: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
......
...@@ -486,12 +486,14 @@ impl KvRouter { ...@@ -486,12 +486,14 @@ impl KvRouter {
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks. /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking /// Now also takes optional context_id for request tracking
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match( pub async fn find_best_match(
&self, &self,
context_id: Option<&str>, context_id: Option<&str>,
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> { ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
// Validate that context_id is provided when update_states is true // Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() { if update_states && context_id.is_none() {
...@@ -517,6 +519,7 @@ impl KvRouter { ...@@ -517,6 +519,7 @@ impl KvRouter {
overlap_scores.clone(), overlap_scores.clone(),
router_config_override, router_config_override,
update_states, update_states,
lora_name,
) )
.await?; .await?;
...@@ -531,6 +534,7 @@ impl KvRouter { ...@@ -531,6 +534,7 @@ impl KvRouter {
Ok((best_worker, overlap_amount)) Ok((best_worker, overlap_amount))
} }
#[allow(clippy::too_many_arguments)]
pub async fn add_request( pub async fn add_request(
&self, &self,
request_id: String, request_id: String,
...@@ -538,6 +542,7 @@ impl KvRouter { ...@@ -538,6 +542,7 @@ impl KvRouter {
overlap_blocks: u32, overlap_blocks: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<String>,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
...@@ -554,6 +559,7 @@ impl KvRouter { ...@@ -554,6 +559,7 @@ impl KvRouter {
overlap_blocks, overlap_blocks,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name,
) )
.await .await
{ {
...@@ -687,7 +693,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -687,7 +693,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New { tokens } => {
let (best_worker, overlap_blocks) = self let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true) .find_best_match(Some(&context_id), &tokens, None, true, None)
.await?; .await?;
RouterResponse::New { RouterResponse::New {
...@@ -744,6 +750,9 @@ impl KvPushRouter { ...@@ -744,6 +750,9 @@ impl KvPushRouter {
) -> Result<WorkerSelection, Error> { ) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref(); let routing = request.routing.as_ref();
// Extract LORA name from routing hints
let lora_name = routing.and_then(|r| r.lora_name.clone());
// Get pre-selected worker based on phase, with backend_instance_id as fallback // Get pre-selected worker based on phase, with backend_instance_id as fallback
let Some(id) = (match phase { let Some(id) = (match phase {
RequestPhase::Prefill => { RequestPhase::Prefill => {
...@@ -763,6 +772,7 @@ impl KvPushRouter { ...@@ -763,6 +772,7 @@ impl KvPushRouter {
&request.token_ids, &request.token_ids,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!is_query_only, !is_query_only,
lora_name,
) )
.await?; .await?;
...@@ -804,6 +814,7 @@ impl KvPushRouter { ...@@ -804,6 +814,7 @@ impl KvPushRouter {
overlap_blocks, overlap_blocks,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name,
) )
.await; .await;
} else { } else {
......
...@@ -266,10 +266,12 @@ impl PrefillRouter { ...@@ -266,10 +266,12 @@ impl PrefillRouter {
InnerPrefillRouter::KvRouter(r) => r, InnerPrefillRouter::KvRouter(r) => r,
_ => return None, _ => return None,
}; };
// Extract LORA name from routing hints
let lora_name = req.routing.as_ref().and_then(|r| r.lora_name.clone());
match async { match async {
kv_router kv_router
.chooser .chooser
.find_best_match(None, &req.token_ids, None, false) .find_best_match(None, &req.token_ids, None, false, lora_name)
.await .await
} }
.instrument(tracing::info_span!("kv_find_best_match")) .instrument(tracing::info_span!("kv_find_best_match"))
......
...@@ -68,6 +68,8 @@ pub struct SchedulingRequest { ...@@ -68,6 +68,8 @@ pub struct SchedulingRequest {
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests) // Whether to update scheduler states (false for query_instance_id requests)
pub update_states: bool, pub update_states: bool,
// LORA adapter name extracted from request.model field
pub lora_name: Option<String>,
// Option to take it out to send the response without moving the struct // Option to take it out to send the response without moving the struct
resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>, resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
} }
...@@ -248,6 +250,7 @@ impl KvScheduler { ...@@ -248,6 +250,7 @@ impl KvScheduler {
selection.overlap_blocks, selection.overlap_blocks,
None, // expected_output_tokens not available in scheduler loop None, // expected_output_tokens not available in scheduler loop
selection.worker, selection.worker,
request.lora_name.clone(),
) )
.await .await
{ {
...@@ -272,6 +275,7 @@ impl KvScheduler { ...@@ -272,6 +275,7 @@ impl KvScheduler {
Ok(KvScheduler { request_tx, slots }) Ok(KvScheduler { request_tx, slots })
} }
#[allow(clippy::too_many_arguments)]
pub async fn schedule( pub async fn schedule(
&self, &self,
maybe_request_id: Option<String>, maybe_request_id: Option<String>,
...@@ -280,6 +284,7 @@ impl KvScheduler { ...@@ -280,6 +284,7 @@ impl KvScheduler {
overlaps: OverlapScores, overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>,
) -> Result<WorkerWithDpRank, KvSchedulerError> { ) -> Result<WorkerWithDpRank, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
...@@ -291,6 +296,7 @@ impl KvScheduler { ...@@ -291,6 +296,7 @@ impl KvScheduler {
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(), router_config_override: router_config_override.cloned(),
update_states, update_states,
lora_name,
resp_tx: Some(resp_tx), // Wrap in Some() resp_tx: Some(resp_tx), // Wrap in Some()
}; };
...@@ -305,6 +311,7 @@ impl KvScheduler { ...@@ -305,6 +311,7 @@ impl KvScheduler {
Ok(response.best_worker) Ok(response.best_worker)
} }
#[allow(clippy::too_many_arguments)]
pub async fn add_request( pub async fn add_request(
&self, &self,
request_id: String, request_id: String,
...@@ -313,6 +320,7 @@ impl KvScheduler { ...@@ -313,6 +320,7 @@ impl KvScheduler {
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<String>,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.slots self.slots
.add_request( .add_request(
...@@ -322,6 +330,7 @@ impl KvScheduler { ...@@ -322,6 +330,7 @@ impl KvScheduler {
overlap, overlap,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name,
) )
.await .await
} }
...@@ -378,6 +387,11 @@ impl KvScheduler { ...@@ -378,6 +387,11 @@ impl KvScheduler {
loads loads
} }
/// Get active request counts grouped by LORA name
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
}
} }
// Helper function for softmax sampling // Helper function for softmax sampling
......
...@@ -405,6 +405,7 @@ enum UpdateSequences { ...@@ -405,6 +405,7 @@ enum UpdateSequences {
pub struct ActiveSequencesMultiWorker { pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>, senders: Arc<DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>, request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
handles: Arc<DashMap<WorkerWithDpRank, std::thread::JoinHandle<()>>>, handles: Arc<DashMap<WorkerWithDpRank, std::thread::JoinHandle<()>>>,
block_size: usize, block_size: usize,
component: Component, component: Component,
...@@ -429,6 +430,7 @@ impl ActiveSequencesMultiWorker { ...@@ -429,6 +430,7 @@ impl ActiveSequencesMultiWorker {
let senders = Arc::new(DashMap::new()); let senders = Arc::new(DashMap::new());
let handles = Arc::new(DashMap::new()); let handles = Arc::new(DashMap::new());
let request_to_worker = Arc::new(DashMap::new()); let request_to_worker = Arc::new(DashMap::new());
let request_to_lora = Arc::new(DashMap::new());
// Expand workers by their dp_rank // Expand workers by their dp_rank
for (worker_id, config) in workers_with_configs { for (worker_id, config) in workers_with_configs {
...@@ -452,6 +454,7 @@ impl ActiveSequencesMultiWorker { ...@@ -452,6 +454,7 @@ impl ActiveSequencesMultiWorker {
let multi_worker = Self { let multi_worker = Self {
senders: senders.clone(), senders: senders.clone(),
request_to_worker: request_to_worker.clone(), request_to_worker: request_to_worker.clone(),
request_to_lora: request_to_lora.clone(),
handles, handles,
block_size, block_size,
component: component.clone(), component: component.clone(),
...@@ -465,6 +468,7 @@ impl ActiveSequencesMultiWorker { ...@@ -465,6 +468,7 @@ impl ActiveSequencesMultiWorker {
if replica_sync { if replica_sync {
let senders_clone = senders.clone(); let senders_clone = senders.clone();
let request_to_worker_clone = request_to_worker.clone(); let request_to_worker_clone = request_to_worker.clone();
let request_to_lora_clone = request_to_lora.clone();
let component_clone = component.clone(); let component_clone = component.clone();
let router_id_clone = router_id; let router_id_clone = router_id;
let cancel_token = component.drt().runtime().child_token(); let cancel_token = component.drt().runtime().child_token();
...@@ -474,6 +478,7 @@ impl ActiveSequencesMultiWorker { ...@@ -474,6 +478,7 @@ impl ActiveSequencesMultiWorker {
if let Err(e) = Self::subscribe_to_events( if let Err(e) = Self::subscribe_to_events(
senders_clone, senders_clone,
request_to_worker_clone, request_to_worker_clone,
request_to_lora_clone,
component_clone, component_clone,
router_id_clone, router_id_clone,
cancel_token, cancel_token,
...@@ -603,6 +608,7 @@ impl ActiveSequencesMultiWorker { ...@@ -603,6 +608,7 @@ impl ActiveSequencesMultiWorker {
DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>, DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>,
>, >,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>, request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
component: Component, component: Component,
router_id: u64, router_id: u64,
cancel_token: CancellationToken, cancel_token: CancellationToken,
...@@ -642,6 +648,11 @@ impl ActiveSequencesMultiWorker { ...@@ -642,6 +648,11 @@ impl ActiveSequencesMultiWorker {
} => { } => {
request_to_worker.insert(event.request_id.clone(), event.worker); request_to_worker.insert(event.request_id.clone(), event.worker);
// Store lora_name mapping if present
if let Some(ref lora_name) = event.lora_name {
request_to_lora.insert(event.request_id.clone(), lora_name.clone());
}
if let Some(sender) = senders.get(&event.worker) { if let Some(sender) = senders.get(&event.worker) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests // For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel(); let (resp_tx, _) = tokio::sync::oneshot::channel();
...@@ -668,6 +679,8 @@ impl ActiveSequencesMultiWorker { ...@@ -668,6 +679,8 @@ impl ActiveSequencesMultiWorker {
request_id: event.request_id.clone(), request_id: event.request_id.clone(),
}); });
} }
// Clean up lora_name mapping
request_to_lora.remove(&event.request_id);
} }
ActiveSequenceEventData::MarkPrefillCompleted => { ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker) = request_to_worker.get(&event.request_id) if let Some(worker) = request_to_worker.get(&event.request_id)
...@@ -724,9 +737,22 @@ impl ActiveSequencesMultiWorker { ...@@ -724,9 +737,22 @@ impl ActiveSequencesMultiWorker {
} }
self.handles.remove(worker); self.handles.remove(worker);
// Collect request_ids to remove from request_to_lora
let requests_to_remove: Vec<RequestId> = self
.request_to_worker
.iter()
.filter(|entry| entry.value() == worker)
.map(|entry| entry.key().clone())
.collect();
// Clean up request_to_worker mappings for this worker // Clean up request_to_worker mappings for this worker
self.request_to_worker self.request_to_worker
.retain(|_request_id, mapped_worker| mapped_worker != worker); .retain(|_request_id, mapped_worker| mapped_worker != worker);
// Clean up request_to_lora mappings for removed requests
for request_id in requests_to_remove {
self.request_to_lora.remove(&request_id);
}
} }
// Add new workers // Add new workers
...@@ -742,6 +768,7 @@ impl ActiveSequencesMultiWorker { ...@@ -742,6 +768,7 @@ impl ActiveSequencesMultiWorker {
} }
} }
#[allow(clippy::too_many_arguments)]
pub async fn add_request( pub async fn add_request(
&self, &self,
request_id: RequestId, request_id: RequestId,
...@@ -750,6 +777,7 @@ impl ActiveSequencesMultiWorker { ...@@ -750,6 +777,7 @@ impl ActiveSequencesMultiWorker {
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<String>,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
// Check for worker existence // Check for worker existence
if !self.senders.contains_key(&worker) { if !self.senders.contains_key(&worker) {
...@@ -779,6 +807,7 @@ impl ActiveSequencesMultiWorker { ...@@ -779,6 +807,7 @@ impl ActiveSequencesMultiWorker {
expected_output_tokens, expected_output_tokens,
}, },
router_id: self.router_id, router_id: self.router_id,
lora_name: lora_name.clone(),
}; };
self.event_publisher.publish(&event).await?; self.event_publisher.publish(&event).await?;
} }
...@@ -786,6 +815,11 @@ impl ActiveSequencesMultiWorker { ...@@ -786,6 +815,11 @@ impl ActiveSequencesMultiWorker {
// Update local state with full WorkerWithDpRank // Update local state with full WorkerWithDpRank
self.request_to_worker.insert(request_id.clone(), worker); self.request_to_worker.insert(request_id.clone(), worker);
// Store lora_name for later use in Free/MarkPrefillCompleted events
if let Some(lora) = lora_name {
self.request_to_lora.insert(request_id.clone(), lora);
}
self.senders self.senders
.get(&worker) .get(&worker)
.unwrap() .unwrap()
...@@ -807,6 +841,7 @@ impl ActiveSequencesMultiWorker { ...@@ -807,6 +841,7 @@ impl ActiveSequencesMultiWorker {
// Remove expired requests from request_to_worker mapping // Remove expired requests from request_to_worker mapping
for expired_id in &removed_requests { for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id); self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id);
} }
// Publish ActiveLoad metrics for this worker // Publish ActiveLoad metrics for this worker
...@@ -833,11 +868,18 @@ impl ActiveSequencesMultiWorker { ...@@ -833,11 +868,18 @@ impl ActiveSequencesMultiWorker {
// Publish event only if replica_sync is enabled // Publish event only if replica_sync is enabled
if self.replica_sync { if self.replica_sync {
// Look up lora_name from mapping
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
worker, worker,
data: ActiveSequenceEventData::Free, data: ActiveSequenceEventData::Free,
router_id: self.router_id, router_id: self.router_id,
lora_name,
}; };
self.event_publisher.publish(&event).await?; self.event_publisher.publish(&event).await?;
} }
...@@ -852,6 +894,7 @@ impl ActiveSequencesMultiWorker { ...@@ -852,6 +894,7 @@ impl ActiveSequencesMultiWorker {
.map_err(|_| SequenceError::WorkerChannelClosed)?; .map_err(|_| SequenceError::WorkerChannelClosed)?;
self.request_to_worker.remove(request_id); self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id);
// Publish ActiveLoad metrics for this worker // Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await; self.publish_active_load_for_worker(worker).await;
...@@ -882,11 +925,18 @@ impl ActiveSequencesMultiWorker { ...@@ -882,11 +925,18 @@ impl ActiveSequencesMultiWorker {
// Publish event only if replica_sync is enabled // Publish event only if replica_sync is enabled
if self.replica_sync { if self.replica_sync {
// Look up lora_name from mapping
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
worker, worker,
data: ActiveSequenceEventData::MarkPrefillCompleted, data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id, router_id: self.router_id,
lora_name,
}; };
self.event_publisher.publish(&event).await?; self.event_publisher.publish(&event).await?;
} }
...@@ -1156,6 +1206,15 @@ impl ActiveSequencesMultiWorker { ...@@ -1156,6 +1206,15 @@ impl ActiveSequencesMultiWorker {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx }) self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
.await .await
} }
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for entry in self.request_to_lora.iter() {
let lora_name = entry.value().clone();
*counts.entry(lora_name).or_insert(0) += 1;
}
counts
}
} }
impl Drop for ActiveSequencesMultiWorker { impl Drop for ActiveSequencesMultiWorker {
...@@ -1264,6 +1323,7 @@ mod tests { ...@@ -1264,6 +1323,7 @@ mod tests {
0, // no overlap 0, // no overlap
None, // expected_output_tokens None, // expected_output_tokens
WorkerWithDpRank::new(0, 0), WorkerWithDpRank::new(0, 0),
None, // lora_name
) )
.await?; .await?;
...@@ -1276,6 +1336,7 @@ mod tests { ...@@ -1276,6 +1336,7 @@ mod tests {
0, // no overlap 0, // no overlap
None, // expected_output_tokens None, // expected_output_tokens
WorkerWithDpRank::new(0, 1), WorkerWithDpRank::new(0, 1),
None, // lora_name
) )
.await?; .await?;
...@@ -1288,6 +1349,7 @@ mod tests { ...@@ -1288,6 +1349,7 @@ mod tests {
0, // no overlap 0, // no overlap
None, // expected_output_tokens None, // expected_output_tokens
WorkerWithDpRank::new(1, 0), WorkerWithDpRank::new(1, 0),
None, // lora_name
) )
.await?; .await?;
...@@ -1423,6 +1485,7 @@ mod tests { ...@@ -1423,6 +1485,7 @@ mod tests {
0, // no overlap 0, // no overlap
None, // expected_output_tokens None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(0), WorkerWithDpRank::from_worker_id(0),
None, // lora_name
) )
.await?; .await?;
...@@ -1435,6 +1498,7 @@ mod tests { ...@@ -1435,6 +1498,7 @@ mod tests {
0, // no overlap 0, // no overlap
None, // expected_output_tokens None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
None, // lora_name
) )
.await?; .await?;
...@@ -1447,6 +1511,7 @@ mod tests { ...@@ -1447,6 +1511,7 @@ mod tests {
0, // no overlap 0, // no overlap
None, // expected_output_tokens None, // expected_output_tokens
WorkerWithDpRank::from_worker_id(2), WorkerWithDpRank::from_worker_id(2),
None, // lora_name
) )
.await?; .await?;
......
...@@ -480,6 +480,7 @@ impl LocalModel { ...@@ -480,6 +480,7 @@ impl LocalModel {
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.card.model_type = model_type; self.card.model_type = model_type;
self.card.model_input = model_input; self.card.model_input = model_input;
self.card.lora_name = lora_name.map(|name| name.to_string());
// Compute model_suffix from lora_name if present // Compute model_suffix from lora_name if present
let model_suffix = lora_name.map(|name| Slug::slugify(name).to_string()); let model_suffix = lora_name.map(|name| Slug::slugify(name).to_string());
......
...@@ -230,6 +230,11 @@ pub struct ModelDeploymentCard { ...@@ -230,6 +230,11 @@ pub struct ModelDeploymentCard {
/// `Text` for engines that take care of pre-processing themselves. /// `Text` for engines that take care of pre-processing themselves.
pub model_input: ModelInput, pub model_input: ModelInput,
/// Optional LoRA adapter name for this model card.
/// Present when this card represents a LoRA adapter registered on top of a base model.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lora_name: Option<String>,
/// User-defined metadata for custom worker behavior /// User-defined metadata for custom worker behavior
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<serde_json::Value>, pub user_data: Option<serde_json::Value>,
...@@ -651,6 +656,7 @@ impl ModelDeploymentCard { ...@@ -651,6 +656,7 @@ impl ModelDeploymentCard {
migration_limit: 0, migration_limit: 0,
model_type: Default::default(), // set later model_type: Default::default(), // set later
model_input: Default::default(), // set later model_input: Default::default(), // set later
lora_name: None,
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
media_decoder: None, media_decoder: None,
......
...@@ -113,6 +113,7 @@ pub struct OpenAIPreprocessor { ...@@ -113,6 +113,7 @@ pub struct OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
model_info: Arc<dyn ModelInfo>, model_info: Arc<dyn ModelInfo>,
lora_name: Option<String>,
/// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser) /// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig, runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
...@@ -136,7 +137,8 @@ impl OpenAIPreprocessor { ...@@ -136,7 +137,8 @@ impl OpenAIPreprocessor {
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum().to_string(); let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer)); let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let Some(model_info) = mdc.model_info else { let lora_name = mdc.lora_name.clone();
let Some(ref model_info) = mdc.model_info else {
anyhow::bail!( anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no model_info" "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
); );
...@@ -144,6 +146,10 @@ impl OpenAIPreprocessor { ...@@ -144,6 +146,10 @@ impl OpenAIPreprocessor {
let model_info = model_info.get_model_info()?; let model_info = model_info.get_model_info()?;
let tool_call_parser = mdc.runtime_config.tool_call_parser.clone(); let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
if let Some(ref lora_name) = lora_name {
tracing::info!(model = %mdc.display_name, lora_name, "LoRA adapter detected in MDC");
}
// // Initialize runtime config from the ModelDeploymentCard // // Initialize runtime config from the ModelDeploymentCard
let runtime_config = mdc.runtime_config.clone(); let runtime_config = mdc.runtime_config.clone();
...@@ -158,6 +164,7 @@ impl OpenAIPreprocessor { ...@@ -158,6 +164,7 @@ impl OpenAIPreprocessor {
tokenizer, tokenizer,
model_info, model_info,
mdcsum, mdcsum,
lora_name,
runtime_config, runtime_config,
tool_call_parser, tool_call_parser,
#[cfg(feature = "media-nixl")] #[cfg(feature = "media-nixl")]
...@@ -237,6 +244,8 @@ impl OpenAIPreprocessor { ...@@ -237,6 +244,8 @@ impl OpenAIPreprocessor {
builder.output_options(request.extract_output_options()?); builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
let lora_name = self.lora_name.clone();
// Extract routing hints from nvext if present // Extract routing hints from nvext if present
if let Some(nvext) = request.nvext() { if let Some(nvext) = request.nvext() {
// Build routing hints from nvext fields // Build routing hints from nvext fields
...@@ -247,8 +256,15 @@ impl OpenAIPreprocessor { ...@@ -247,8 +256,15 @@ impl OpenAIPreprocessor {
dp_rank: None, // dp_rank is set later in the pipeline dp_rank: None, // dp_rank is set later in the pipeline
enable_local_updates: nvext.enable_local_updates, enable_local_updates: nvext.enable_local_updates,
expected_output_tokens: nvext.expected_output_tokens, expected_output_tokens: nvext.expected_output_tokens,
lora_name,
}; };
builder.routing(Some(routing)); builder.routing(Some(routing));
} else if lora_name.is_some() {
// Ensure LoRA-aware routing still gets hints even when nvext is absent.
builder.routing(Some(RoutingHints {
lora_name,
..Default::default()
}));
} }
Ok(builder) Ok(builder)
......
...@@ -47,6 +47,11 @@ pub struct RoutingHints { ...@@ -47,6 +47,11 @@ pub struct RoutingHints {
/// Used as a hint for routing decisions to estimate resource requirements. /// Used as a hint for routing decisions to estimate resource requirements.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub expected_output_tokens: Option<u32>, pub expected_output_tokens: Option<u32>,
/// LORA adapter name for this request.
/// Used for LORA-aware routing and tracking.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lora_name: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Default)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment