Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use tokio::time::Instant;
use super::single::RequestId;
#[derive(Debug, Clone, Copy)]
pub(super) struct PrefillLoadState {
pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>,
}
#[derive(Debug, Default)]
pub(super) struct PrefillLoadTracker {
pub(super) prefill_order: VecDeque<RequestId>,
pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<(RequestId, Instant)>,
}
impl PrefillLoadTracker {
pub(super) fn insert(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens;
let should_anchor = self.anchored_prefill.is_none();
self.prefill_order.push_back(request_id.clone());
if should_anchor {
self.anchored_prefill = Some((request_id.clone(), decay_now));
}
}
pub(super) fn remove(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum = self
.prefill_full_tokens_sum
.checked_sub(prefill.initial_effective_prefill_tokens)
.expect("prefill_full_tokens_sum underflow");
let removed_front = self.prefill_order.front() == Some(request_id);
if removed_front {
let removed = self.prefill_order.pop_front();
debug_assert_eq!(removed.as_ref(), Some(request_id));
} else {
self.prefill_order
.retain(|queued_request_id| queued_request_id != request_id);
}
if self
.anchored_prefill
.as_ref()
.is_some_and(|(anchored_request_id, _)| anchored_request_id == request_id)
{
self.set_anchor_to_front(decay_now);
}
}
pub(super) fn set_anchor_to_front(&mut self, now: Instant) {
self.anchored_prefill = self
.prefill_order
.front()
.cloned()
.map(|request_id| (request_id, now));
}
}
This diff is collapsed.
......@@ -4,7 +4,7 @@
use std::{collections::HashSet, sync::Arc};
use dashmap::{DashMap, mapref::entry::Entry};
use dynamo_kv_router::{config::KvRouterConfig, protocols::WorkerId};
use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig, protocols::WorkerId};
use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig;
......@@ -568,6 +568,7 @@ impl ModelManager {
endpoint: &Endpoint,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
......@@ -604,6 +605,7 @@ impl ModelManager {
kv_cache_block_size,
selector,
kv_router_config,
prefill_load_estimator,
worker_type,
model_name,
is_eagle,
......
......@@ -7,6 +7,7 @@ use tokio::sync::mpsc::Sender;
use anyhow::Context as _;
use dashmap::DashSet;
use dynamo_kv_router::PrefillLoadEstimator;
use futures::StreamExt;
use dynamo_runtime::{
......@@ -74,6 +75,7 @@ pub struct ModelWatcher {
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>,
/// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>,
......@@ -118,6 +120,7 @@ impl ModelWatcher {
router_config: RouterConfig,
migration_limit: u32,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>,
) -> ModelWatcher {
Self {
......@@ -128,6 +131,7 @@ impl ModelWatcher {
notify_on_model: Notify::new(),
model_update_tx: None,
chat_engine_factory,
prefill_load_estimator,
metrics,
registering_worker_sets: DashSet::new(),
}
......@@ -465,6 +469,7 @@ impl ModelWatcher {
&endpoint,
card.kv_cache_block_size,
Some(self.router_config.kv_router_config.clone()),
self.prefill_load_estimator.clone(),
WORKER_TYPE_DECODE, // This is the decode router
Some(card.display_name.clone()),
card.runtime_config.enable_eagle,
......@@ -506,6 +511,7 @@ impl ModelWatcher {
self.router_config.router_mode,
card.kv_cache_block_size,
Some(prefill_config),
self.prefill_load_estimator.clone(),
self.router_config.enforce_disagg,
model_name.clone(),
namespace.clone(),
......
......@@ -12,7 +12,7 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode};
use crate::{
......@@ -68,6 +68,7 @@ pub enum EngineConfig {
Dynamic {
model: Box<LocalModel>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
},
/// A Text engine receives text, does it's own tokenization and prompt formatting.
......
......@@ -94,7 +94,9 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic {
model: local_model, ..
model: local_model,
prefill_load_estimator,
..
} => {
let model_manager = Arc::new(ModelManager::new());
// Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode)
......@@ -105,6 +107,7 @@ pub async fn prepare_engine(
RouterConfig::default(),
local_model.migration_limit(),
None,
prefill_load_estimator,
metrics,
));
let discovery = distributed_runtime.discovery();
......
......@@ -33,7 +33,11 @@ pub async fn run(
}
let grpc_service = match engine_config {
EngineConfig::Dynamic { ref model, .. } => {
EngineConfig::Dynamic {
ref model,
ref prefill_load_estimator,
..
} => {
let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config();
let migration_limit = model.migration_limit();
......@@ -48,6 +52,7 @@ pub async fn run(
router_config.clone(),
migration_limit,
namespace_filter,
prefill_load_estimator.clone(),
)
.await?;
grpc_service
......@@ -111,6 +116,7 @@ async fn run_watcher(
router_config: RouterConfig,
migration_limit: u32,
namespace_filter: NamespaceFilter,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new());
......@@ -120,6 +126,7 @@ async fn run_watcher(
router_config,
migration_limit,
None,
prefill_load_estimator,
metrics,
);
tracing::debug!("Waiting for remote model");
......
......@@ -67,6 +67,7 @@ pub async fn run(
EngineConfig::Dynamic {
ref model,
ref chat_engine_factory,
ref prefill_load_estimator,
} => {
// Pass the discovery client so the /health endpoint can query active instances
http_service_builder =
......@@ -90,6 +91,7 @@ pub async fn run(
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
chat_engine_factory.clone(),
prefill_load_estimator.clone(),
)
.await?;
http_service
......@@ -167,6 +169,7 @@ async fn run_watcher(
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(
runtime.clone(),
......@@ -174,6 +177,7 @@ async fn run_watcher(
router_config,
migration_limit,
chat_engine_factory,
prefill_load_estimator,
metrics.clone(),
);
tracing::debug!("Waiting for remote model");
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT,
protocols::{
BlockExtraInfo, BlockHashOptions, DpRank, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
BlockExtraInfo, BlockHashOptions, DpRank, PrefillLoadHint, RouterEvent, RouterRequest,
RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
},
};
use dynamo_runtime::{
......@@ -111,6 +113,7 @@ where
scheduler: KvScheduler<Sel>,
block_size: u32,
kv_router_config: KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
is_eagle: bool,
......@@ -128,6 +131,7 @@ where
block_size: u32,
selector: Sel,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
......@@ -159,6 +163,7 @@ where
workers_with_configs.clone(),
selector,
&kv_router_config,
prefill_load_estimator.clone(),
worker_type,
)
.await?;
......@@ -184,6 +189,7 @@ where
scheduler,
block_size,
kv_router_config,
prefill_load_estimator,
cancellation_token,
client,
is_eagle,
......@@ -345,6 +351,8 @@ where
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
if let Err(e) = self
.scheduler
......@@ -355,6 +363,7 @@ where
overlap: overlap_blocks,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
worker,
lora_name,
})
......@@ -377,6 +386,42 @@ where
self.scheduler.pending_count()
}
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
}
}
/// Get the worker type for this router ("prefill" or "decode").
/// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str {
......
......@@ -6,7 +6,7 @@ use std::sync::Arc;
use anyhow::Result;
use tokio::sync::oneshot;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{
component::{Client, Endpoint},
pipeline::{PushRouter, RouterMode},
......@@ -37,6 +37,7 @@ impl PrefillRouter {
cancel_token: tokio_util::sync::CancellationToken::new(),
router_mode,
enforce_disagg,
prefill_load_estimator: None,
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
is_eagle: false,
......@@ -50,6 +51,7 @@ impl PrefillRouter {
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
enforce_disagg: bool,
model_name: String,
namespace: String,
......@@ -65,6 +67,7 @@ impl PrefillRouter {
cancel_token: cancel_token.clone(),
router_mode,
enforce_disagg,
prefill_load_estimator,
model_name,
namespace,
is_eagle,
......@@ -85,6 +88,7 @@ impl PrefillRouter {
model_manager,
kv_cache_block_size,
kv_router_config,
router_clone.prefill_load_estimator.clone(),
).await {
tracing::error!(error = %e, "Failed to activate prefill router");
}
......@@ -105,6 +109,7 @@ impl PrefillRouter {
model_manager: Arc<ModelManager>,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> Result<()> {
tracing::info!(
router_mode = ?self.router_mode,
......@@ -127,6 +132,7 @@ impl PrefillRouter {
&endpoint,
kv_cache_block_size,
kv_router_config,
prefill_load_estimator,
WORKER_TYPE_PREFILL,
Some(self.model_name.clone()),
self.is_eagle,
......
......@@ -6,6 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result;
use tokio_util::sync::CancellationToken;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn,
......@@ -47,6 +48,7 @@ pub struct PrefillRouter {
cancel_token: CancellationToken,
router_mode: RouterMode,
enforce_disagg: bool,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
/// Model name used to look up the worker monitor for prefill client registration
model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment