Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use tokio::time::Instant;
use super::single::RequestId;
#[derive(Debug, Clone, Copy)]
pub(super) struct PrefillLoadState {
pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>,
}
#[derive(Debug, Default)]
pub(super) struct PrefillLoadTracker {
pub(super) prefill_order: VecDeque<RequestId>,
pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<(RequestId, Instant)>,
}
impl PrefillLoadTracker {
pub(super) fn insert(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens;
let should_anchor = self.anchored_prefill.is_none();
self.prefill_order.push_back(request_id.clone());
if should_anchor {
self.anchored_prefill = Some((request_id.clone(), decay_now));
}
}
pub(super) fn remove(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum = self
.prefill_full_tokens_sum
.checked_sub(prefill.initial_effective_prefill_tokens)
.expect("prefill_full_tokens_sum underflow");
let removed_front = self.prefill_order.front() == Some(request_id);
if removed_front {
let removed = self.prefill_order.pop_front();
debug_assert_eq!(removed.as_ref(), Some(request_id));
} else {
self.prefill_order
.retain(|queued_request_id| queued_request_id != request_id);
}
if self
.anchored_prefill
.as_ref()
.is_some_and(|(anchored_request_id, _)| anchored_request_id == request_id)
{
self.set_anchor_to_front(decay_now);
}
}
pub(super) fn set_anchor_to_front(&mut self, now: Instant) {
self.anchored_prefill = self
.prefill_order
.front()
.cloned()
.map(|request_id| (request_id, now));
}
}
This diff is collapsed.
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use dashmap::{DashMap, mapref::entry::Entry}; use dashmap::{DashMap, mapref::entry::Entry};
use dynamo_kv_router::{config::KvRouterConfig, protocols::WorkerId}; use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig, protocols::WorkerId};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig; use super::worker_monitor::LoadThresholdConfig;
...@@ -568,6 +568,7 @@ impl ModelManager { ...@@ -568,6 +568,7 @@ impl ModelManager {
endpoint: &Endpoint, endpoint: &Endpoint,
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str, worker_type: &'static str,
model_name: Option<String>, model_name: Option<String>,
is_eagle: bool, is_eagle: bool,
...@@ -604,6 +605,7 @@ impl ModelManager { ...@@ -604,6 +605,7 @@ impl ModelManager {
kv_cache_block_size, kv_cache_block_size,
selector, selector,
kv_router_config, kv_router_config,
prefill_load_estimator,
worker_type, worker_type,
model_name, model_name,
is_eagle, is_eagle,
......
...@@ -7,6 +7,7 @@ use tokio::sync::mpsc::Sender; ...@@ -7,6 +7,7 @@ use tokio::sync::mpsc::Sender;
use anyhow::Context as _; use anyhow::Context as _;
use dashmap::DashSet; use dashmap::DashSet;
use dynamo_kv_router::PrefillLoadEstimator;
use futures::StreamExt; use futures::StreamExt;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -74,6 +75,7 @@ pub struct ModelWatcher { ...@@ -74,6 +75,7 @@ pub struct ModelWatcher {
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
/// Guards against concurrent pipeline construction for the same (model, namespace). /// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>, registering_worker_sets: DashSet<String>,
...@@ -118,6 +120,7 @@ impl ModelWatcher { ...@@ -118,6 +120,7 @@ impl ModelWatcher {
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> ModelWatcher { ) -> ModelWatcher {
Self { Self {
...@@ -128,6 +131,7 @@ impl ModelWatcher { ...@@ -128,6 +131,7 @@ impl ModelWatcher {
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
chat_engine_factory, chat_engine_factory,
prefill_load_estimator,
metrics, metrics,
registering_worker_sets: DashSet::new(), registering_worker_sets: DashSet::new(),
} }
...@@ -465,6 +469,7 @@ impl ModelWatcher { ...@@ -465,6 +469,7 @@ impl ModelWatcher {
&endpoint, &endpoint,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(self.router_config.kv_router_config.clone()), Some(self.router_config.kv_router_config.clone()),
self.prefill_load_estimator.clone(),
WORKER_TYPE_DECODE, // This is the decode router WORKER_TYPE_DECODE, // This is the decode router
Some(card.display_name.clone()), Some(card.display_name.clone()),
card.runtime_config.enable_eagle, card.runtime_config.enable_eagle,
...@@ -506,6 +511,7 @@ impl ModelWatcher { ...@@ -506,6 +511,7 @@ impl ModelWatcher {
self.router_config.router_mode, self.router_config.router_mode,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(prefill_config), Some(prefill_config),
self.prefill_load_estimator.clone(),
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
model_name.clone(), model_name.clone(),
namespace.clone(), namespace.clone(),
......
...@@ -12,7 +12,7 @@ use std::future::Future; ...@@ -12,7 +12,7 @@ use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode}; use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode};
use crate::{ use crate::{
...@@ -68,6 +68,7 @@ pub enum EngineConfig { ...@@ -68,6 +68,7 @@ pub enum EngineConfig {
Dynamic { Dynamic {
model: Box<LocalModel>, model: Box<LocalModel>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
}, },
/// A Text engine receives text, does it's own tokenization and prompt formatting. /// A Text engine receives text, does it's own tokenization and prompt formatting.
......
...@@ -94,7 +94,9 @@ pub async fn prepare_engine( ...@@ -94,7 +94,9 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic { EngineConfig::Dynamic {
model: local_model, .. model: local_model,
prefill_load_estimator,
..
} => { } => {
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
// Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode) // Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode)
...@@ -105,6 +107,7 @@ pub async fn prepare_engine( ...@@ -105,6 +107,7 @@ pub async fn prepare_engine(
RouterConfig::default(), RouterConfig::default(),
local_model.migration_limit(), local_model.migration_limit(),
None, None,
prefill_load_estimator,
metrics, metrics,
)); ));
let discovery = distributed_runtime.discovery(); let discovery = distributed_runtime.discovery();
......
...@@ -33,7 +33,11 @@ pub async fn run( ...@@ -33,7 +33,11 @@ pub async fn run(
} }
let grpc_service = match engine_config { let grpc_service = match engine_config {
EngineConfig::Dynamic { ref model, .. } => { EngineConfig::Dynamic {
ref model,
ref prefill_load_estimator,
..
} => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
...@@ -48,6 +52,7 @@ pub async fn run( ...@@ -48,6 +52,7 @@ pub async fn run(
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
namespace_filter, namespace_filter,
prefill_load_estimator.clone(),
) )
.await?; .await?;
grpc_service grpc_service
...@@ -111,6 +116,7 @@ async fn run_watcher( ...@@ -111,6 +116,7 @@ async fn run_watcher(
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
namespace_filter: NamespaceFilter, namespace_filter: NamespaceFilter,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode) // Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new()); let metrics = Arc::new(Metrics::new());
...@@ -120,6 +126,7 @@ async fn run_watcher( ...@@ -120,6 +126,7 @@ async fn run_watcher(
router_config, router_config,
migration_limit, migration_limit,
None, None,
prefill_load_estimator,
metrics, metrics,
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
......
...@@ -67,6 +67,7 @@ pub async fn run( ...@@ -67,6 +67,7 @@ pub async fn run(
EngineConfig::Dynamic { EngineConfig::Dynamic {
ref model, ref model,
ref chat_engine_factory, ref chat_engine_factory,
ref prefill_load_estimator,
} => { } => {
// Pass the discovery client so the /health endpoint can query active instances // Pass the discovery client so the /health endpoint can query active instances
http_service_builder = http_service_builder =
...@@ -90,6 +91,7 @@ pub async fn run( ...@@ -90,6 +91,7 @@ pub async fn run(
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
chat_engine_factory.clone(), chat_engine_factory.clone(),
prefill_load_estimator.clone(),
) )
.await?; .await?;
http_service http_service
...@@ -167,6 +169,7 @@ async fn run_watcher( ...@@ -167,6 +169,7 @@ async fn run_watcher(
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new( let mut watch_obj = ModelWatcher::new(
runtime.clone(), runtime.clone(),
...@@ -174,6 +177,7 @@ async fn run_watcher( ...@@ -174,6 +177,7 @@ async fn run_watcher(
router_config, router_config,
migration_limit, migration_limit,
chat_engine_factory, chat_engine_factory,
prefill_load_estimator,
metrics.clone(), metrics.clone(),
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env}, config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError, indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT, protocols::KV_EVENT_SUBJECT,
protocols::{ protocols::{
BlockExtraInfo, BlockHashOptions, DpRank, RouterEvent, RouterRequest, RouterResponse, BlockExtraInfo, BlockHashOptions, DpRank, PrefillLoadHint, RouterEvent, RouterRequest,
TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
}, },
}; };
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -111,6 +113,7 @@ where ...@@ -111,6 +113,7 @@ where
scheduler: KvScheduler<Sel>, scheduler: KvScheduler<Sel>,
block_size: u32, block_size: u32,
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
cancellation_token: tokio_util::sync::CancellationToken, cancellation_token: tokio_util::sync::CancellationToken,
client: Client, client: Client,
is_eagle: bool, is_eagle: bool,
...@@ -128,6 +131,7 @@ where ...@@ -128,6 +131,7 @@ where
block_size: u32, block_size: u32,
selector: Sel, selector: Sel,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str, worker_type: &'static str,
model_name: Option<String>, model_name: Option<String>,
is_eagle: bool, is_eagle: bool,
...@@ -159,6 +163,7 @@ where ...@@ -159,6 +163,7 @@ where
workers_with_configs.clone(), workers_with_configs.clone(),
selector, selector,
&kv_router_config, &kv_router_config,
prefill_load_estimator.clone(),
worker_type, worker_type,
) )
.await?; .await?;
...@@ -184,6 +189,7 @@ where ...@@ -184,6 +189,7 @@ where
scheduler, scheduler,
block_size, block_size,
kv_router_config, kv_router_config,
prefill_load_estimator,
cancellation_token, cancellation_token,
client, client,
is_eagle, is_eagle,
...@@ -345,6 +351,8 @@ where ...@@ -345,6 +351,8 @@ where
let track_prefill_tokens = self let track_prefill_tokens = self
.kv_router_config .kv_router_config
.track_prefill_tokens(router_config_override); .track_prefill_tokens(router_config_override);
let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
if let Err(e) = self if let Err(e) = self
.scheduler .scheduler
...@@ -355,6 +363,7 @@ where ...@@ -355,6 +363,7 @@ where
overlap: overlap_blocks, overlap: overlap_blocks,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint,
worker, worker,
lora_name, lora_name,
}) })
...@@ -377,6 +386,42 @@ where ...@@ -377,6 +386,42 @@ where
self.scheduler.pending_count() self.scheduler.pending_count()
} }
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
}
}
/// Get the worker type for this router ("prefill" or "decode"). /// Get the worker type for this router ("prefill" or "decode").
/// Used for Prometheus metric labeling. /// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str { pub fn worker_type(&self) -> &'static str {
......
...@@ -6,6 +6,7 @@ use std::sync::{Arc, OnceLock}; ...@@ -6,6 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result; use anyhow::Result;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{ pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn, AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn,
...@@ -47,6 +48,7 @@ pub struct PrefillRouter { ...@@ -47,6 +48,7 @@ pub struct PrefillRouter {
cancel_token: CancellationToken, cancel_token: CancellationToken,
router_mode: RouterMode, router_mode: RouterMode,
enforce_disagg: bool, enforce_disagg: bool,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
/// Model name used to look up the worker monitor for prefill client registration /// Model name used to look up the worker monitor for prefill client registration
model_name: String, model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor /// Namespace used to look up the correct WorkerSet's worker monitor
......
This diff is collapsed.
This diff is collapsed.
...@@ -344,6 +344,7 @@ mod integration_tests { ...@@ -344,6 +344,7 @@ mod integration_tests {
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit 0, // migration_limit
None, None,
None,
service.state().metrics_clone(), service.state().metrics_clone(),
); );
// Start watching for model registrations via discovery interface // Start watching for model registrations via discovery interface
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment