Unverified Commit ebaf048d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: plumb allowed_worker_ids through RoutingHints (#6580)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4fbd8f1c
......@@ -919,7 +919,7 @@ impl KvRouter {
update_states,
lora_name,
0.0,
None, // allowed_worker_ids not exposed in Python API yet
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
)
.await
.map_err(to_pyerr)?;
......
......@@ -386,19 +386,13 @@ impl KvRouter {
});
let hash_elapsed = start.elapsed();
let mut overlap_scores = self
let overlap_scores = self
.indexer
.find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
let find_matches_elapsed = start.elapsed();
if let Some(ref allowed_ids) = allowed_worker_ids {
overlap_scores
.scores
.retain(|worker, _| allowed_ids.contains(&worker.worker_id));
}
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking(
......
......@@ -48,6 +48,14 @@ pub enum PrefillError {
NoDisaggregatedParams(String),
}
/// Result of the prefill phase in `generate()`.
enum PrefillOutcome {
/// Bootstrap optimization: prefill spawned in background, bootstrap info ready
Bootstrap(BootstrapInfo),
/// Synchronous prefill completed with result
Completed(PrefillResult),
}
/// The inner router used by PrefillRouter
#[derive(Clone)]
enum InnerPrefillRouter {
......@@ -193,7 +201,7 @@ impl PrefillRouter {
"Activating prefill router"
);
// Store endpoint_id for later use in build_bootstrap_info
// Store endpoint_id for later use in resolve_prefill_worker
let _ = self.endpoint_id.set(endpoint.id());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
......@@ -270,16 +278,16 @@ impl PrefillRouter {
Ok(())
}
/// Build bootstrap_info for disaggregated serving
/// Select a prefill worker and resolve its bootstrap connection info.
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
async fn build_bootstrap_info(
async fn resolve_prefill_worker(
&self,
req: &PreprocessedRequest,
preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> {
let endpoint_id = self.endpoint_id.get()?;
let _prefill_router = self.prefill_router.get()?;
self.prefill_router.get()?;
// Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
......@@ -299,6 +307,10 @@ impl PrefillRouter {
.as_ref()
.and_then(|r| r.priority_jump)
.unwrap_or(0.0);
let allowed_worker_ids = req
.routing
.as_ref()
.and_then(|r| r.allowed_worker_ids.clone());
let (routing_token_ids, block_mm_infos) = req.block_mm_routing_info();
match self
.query_prefill_worker(
......@@ -307,7 +319,7 @@ impl PrefillRouter {
false,
lora_name,
priority_jump,
None,
allowed_worker_ids,
)
.await
{
......@@ -383,6 +395,13 @@ impl PrefillRouter {
));
};
if let Some(err) = first_output.err() {
return Err(PrefillError::PrefillError(
"Prefill router returned error in output".to_string(),
Some(Box::new(err)),
));
}
let mut prompt_tokens_details = first_output
.data
.as_ref()
......@@ -400,13 +419,6 @@ impl PrefillRouter {
}
}
if let Some(err) = first_output.err() {
return Err(PrefillError::PrefillError(
"Prefill router returned error in output".to_string(),
Some(Box::new(err)),
));
}
let Some(output) = &first_output.data else {
return Err(PrefillError::NoDisaggregatedParams(
"Prefill router output has no data field".to_string(),
......@@ -499,7 +511,7 @@ impl PrefillRouter {
/// Query the best prefill worker without executing a request.
/// Returns (worker_id, dp_rank).
///
/// This is the shared worker selection logic used by both `build_bootstrap_info`
/// This is the shared worker selection logic used by both `resolve_prefill_worker`
/// and `query_route`.
pub async fn query_prefill_worker(
&self,
......@@ -597,7 +609,7 @@ impl
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
// Try build_bootstrap_info optimization: if we can get bootstrap info upfront,
// Try to resolve prefill worker upfront: if we can get bootstrap info early,
// spawn prefill in background and proceed to decode immediately.
let preselected_worker = prefill_req
.routing
......@@ -606,7 +618,7 @@ impl
let prefill_result = async {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.resolve_prefill_worker(&prefill_req, preselected_worker)
.await
{
// Bootstrap optimization path: spawn prefill in background
......@@ -630,7 +642,7 @@ impl
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
Ok((None, Some(worker_id), Some(bootstrap_info)))
Ok(PrefillOutcome::Bootstrap(bootstrap_info))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");
......@@ -642,11 +654,9 @@ impl
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let result = self.call_prefill(prefill_context).await;
let (result, _worker_info) = self.call_prefill(prefill_context).await?;
result.map(|(result, worker_info)| {
(Some(result), worker_info.map(|(id, _)| id), None)
})
Ok(PrefillOutcome::Completed(result))
}
}
.await;
......@@ -662,7 +672,7 @@ impl
// Handle prefill result
match prefill_result {
Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
Ok(outcome) => {
tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request.
......@@ -675,19 +685,18 @@ impl
let mut decode_req = req;
// Update request with prefill result
if let Some(prefill_result) = maybe_prefill_result {
decode_req.prefill_result = Some(prefill_result);
match outcome {
PrefillOutcome::Bootstrap(info) => {
decode_req.bootstrap_info = Some(info);
}
PrefillOutcome::Completed(result) => {
decode_req.prefill_result = Some(result);
}
}
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
// Inject bootstrap_info for decode worker
if let Some(info) = bootstrap_info {
decode_req.bootstrap_info = Some(info);
}
// Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
......
......@@ -200,6 +200,7 @@ impl KvPushRouter {
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let allowed_worker_ids = routing.and_then(|r| r.allowed_worker_ids.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
// Get pre-selected worker based on phase, with backend_instance_id as fallback
......@@ -224,7 +225,7 @@ impl KvPushRouter {
!is_query_only,
lora_name,
priority_jump,
None,
allowed_worker_ids,
)
.await?;
......
......@@ -285,6 +285,7 @@ impl OpenAIPreprocessor {
priority_jump: hints.and_then(|h| h.latency_sensitivity),
priority: hints.and_then(|h| h.priority),
lora_name,
allowed_worker_ids: None,
};
builder.routing(Some(routing));
} else if lora_name.is_some() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::sync::Arc;
use derive_builder::Builder;
......@@ -9,7 +10,7 @@ use serde::{Deserialize, Serialize};
use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::kv_router::protocols::BlockExtraInfo;
use crate::kv_router::protocols::{BlockExtraInfo, WorkerId};
use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType;
......@@ -54,6 +55,11 @@ pub struct RoutingHints {
/// Backend engine scheduling priority forwarded to the generate call.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
/// Optional set of allowed worker IDs to restrict routing decisions (EPP).
/// When set, only workers in this set are considered during scoring.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_worker_ids: Option<HashSet<WorkerId>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment