Unverified Commit ebaf048d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: plumb allowed_worker_ids through RoutingHints (#6580)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4fbd8f1c
...@@ -919,7 +919,7 @@ impl KvRouter { ...@@ -919,7 +919,7 @@ impl KvRouter {
update_states, update_states,
lora_name, lora_name,
0.0, 0.0,
None, // allowed_worker_ids not exposed in Python API yet None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
) )
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -386,19 +386,13 @@ impl KvRouter { ...@@ -386,19 +386,13 @@ impl KvRouter {
}); });
let hash_elapsed = start.elapsed(); let hash_elapsed = start.elapsed();
let mut overlap_scores = self let overlap_scores = self
.indexer .indexer
.find_matches(block_hashes) .find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches")) .instrument(tracing::info_span!("kv_router.find_matches"))
.await?; .await?;
let find_matches_elapsed = start.elapsed(); let find_matches_elapsed = start.elapsed();
if let Some(ref allowed_ids) = allowed_worker_ids {
overlap_scores
.scores
.retain(|worker, _| allowed_ids.contains(&worker.worker_id));
}
// Compute seq_hashes only if scheduler needs it for active blocks tracking // Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| { let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking( self.kv_router_config.compute_seq_hashes_for_tracking(
......
...@@ -48,6 +48,14 @@ pub enum PrefillError { ...@@ -48,6 +48,14 @@ pub enum PrefillError {
NoDisaggregatedParams(String), NoDisaggregatedParams(String),
} }
/// Result of the prefill phase in `generate()`.
enum PrefillOutcome {
/// Bootstrap optimization: prefill spawned in background, bootstrap info ready
Bootstrap(BootstrapInfo),
/// Synchronous prefill completed with result
Completed(PrefillResult),
}
/// The inner router used by PrefillRouter /// The inner router used by PrefillRouter
#[derive(Clone)] #[derive(Clone)]
enum InnerPrefillRouter { enum InnerPrefillRouter {
...@@ -193,7 +201,7 @@ impl PrefillRouter { ...@@ -193,7 +201,7 @@ impl PrefillRouter {
"Activating prefill router" "Activating prefill router"
); );
// Store endpoint_id for later use in build_bootstrap_info // Store endpoint_id for later use in resolve_prefill_worker
let _ = self.endpoint_id.set(endpoint.id()); let _ = self.endpoint_id.set(endpoint.id());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint) // Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
...@@ -270,16 +278,16 @@ impl PrefillRouter { ...@@ -270,16 +278,16 @@ impl PrefillRouter {
Ok(()) Ok(())
} }
/// Build bootstrap_info for disaggregated serving /// Select a prefill worker and resolve its bootstrap connection info.
/// If preselected_worker is provided (GAIE Stage 2), use it directly. /// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes). /// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
async fn build_bootstrap_info( async fn resolve_prefill_worker(
&self, &self,
req: &PreprocessedRequest, req: &PreprocessedRequest,
preselected_worker: Option<u64>, preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> { ) -> Option<(u64, u32, BootstrapInfo)> {
let endpoint_id = self.endpoint_id.get()?; let endpoint_id = self.endpoint_id.get()?;
let _prefill_router = self.prefill_router.get()?; self.prefill_router.get()?;
// Worker selection // Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker { let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
...@@ -299,6 +307,10 @@ impl PrefillRouter { ...@@ -299,6 +307,10 @@ impl PrefillRouter {
.as_ref() .as_ref()
.and_then(|r| r.priority_jump) .and_then(|r| r.priority_jump)
.unwrap_or(0.0); .unwrap_or(0.0);
let allowed_worker_ids = req
.routing
.as_ref()
.and_then(|r| r.allowed_worker_ids.clone());
let (routing_token_ids, block_mm_infos) = req.block_mm_routing_info(); let (routing_token_ids, block_mm_infos) = req.block_mm_routing_info();
match self match self
.query_prefill_worker( .query_prefill_worker(
...@@ -307,7 +319,7 @@ impl PrefillRouter { ...@@ -307,7 +319,7 @@ impl PrefillRouter {
false, false,
lora_name, lora_name,
priority_jump, priority_jump,
None, allowed_worker_ids,
) )
.await .await
{ {
...@@ -383,6 +395,13 @@ impl PrefillRouter { ...@@ -383,6 +395,13 @@ impl PrefillRouter {
)); ));
}; };
if let Some(err) = first_output.err() {
return Err(PrefillError::PrefillError(
"Prefill router returned error in output".to_string(),
Some(Box::new(err)),
));
}
let mut prompt_tokens_details = first_output let mut prompt_tokens_details = first_output
.data .data
.as_ref() .as_ref()
...@@ -400,13 +419,6 @@ impl PrefillRouter { ...@@ -400,13 +419,6 @@ impl PrefillRouter {
} }
} }
if let Some(err) = first_output.err() {
return Err(PrefillError::PrefillError(
"Prefill router returned error in output".to_string(),
Some(Box::new(err)),
));
}
let Some(output) = &first_output.data else { let Some(output) = &first_output.data else {
return Err(PrefillError::NoDisaggregatedParams( return Err(PrefillError::NoDisaggregatedParams(
"Prefill router output has no data field".to_string(), "Prefill router output has no data field".to_string(),
...@@ -499,7 +511,7 @@ impl PrefillRouter { ...@@ -499,7 +511,7 @@ impl PrefillRouter {
/// Query the best prefill worker without executing a request. /// Query the best prefill worker without executing a request.
/// Returns (worker_id, dp_rank). /// Returns (worker_id, dp_rank).
/// ///
/// This is the shared worker selection logic used by both `build_bootstrap_info` /// This is the shared worker selection logic used by both `resolve_prefill_worker`
/// and `query_route`. /// and `query_route`.
pub async fn query_prefill_worker( pub async fn query_prefill_worker(
&self, &self,
...@@ -597,7 +609,7 @@ impl ...@@ -597,7 +609,7 @@ impl
let mut prefill_req = req.clone(); let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1); prefill_req.stop_conditions.max_tokens = Some(1);
// Try build_bootstrap_info optimization: if we can get bootstrap info upfront, // Try to resolve prefill worker upfront: if we can get bootstrap info early,
// spawn prefill in background and proceed to decode immediately. // spawn prefill in background and proceed to decode immediately.
let preselected_worker = prefill_req let preselected_worker = prefill_req
.routing .routing
...@@ -606,7 +618,7 @@ impl ...@@ -606,7 +618,7 @@ impl
let prefill_result = async { let prefill_result = async {
if let Some((worker_id, dp_rank, bootstrap_info)) = self if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker) .resolve_prefill_worker(&prefill_req, preselected_worker)
.await .await
{ {
// Bootstrap optimization path: spawn prefill in background // Bootstrap optimization path: spawn prefill in background
...@@ -630,7 +642,7 @@ impl ...@@ -630,7 +642,7 @@ impl
// This allows set_phase(Decode) below to proceed only after prefill routing is done // This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit); self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
Ok((None, Some(worker_id), Some(bootstrap_info))) Ok(PrefillOutcome::Bootstrap(bootstrap_info))
} else { } else {
// Original prefill path: wait for prefill to complete // Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path"); tracing::debug!("Using original prefill path");
...@@ -642,11 +654,9 @@ impl ...@@ -642,11 +654,9 @@ impl
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
let result = self.call_prefill(prefill_context).await; let (result, _worker_info) = self.call_prefill(prefill_context).await?;
result.map(|(result, worker_info)| { Ok(PrefillOutcome::Completed(result))
(Some(result), worker_info.map(|(id, _)| id), None)
})
} }
} }
.await; .await;
...@@ -662,7 +672,7 @@ impl ...@@ -662,7 +672,7 @@ impl
// Handle prefill result // Handle prefill result
match prefill_result { match prefill_result {
Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => { Ok(outcome) => {
tracing::debug!("Prefill completed, proceeding to decode"); tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request. // Set phase to Decode for the decode request.
...@@ -675,19 +685,18 @@ impl ...@@ -675,19 +685,18 @@ impl
let mut decode_req = req; let mut decode_req = req;
// Update request with prefill result match outcome {
if let Some(prefill_result) = maybe_prefill_result { PrefillOutcome::Bootstrap(info) => {
decode_req.prefill_result = Some(prefill_result); decode_req.bootstrap_info = Some(info);
}
PrefillOutcome::Completed(result) => {
decode_req.prefill_result = Some(result);
}
} }
// Restore original max_tokens for decode // Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens; decode_req.stop_conditions.max_tokens = original_max_tokens;
// Inject bootstrap_info for decode worker
if let Some(info) = bootstrap_info {
decode_req.bootstrap_info = Some(info);
}
// Set router_config_override for decode: // Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode) // - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers // - assume_kv_reuse = false (generate random hashes since decode workers
......
...@@ -200,6 +200,7 @@ impl KvPushRouter { ...@@ -200,6 +200,7 @@ impl KvPushRouter {
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0); let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens); let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let allowed_worker_ids = routing.and_then(|r| r.allowed_worker_ids.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info(); let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
// Get pre-selected worker based on phase, with backend_instance_id as fallback // Get pre-selected worker based on phase, with backend_instance_id as fallback
...@@ -224,7 +225,7 @@ impl KvPushRouter { ...@@ -224,7 +225,7 @@ impl KvPushRouter {
!is_query_only, !is_query_only,
lora_name, lora_name,
priority_jump, priority_jump,
None, allowed_worker_ids,
) )
.await?; .await?;
......
...@@ -285,6 +285,7 @@ impl OpenAIPreprocessor { ...@@ -285,6 +285,7 @@ impl OpenAIPreprocessor {
priority_jump: hints.and_then(|h| h.latency_sensitivity), priority_jump: hints.and_then(|h| h.latency_sensitivity),
priority: hints.and_then(|h| h.priority), priority: hints.and_then(|h| h.priority),
lora_name, lora_name,
allowed_worker_ids: None,
}; };
builder.routing(Some(routing)); builder.routing(Some(routing));
} else if lora_name.is_some() { } else if lora_name.is_some() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use derive_builder::Builder; use derive_builder::Builder;
...@@ -9,7 +10,7 @@ use serde::{Deserialize, Serialize}; ...@@ -9,7 +10,7 @@ use serde::{Deserialize, Serialize};
use super::timing::RequestTracker; use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions}; use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride; use crate::kv_router::RouterConfigOverride;
use crate::kv_router::protocols::BlockExtraInfo; use crate::kv_router::protocols::{BlockExtraInfo, WorkerId};
use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
...@@ -54,6 +55,11 @@ pub struct RoutingHints { ...@@ -54,6 +55,11 @@ pub struct RoutingHints {
/// Backend engine scheduling priority forwarded to the generate call. /// Backend engine scheduling priority forwarded to the generate call.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>, pub priority: Option<i32>,
/// Optional set of allowed worker IDs to restrict routing decisions (EPP).
/// When set, only workers in this set are considered during scoring.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_worker_ids: Option<HashSet<WorkerId>>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Default)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment