Unverified Commit b579a772 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(llm): preserve unresolved dp rank routing (#8000)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f3d3a8b3
......@@ -13,8 +13,9 @@ from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
# Sentinel value matching u32::MAX from prefill_router.rs SimpleRouter path,
# indicating no specific data-parallel rank was selected.
# Sentinel value matching u32::MAX from the C/Go prefill-routing ABI.
# This remains as a compatibility fallback for older callers that still encode
# an unresolved data-parallel rank in-band instead of omitting the field.
_DP_RANK_UNSET = 2**32 - 1
......
......@@ -120,7 +120,6 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
}
prefillWorkerID := strconv.FormatUint(result.WorkerID, 10)
prefillDpRank := strconv.FormatUint(uint64(result.DpRank), 10)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected",
"prefillWorkerID", prefillWorkerID,
"prefillDpRank", result.DpRank,
......@@ -134,7 +133,11 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
req.Headers = map[string]string{}
}
req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
req.Headers[PrefillDpRankHeader] = prefillDpRank
if result.DpRank != dynscorer.UnsetDpRank {
req.Headers[PrefillDpRankHeader] = strconv.FormatUint(uint64(result.DpRank), 10)
} else {
delete(req.Headers, PrefillDpRankHeader)
}
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative.
......
......@@ -124,6 +124,10 @@ var (
routerHandlesMutex sync.RWMutex
)
// UnsetDpRank is the ABI sentinel used by the Rust C bindings when a prefill
// route selected a worker but left the DP rank unresolved.
const UnsetDpRank = ^uint32(0)
func loadDynamoConfig() {
ffiNamespace = getEnvOrDefault("DYN_NAMESPACE_PREFIX", getEnvOrDefault("DYN_NAMESPACE", "vllm-agg"))
ffiComponent = "backend" // This is not the same as DYN_COMPONENT=epp (in this case)
......
......@@ -27,6 +27,7 @@ use dynamo_runtime::transports::event_plane::EventSubscriber;
// Re-export worker type constants from timing.rs (single source of truth)
pub use crate::protocols::common::timing::{WORKER_TYPE_DECODE, WORKER_TYPE_PREFILL};
const UNSET_DP_RANK_LABEL: &str = "none";
/// Clean up all Prometheus metrics for a worker across the specified dp_ranks.
///
......@@ -44,6 +45,11 @@ fn cleanup_worker_metrics(worker_id: u64, dp_ranks: &[u32], worker_type: &str) {
let _ = WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE.remove_label_values(labels);
let _ = WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE.remove_label_values(labels);
}
let unset_labels = &[worker_id_str.as_str(), UNSET_DP_RANK_LABEL, worker_type];
let _ = WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE.remove_label_values(unset_labels);
let _ = WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE.remove_label_values(unset_labels);
let _ = WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE.remove_label_values(unset_labels);
}
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
......
......@@ -42,6 +42,7 @@ use super::RouteDoc;
/// Worker type label values for Prometheus timing metrics
pub use crate::discovery::{WORKER_TYPE_DECODE, WORKER_TYPE_PREFILL};
const UNSET_DP_RANK_LABEL: &str = "none";
/// Global Prometheus gauge for last observed TTFT per worker (in seconds)
/// Labels: worker_id, dp_rank, worker_type
......@@ -1342,7 +1343,7 @@ impl ResponseMetricCollector {
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.prefill_dp_rank
.map_or("0".to_string(), |r| r.to_string());
.map_or(UNSET_DP_RANK_LABEL.to_string(), |r| r.to_string());
let worker_type = self
.prefill_worker_type
.as_deref()
......@@ -1385,7 +1386,7 @@ impl ResponseMetricCollector {
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.decode_dp_rank
.map_or("0".to_string(), |r| r.to_string());
.map_or(UNSET_DP_RANK_LABEL.to_string(), |r| r.to_string());
let worker_type = self
.decode_worker_type
.as_deref()
......
......@@ -131,6 +131,7 @@ where
{
indexer: Indexer,
scheduler: KvScheduler<Sel>,
workers_with_configs: RuntimeConfigWatch,
block_size: u32,
kv_router_config: KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
......@@ -230,6 +231,7 @@ where
Ok(Self {
indexer,
scheduler,
workers_with_configs,
block_size,
kv_router_config,
prefill_load_estimator,
......@@ -473,6 +475,13 @@ where
self.scheduler.worker_type()
}
/// Return the worker's unique global DP rank when it owns exactly one rank.
pub fn unique_dp_rank_for_worker(&self, worker_id: WorkerId) -> Option<u32> {
let configs = self.workers_with_configs.borrow();
let config = configs.get(&worker_id)?;
(config.data_parallel_size == 1).then_some(config.data_parallel_start_rank)
}
pub fn add_output_block(
&self,
request_id: &str,
......
......@@ -123,7 +123,7 @@ impl PrefillRouter {
///
/// If `phase_transition_permit` is provided, it is dropped immediately after routing completes,
/// allowing subsequent `set_phase` calls to proceed. This preserves the current synchronization:
/// the prefill route must finish `record_worker_full` before the phase can change to Decode.
/// the prefill route must finish worker recording before the phase can change to Decode.
///
/// Returns (PrefillResult, Option<(worker_id, dp_rank)>).
pub(super) async fn execute_prefill(
......@@ -131,7 +131,7 @@ impl PrefillRouter {
request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
phase_transition_permit: Option<OwnedSemaphorePermit>,
) -> Result<(PrefillResult, Option<(u64, u32)>), PrefillError> {
) -> Result<(PrefillResult, Option<(u64, Option<u32>)>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router
.generate_to_worker(request, target_worker)
......@@ -143,7 +143,7 @@ impl PrefillRouter {
)
})?;
// Release the phase barrier now that routing completed and record_worker_full already ran.
// Release the phase barrier now that routing completed and worker recording already ran.
// Decode may proceed without waiting for prefill output streaming to finish.
drop(phase_transition_permit);
......@@ -201,8 +201,7 @@ impl PrefillRouter {
let dp_rank = worker_id_json
.get("prefill_dp_rank")
.and_then(|v| v.as_u64())
.map(|r| r as u32)
.unwrap_or(0);
.map(|r| r as u32);
Some((worker_id, dp_rank))
});
Ok((
......
......@@ -148,7 +148,7 @@ impl
link_child_context(&engine_ctx, prefill_req, request_id.as_str());
// Pass the phase barrier to the spawned task. It is released after routing
// completes so `record_worker_full` finishes before phase changes to Decode.
// completes so worker recording finishes before phase changes to Decode.
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_barrier);
Ok(PrefillOutcome::Bootstrap(bootstrap_info))
......
......@@ -34,8 +34,9 @@ pub struct KvPushRouter {
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
instance_id: u64,
dp_rank: u32,
overlap_amount: u32,
backend_dp_rank: Option<u32>,
bookkeeping_dp_rank: Option<u32>,
overlap_amount: Option<u32>,
}
/// Drop guard that manages the full lifecycle of a routed request:
......@@ -46,6 +47,7 @@ struct WorkerSelection {
/// `Drop` impl fires and spawns a task to call `free()`.
struct RequestGuard {
chooser: Arc<KvRouter>,
scheduler_tracked: bool,
context_id: String,
tracker: Option<Arc<RequestTracker>>,
request_metrics: Arc<RouterRequestMetrics>,
......@@ -70,7 +72,9 @@ impl RequestGuard {
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = self.chooser.mark_prefill_completed(&self.context_id).await {
if self.scheduler_tracked
&& let Err(e) = self.chooser.mark_prefill_completed(&self.context_id).await
{
tracing::warn!(
"Failed to mark prefill completed for request {}: {e}",
self.context_id
......@@ -130,7 +134,9 @@ impl RequestGuard {
async fn finish(&mut self) {
self.record_metrics();
if let Err(e) = self.chooser.free(&self.context_id).await {
if self.scheduler_tracked
&& let Err(e) = self.chooser.free(&self.context_id).await
{
tracing::warn!("Failed to free request {}: {e}", self.context_id);
}
self.freed = true;
......@@ -155,7 +161,7 @@ impl RequestGuard {
impl Drop for RequestGuard {
fn drop(&mut self) {
self.record_metrics();
if !self.freed {
if !self.freed && self.scheduler_tracked {
let chooser = self.chooser.clone();
let context_id = self.context_id.clone();
let Ok(handle) = tokio::runtime::Handle::try_current() else {
......@@ -198,7 +204,6 @@ impl KvPushRouter {
let routing = request.routing.as_ref();
let lora_name = routing.and_then(|r| r.lora_name.clone());
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let allowed_worker_ids = routing.and_then(|r| r.allowed_worker_ids.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
......@@ -213,6 +218,10 @@ impl KvPushRouter {
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
};
let requested_dp_rank = match phase {
RequestPhase::Prefill => routing.and_then(|r| r.prefill_dp_rank.or(r.dp_rank)),
RequestPhase::Decode | RequestPhase::Aggregated => routing.and_then(|r| r.dp_rank),
};
let Some(id) = preselected_id else {
let _nvtx_kv = dynamo_nvtx_range!("route.kv_match");
......@@ -254,18 +263,23 @@ impl KvPushRouter {
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_amount,
backend_dp_rank: Some(best_worker.dp_rank),
bookkeeping_dp_rank: Some(best_worker.dp_rank),
overlap_amount: Some(overlap_amount),
});
};
let backend_dp_rank =
requested_dp_rank.or_else(|| self.chooser.unique_dp_rank_for_worker(id));
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
dp_rank = ?backend_dp_rank,
?phase,
"Routing to specified worker"
);
let (bookkeeping_dp_rank, overlap_amount) = if let Some(dp_rank) = backend_dp_rank {
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
......@@ -295,14 +309,26 @@ impl KvPushRouter {
request_id = %context_id,
worker_id = id,
dp_rank = dp_rank,
"Skipping add_request - query or handled externally"
"Skipping add_request - query-only request"
);
}
(Some(dp_rank), Some(overlap_blocks))
} else {
tracing::debug!(
request_id = %context_id,
worker_id = id,
?phase,
"Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping"
);
(None, None)
};
Ok(WorkerSelection {
instance_id: id,
dp_rank,
overlap_amount: overlap_blocks,
backend_dp_rank,
bookkeeping_dp_rank,
overlap_amount,
})
}
}
......@@ -354,14 +380,17 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.await?;
let WorkerSelection {
instance_id,
dp_rank,
backend_dp_rank,
bookkeeping_dp_rank,
overlap_amount,
} = selection;
let scheduler_tracked = !is_query_only && bookkeeping_dp_rank.is_some();
// In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
if let Some(dp_rank) = bookkeeping_dp_rank {
let lora_name = request.routing.as_ref().and_then(|r| r.lora_name.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
......@@ -387,6 +416,13 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
"Failed to record routing decision in approximate mode"
);
}
} else {
tracing::debug!(
request_id = %context_id,
worker_id = instance_id,
"Skipping approximate-mode routing decision for unresolved dp_rank"
);
}
}
// Record routing metrics on tracker and observe ISL + prefill start.
......@@ -395,12 +431,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
if let Some(ref tracker) = request.tracker {
let (routing_token_ids, _) = request.block_mm_routing_info();
let isl_blocks = routing_token_ids.len().div_ceil(block_size);
if let Some(overlap_amount) = overlap_amount {
tracker.record_kv_hit(overlap_amount, isl_blocks);
}
tracker.record_isl(
routing_token_ids.len(),
overlap_amount as usize * block_size,
overlap_amount.map(|overlap| overlap as usize * block_size),
);
tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
tracker.record_worker(instance_id, backend_dp_rank, self.chooser.worker_type());
tracker.record_router_queue_depth(self.chooser.pending_count());
if let Some(hit_rate) = tracker.kv_hit_rate() {
request_metrics.kv_hit_rate.observe(hit_rate);
......@@ -444,7 +482,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let tracker = request.tracker.clone();
let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = Some(dp_rank);
backend_input.routing_mut().dp_rank = backend_dp_rank;
let updated_request = context.map(|_| backend_input);
// Record prefill start right before pushing to backend (OnceLock: first call wins).
......@@ -460,8 +498,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
"kv_router.route_request",
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
overlap_blocks = overlap_amount,
dp_rank = ?backend_dp_rank,
overlap_blocks = ?overlap_amount,
phase = ?phase,
))
.await?;
......@@ -471,6 +509,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! {
let mut guard = RequestGuard {
chooser: chooser.clone(),
scheduler_tracked,
context_id: context_id.clone(),
tracker: tracker.clone(),
request_metrics: request_metrics.clone(),
......@@ -479,7 +518,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
freed: false,
prefill_marked: false,
first_token_recorded: false,
track_output_blocks,
track_output_blocks: scheduler_tracked && track_output_blocks,
current_total_blocks: isl_tokens.div_ceil(block_size),
isl_tokens,
block_size,
......
......@@ -8,7 +8,7 @@
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use parking_lot::Mutex;
......@@ -22,16 +22,12 @@ use crate::http::service::metrics::{
};
use crate::protocols::openai::nvext::WorkerIdInfo;
/// Sentinel value indicating no worker ID has been set.
/// We use 0 as the sentinel since valid worker IDs are non-zero lease IDs from etcd.
const NO_WORKER_ID: u64 = 0;
const NO_DP_RANK: u32 = u32::MAX;
/// Worker type constants for Prometheus metric labels.
/// These are stored in RequestTracker at routing time to avoid costly MDC lookups
/// when updating per-worker metrics (TTFT, ITL).
pub const WORKER_TYPE_PREFILL: &str = "prefill";
pub const WORKER_TYPE_DECODE: &str = "decode";
const UNSET_DP_RANK_LABEL: &str = "none";
/// Phase of the request in disaggregated serving.
///
......@@ -81,8 +77,8 @@ impl std::fmt::Display for RequestPhase {
/// phase's final finish naturally overwrites the prefill phase's earlier finish.
/// `phase` also uses a Mutex since it transitions across phases.
///
/// **`AtomicU64`/`AtomicU32`:** Used for frequently updated counters (`osl_tokens`)
/// and worker IDs/ranks where `OnceLock`'s heap overhead is unnecessary.
/// **`AtomicU64`:** Used for frequently updated counters (`osl_tokens`) and
/// accumulated detokenize timing, where lock-free updates are beneficial.
#[derive(Debug)]
pub struct RequestTracker {
/// When the request was received (monotonic clock for duration calculations)
......@@ -118,19 +114,17 @@ pub struct RequestTracker {
/// Output sequence length in tokens - updated atomically as tokens stream back
osl_tokens: AtomicU64,
/// Prefill worker ID (for disaggregated serving).
/// Uses atomic with compare-exchange for set-once semantics.
/// Value of 0 (NO_WORKER_ID) means not yet set.
prefill_worker_id: AtomicU64,
/// Prefill worker ID (for disaggregated serving) - set once when known.
prefill_worker_id: OnceLock<u64>,
/// Prefill DP rank. Value of u32::MAX (NO_DP_RANK) means not yet set.
prefill_dp_rank: AtomicU32,
/// Prefill DP rank - set once when known.
prefill_dp_rank: OnceLock<u32>,
/// Decode worker ID. Value of 0 (NO_WORKER_ID) means not yet set.
decode_worker_id: AtomicU64,
/// Decode worker ID - set once when known.
decode_worker_id: OnceLock<u64>,
/// Decode DP rank. Value of u32::MAX (NO_DP_RANK) means not yet set.
decode_dp_rank: AtomicU32,
/// Decode DP rank - set once when known.
decode_dp_rank: OnceLock<u32>,
/// Worker type for the prefill worker ("prefill" or "decode").
/// Stored at routing time to avoid MDC lookup when updating Prometheus metrics.
......@@ -149,7 +143,7 @@ pub struct RequestTracker {
/// Semaphore for coordinating phase transitions.
/// Acquiring a permit blocks subsequent set_phase calls until the permit is dropped.
/// This prevents race conditions in the bootstrap optimization path where prefill
/// runs in background and needs to complete record_worker_full before phase changes.
/// runs in background and needs to complete worker recording before phase changes.
phase_semaphore: Arc<Semaphore>,
/// How long it took to tokenize the input
......@@ -185,10 +179,10 @@ impl RequestTracker {
isl_tokens: OnceLock::new(),
cached_tokens: OnceLock::new(),
osl_tokens: AtomicU64::new(0),
prefill_worker_id: AtomicU64::new(NO_WORKER_ID),
prefill_dp_rank: AtomicU32::new(NO_DP_RANK),
decode_worker_id: AtomicU64::new(NO_WORKER_ID),
decode_dp_rank: AtomicU32::new(NO_DP_RANK),
prefill_worker_id: OnceLock::new(),
prefill_dp_rank: OnceLock::new(),
decode_worker_id: OnceLock::new(),
decode_dp_rank: OnceLock::new(),
prefill_worker_type: OnceLock::new(),
decode_worker_type: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated),
......@@ -220,11 +214,13 @@ impl RequestTracker {
overlap_set && isl_set
}
/// Record input sequence length in tokens and cached token count.
pub fn record_isl(&self, isl_tokens: usize, cached_tokens: usize) {
/// Record input sequence length in tokens and cached token count when known.
pub fn record_isl(&self, isl_tokens: usize, cached_tokens: Option<usize>) {
let _ = self.isl_tokens.set(isl_tokens);
if let Some(cached_tokens) = cached_tokens {
let _ = self.cached_tokens.set(cached_tokens);
}
}
pub fn isl_tokens(&self) -> Option<usize> {
self.isl_tokens.get().copied()
......@@ -321,31 +317,96 @@ impl RequestTracker {
*self.phase.lock()
}
/// Record worker ID, DP rank, and worker type based on the current phase.
///
/// Each slot is written exactly once by `KvPushRouter::generate()`:
/// - Prefill phase: stores as prefill worker
/// - Decode phase: stores as decode worker
/// - Aggregated phase: stores as both prefill and decode worker
pub fn record_worker_full(&self, instance_id: u64, dp_rank: u32, worker_type: &'static str) {
match self.phase() {
RequestPhase::Prefill => {
self.prefill_worker_id.store(instance_id, Ordering::Relaxed);
self.prefill_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.prefill_worker_type.set(worker_type);
fn record_once_u64(slot: &OnceLock<u64>, value: u64, field_name: &'static str) {
if let Some(existing) = slot.get() {
if *existing != value {
tracing::error!(
field = field_name,
existing = *existing,
new = value,
"Conflicting request tracker write"
);
}
return;
}
RequestPhase::Decode => {
self.decode_worker_id.store(instance_id, Ordering::Relaxed);
self.decode_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.decode_worker_type.set(worker_type);
let _ = slot.set(value);
}
fn record_once_u32(slot: &OnceLock<u32>, value: u32, field_name: &'static str) {
if let Some(existing) = slot.get() {
if *existing != value {
tracing::error!(
field = field_name,
existing = *existing,
new = value,
"Conflicting request tracker write"
);
}
return;
}
let _ = slot.set(value);
}
fn record_once_worker_type(
slot: &OnceLock<&'static str>,
value: &'static str,
field_name: &'static str,
) {
if let Some(existing) = slot.get() {
if *existing != value {
tracing::error!(
field = field_name,
existing = *existing,
new = value,
"Conflicting request tracker write"
);
}
return;
}
let _ = slot.set(value);
}
fn record_prefill_worker(
&self,
instance_id: u64,
dp_rank: Option<u32>,
worker_type: &'static str,
) {
Self::record_once_u64(&self.prefill_worker_id, instance_id, "prefill_worker_id");
if let Some(rank) = dp_rank {
Self::record_once_u32(&self.prefill_dp_rank, rank, "prefill_dp_rank");
}
Self::record_once_worker_type(
&self.prefill_worker_type,
worker_type,
"prefill_worker_type",
);
}
fn record_decode_worker(
&self,
instance_id: u64,
dp_rank: Option<u32>,
worker_type: &'static str,
) {
Self::record_once_u64(&self.decode_worker_id, instance_id, "decode_worker_id");
if let Some(rank) = dp_rank {
Self::record_once_u32(&self.decode_dp_rank, rank, "decode_dp_rank");
}
Self::record_once_worker_type(&self.decode_worker_type, worker_type, "decode_worker_type");
}
/// Record worker ID, optional DP rank, and worker type based on the current phase.
///
/// Worker ID and type are recorded as soon as they are known. DP rank is recorded only
/// when it is concrete, allowing the unresolved rank to remain unset until later.
pub fn record_worker(&self, instance_id: u64, dp_rank: Option<u32>, worker_type: &'static str) {
match self.phase() {
RequestPhase::Prefill => self.record_prefill_worker(instance_id, dp_rank, worker_type),
RequestPhase::Decode => self.record_decode_worker(instance_id, dp_rank, worker_type),
RequestPhase::Aggregated => {
self.prefill_worker_id.store(instance_id, Ordering::Relaxed);
self.prefill_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.prefill_worker_type.set(worker_type);
self.decode_worker_id.store(instance_id, Ordering::Relaxed);
self.decode_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.decode_worker_type.set(worker_type);
self.record_prefill_worker(instance_id, dp_rank, worker_type);
self.record_decode_worker(instance_id, dp_rank, worker_type);
}
}
}
......@@ -415,26 +476,22 @@ impl RequestTracker {
/// Get the decode worker ID if recorded.
pub fn decode_worker_id(&self) -> Option<u64> {
let id = self.decode_worker_id.load(Ordering::SeqCst);
if id == NO_WORKER_ID { None } else { Some(id) }
self.decode_worker_id.get().copied()
}
/// Get the decode DP rank if recorded.
pub fn decode_dp_rank(&self) -> Option<u32> {
let rank = self.decode_dp_rank.load(Ordering::SeqCst);
if rank == NO_DP_RANK { None } else { Some(rank) }
self.decode_dp_rank.get().copied()
}
/// Get the prefill worker ID if recorded.
pub fn prefill_worker_id(&self) -> Option<u64> {
let id = self.prefill_worker_id.load(Ordering::SeqCst);
if id == NO_WORKER_ID { None } else { Some(id) }
self.prefill_worker_id.get().copied()
}
/// Get the prefill DP rank if recorded.
pub fn prefill_dp_rank(&self) -> Option<u32> {
let rank = self.prefill_dp_rank.load(Ordering::SeqCst);
if rank == NO_DP_RANK { None } else { Some(rank) }
self.prefill_dp_rank.get().copied()
}
/// Get the prefill worker type if recorded.
......@@ -456,7 +513,7 @@ impl RequestTracker {
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.prefill_dp_rank()
.map_or("0".to_string(), |r| r.to_string());
.map_or(UNSET_DP_RANK_LABEL.to_string(), |r| r.to_string());
let worker_type = self.prefill_worker_type().unwrap_or(WORKER_TYPE_PREFILL);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
......@@ -481,7 +538,7 @@ impl RequestTracker {
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.decode_dp_rank()
.map_or("0".to_string(), |r| r.to_string());
.map_or(UNSET_DP_RANK_LABEL.to_string(), |r| r.to_string());
let worker_type = self.decode_worker_type().unwrap_or(WORKER_TYPE_DECODE);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
......@@ -555,7 +612,7 @@ mod tests {
fn test_record_isl_osl() {
let tracker = RequestTracker::new();
tracker.record_isl(512, 256);
tracker.record_isl(512, Some(256));
assert_eq!(tracker.isl_tokens(), Some(512));
assert_eq!(tracker.cached_tokens(), Some(256));
......@@ -659,7 +716,7 @@ mod tests {
fn test_observe_first_token_gauges_no_panic_without_worker() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_isl(100, 50);
tracker.record_isl(100, Some(50));
// No worker recorded — should return early without panic
tracker.observe_first_token_gauges();
}
......@@ -677,10 +734,10 @@ mod tests {
#[test]
fn test_observe_first_token_gauges_with_worker() {
let tracker = RequestTracker::new();
tracker.record_worker_full(42, 0, WORKER_TYPE_PREFILL);
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
thread::sleep(Duration::from_millis(5));
tracker.record_first_token();
tracker.record_isl(256, 128);
tracker.record_isl(256, Some(128));
tracker.observe_first_token_gauges();
......@@ -702,7 +759,7 @@ mod tests {
#[test]
fn test_observe_finish_gauges_with_worker() {
let tracker = RequestTracker::new();
tracker.record_worker_full(99, 1, WORKER_TYPE_DECODE);
tracker.record_worker(99, Some(1), WORKER_TYPE_DECODE);
tracker.record_first_token();
thread::sleep(Duration::from_millis(10));
tracker.record_osl(5);
......
......@@ -13,6 +13,7 @@ pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
pub const HEADER_DP_RANK: &str = "x-dp-rank";
pub const HEADER_PREFILL_DP_RANK: &str = "x-prefill-dp-rank";
const UNSET_DP_RANK_SENTINEL: u32 = u32::MAX;
/// Apply routing overrides from HTTP headers to nvext.
///
......@@ -44,6 +45,7 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
.get(HEADER_PREFILL_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = prefill_dp_rank.filter(|rank| *rank != UNSET_DP_RANK_SENTINEL);
if worker_id.is_none() && prefill_id.is_none() && dp_rank.is_none() && prefill_dp_rank.is_none()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment