Unverified Commit 7aa8e0e6 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: move worker_monitor to the llm crate (#3667)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3d7f7a56
...@@ -31,6 +31,8 @@ The main KV-aware routing arguments: ...@@ -31,6 +31,8 @@ The main KV-aware routing arguments:
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management. - `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
- `--busy-threshold`: Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`.
>[!Note] >[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported. > State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
> >
......
...@@ -7,5 +7,8 @@ pub use model_manager::{ModelManager, ModelManagerError}; ...@@ -7,5 +7,8 @@ pub use model_manager::{ModelManager, ModelManagerError};
mod watcher; mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher}; pub use watcher::{ModelUpdate, ModelWatcher};
mod worker_monitor;
pub use worker_monitor::{KvWorkerMonitor, WorkerLoadState};
/// The root etcd path for KV Router registrations /// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "v1/kv_routers"; pub const KV_ROUTERS_ROOT_PATH: &str = "v1/kv_routers";
...@@ -407,7 +407,7 @@ impl ModelWatcher { ...@@ -407,7 +407,7 @@ impl ModelWatcher {
NvCreateEmbeddingRequest, NvCreateEmbeddingRequest,
Annotated<NvCreateEmbeddingResponse>, Annotated<NvCreateEmbeddingResponse>,
>::from_client_with_threshold( >::from_client_with_threshold(
client, self.router_mode, self.busy_threshold client, self.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
...@@ -415,13 +415,12 @@ impl ModelWatcher { ...@@ -415,13 +415,12 @@ impl ModelWatcher {
.add_embeddings_model(card.name(), checksum, engine)?; .add_embeddings_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() { } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat // Case 3: Text + Chat
let push_router = PushRouter::< let push_router =
NvCreateChatCompletionRequest, PushRouter::<
Annotated<NvCreateChatCompletionStreamResponse>, NvCreateChatCompletionRequest,
>::from_client_with_threshold( Annotated<NvCreateChatCompletionStreamResponse>,
client, self.router_mode, self.busy_threshold >::from_client_with_threshold(client, self.router_mode, None, None)
) .await?;
.await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_chat_completions_model(card.name(), checksum, engine)?; .add_chat_completions_model(card.name(), checksum, engine)?;
...@@ -431,7 +430,7 @@ impl ModelWatcher { ...@@ -431,7 +430,7 @@ impl ModelWatcher {
NvCreateCompletionRequest, NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>, Annotated<NvCreateCompletionResponse>,
>::from_client_with_threshold( >::from_client_with_threshold(
client, self.router_mode, self.busy_threshold client, self.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
...@@ -453,11 +452,11 @@ impl ModelWatcher { ...@@ -453,11 +452,11 @@ impl ModelWatcher {
PreprocessedEmbeddingRequest, PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>, Annotated<EmbeddingsEngineOutput>,
>::from_client_with_threshold( >::from_client_with_threshold(
client, self.router_mode, self.busy_threshold client, self.router_mode, None, None
) )
.await?; .await?;
// Note: Embeddings don't need KV routing complexity // Note: Embeddings don't need KV routing complexity or load monitoring
let service_backend = ServiceBackend::from_engine(Arc::new(router)); let service_backend = ServiceBackend::from_engine(Arc::new(router));
// Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend // Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
...@@ -473,11 +472,12 @@ impl ModelWatcher { ...@@ -473,11 +472,12 @@ impl ModelWatcher {
.add_embeddings_model(card.name(), checksum, embedding_engine)?; .add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() { } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 5: Tensor + Tensor (non-LLM) // Case 5: Tensor + Tensor (non-LLM)
// No KV cache concepts - not an LLM model
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateTensorRequest, NvCreateTensorRequest,
Annotated<NvCreateTensorResponse>, Annotated<NvCreateTensorResponse>,
>::from_client_with_threshold( >::from_client_with_threshold(
client, self.router_mode, self.busy_threshold client, self.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring use crate::kv_router::KV_METRICS_SUBJECT;
// system is not tied to KV cache concepts, which are LLM-specific. This would allow use crate::kv_router::scoring::LoadEvent;
// different types of workers to define their own load metrics and busy thresholds. use crate::model_card::{self, ModelDeploymentCard};
use dynamo_runtime::component::Client;
use crate::component::{Client, InstanceSource}; use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use crate::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use crate::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction}; use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tokio::sync::watch;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
// Constants for monitoring configuration
const KV_METRICS_SUBJECT: &str = "kv_metrics";
// Internal structs for deserializing metrics events
#[derive(serde::Deserialize)]
struct LoadEvent {
worker_id: i64,
data: ForwardPassMetrics,
}
#[derive(serde::Deserialize)]
struct ForwardPassMetrics {
worker_stats: WorkerStats,
kv_stats: KvStats,
}
#[derive(serde::Deserialize)]
struct WorkerStats {
data_parallel_rank: Option<u32>,
}
#[derive(serde::Deserialize)]
struct KvStats {
kv_active_blocks: u64,
}
#[derive(serde::Deserialize, Clone)]
struct RuntimeConfig {
total_kv_blocks: Option<u64>,
data_parallel_size: u32,
}
/// Worker load monitoring state per dp_rank /// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct WorkerLoadState { pub struct WorkerLoadState {
...@@ -83,15 +50,15 @@ impl WorkerLoadState { ...@@ -83,15 +50,15 @@ impl WorkerLoadState {
} }
/// Worker monitor for tracking KV cache usage and busy states /// Worker monitor for tracking KV cache usage and busy states
pub struct WorkerMonitor { pub struct KvWorkerMonitor {
client: Arc<Client>, client: Arc<Client>,
worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>, worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
busy_threshold: f64, busy_threshold: f64,
} }
impl WorkerMonitor { impl KvWorkerMonitor {
/// Create a new worker monitor with custom threshold /// Create a new worker monitor with custom threshold
pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self { pub fn new(client: Arc<Client>, busy_threshold: f64) -> Self {
Self { Self {
client, client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())), worker_load_states: Arc::new(RwLock::new(HashMap::new())),
...@@ -103,9 +70,12 @@ impl WorkerMonitor { ...@@ -103,9 +70,12 @@ impl WorkerMonitor {
pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> { pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
self.worker_load_states.clone() self.worker_load_states.clone()
} }
}
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
/// Start background monitoring of worker KV cache usage /// Start background monitoring of worker KV cache usage
pub async fn start_monitoring(&self) -> anyhow::Result<()> { async fn start_monitoring(&self) -> anyhow::Result<()> {
let endpoint = &self.client.endpoint; let endpoint = &self.client.endpoint;
let component = endpoint.component(); let component = endpoint.component();
...@@ -114,19 +84,12 @@ impl WorkerMonitor { ...@@ -114,19 +84,12 @@ impl WorkerMonitor {
return Ok(()); return Ok(());
}; };
// WorkerMonitor is in the wrong crate. It deals with LLM things (KV) so it should be in // Watch for runtime config updates from model deployment cards
// dynamo-llm not dynamo-runtime.
// That means we cannot use ModelDeploymentCard, so use serde_json::Value for now .
let runtime_configs_watcher = watch_prefix_with_extraction( let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client, etcd_client,
"v1/mdc/", // should be model_card::ROOT_PREFIX but wrong crate model_card::ROOT_PATH,
key_extractors::lease_id, key_extractors::lease_id,
|card: serde_json::Value| { |card: ModelDeploymentCard| Some(card.runtime_config),
let runtime_config: Option<RuntimeConfig> = card
.get("runtime_config")
.and_then(|rc| serde_json::from_value(rc.clone()).ok());
runtime_config
},
component.drt().child_token(), component.drt().child_token(),
) )
.await?; .await?;
...@@ -138,7 +101,7 @@ impl WorkerMonitor { ...@@ -138,7 +101,7 @@ impl WorkerMonitor {
let worker_load_states = self.worker_load_states.clone(); let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone(); let client = self.client.clone();
let cancellation_token = component.drt().child_token(); let cancellation_token = component.drt().child_token();
let busy_threshold = self.busy_threshold; // Capture threshold for the closure let busy_threshold = self.busy_threshold;
// Spawn background monitoring task // Spawn background monitoring task
tokio::spawn(async move { tokio::spawn(async move {
...@@ -151,7 +114,7 @@ impl WorkerMonitor { ...@@ -151,7 +114,7 @@ impl WorkerMonitor {
break; break;
} }
// Handle runtime config updates - now receives full HashMap // Handle runtime config updates
_ = config_events_rx.changed() => { _ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone(); let runtime_configs = config_events_rx.borrow().clone();
...@@ -163,7 +126,6 @@ impl WorkerMonitor { ...@@ -163,7 +126,6 @@ impl WorkerMonitor {
let state = states.entry(*lease_id).or_default(); let state = states.entry(*lease_id).or_default();
// Populate total_blocks for all dp_ranks (they share the same total) // Populate total_blocks for all dp_ranks (they share the same total)
// data_parallel_size defaults to 1 via serde in ModelRuntimeConfig
if let Some(total_blocks) = runtime_config.total_kv_blocks { if let Some(total_blocks) = runtime_config.total_kv_blocks {
for dp_rank in 0..runtime_config.data_parallel_size { for dp_rank in 0..runtime_config.data_parallel_size {
state.kv_total_blocks.insert(dp_rank, total_blocks); state.kv_total_blocks.insert(dp_rank, total_blocks);
......
...@@ -281,11 +281,21 @@ where ...@@ -281,11 +281,21 @@ where
let preprocessor_op = preprocessor.into_operator(); let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator(); let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card).into_operator(); let migration = Migration::from_mdc(card).into_operator();
// Create worker monitor only if busy_threshold is set
let worker_monitor = busy_threshold.map(|threshold| {
Arc::new(crate::discovery::KvWorkerMonitor::new(
Arc::new(client.clone()),
threshold,
)) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
});
let router = let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold( PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(), client.clone(),
router_mode, router_mode,
busy_threshold, busy_threshold,
worker_monitor,
) )
.await?; .await?;
let service_backend = match router_mode { let service_backend = match router_mode {
......
...@@ -15,7 +15,7 @@ pub mod context; ...@@ -15,7 +15,7 @@ pub mod context;
pub mod error; pub mod error;
pub mod network; pub mod network;
pub use network::egress::addressed_router::{AddressedPushRouter, AddressedRequest}; pub use network::egress::addressed_router::{AddressedPushRouter, AddressedRequest};
pub use network::egress::push_router::{PushRouter, RouterMode}; pub use network::egress::push_router::{PushRouter, RouterMode, WorkerLoadMonitor};
pub mod registry; pub mod registry;
pub use crate::engine::{ pub use crate::engine::{
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG}; use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
use crate::utils::worker_monitor::WorkerMonitor;
use crate::{ use crate::{
component::{Client, Endpoint, InstanceSource}, component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data}, engine::{AsyncEngine, Data},
...@@ -29,6 +28,15 @@ use std::{ ...@@ -29,6 +28,15 @@ use std::{
}; };
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
/// Trait for monitoring worker load and determining busy state.
/// Implementations can define custom load metrics and busy thresholds.
#[async_trait]
pub trait WorkerLoadMonitor: Send + Sync {
/// Start background monitoring of worker load.
/// This should spawn background tasks that update the client's free instances.
async fn start_monitoring(&self) -> anyhow::Result<()>;
}
#[derive(Clone)] #[derive(Clone)]
pub struct PushRouter<T, U> pub struct PushRouter<T, U>
where where
...@@ -54,9 +62,6 @@ where ...@@ -54,9 +62,6 @@ where
/// addresses it, then passes it to AddressedPushRouter which does the network traffic. /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>, addressed: Arc<AddressedPushRouter>,
/// Worker monitor for tracking KV cache usage
worker_monitor: Option<Arc<WorkerMonitor>>,
/// Threshold for determining when a worker is busy (0.0 to 1.0) /// Threshold for determining when a worker is busy (0.0 to 1.0)
/// If None, busy detection is disabled /// If None, busy detection is disabled
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
...@@ -97,36 +102,30 @@ where ...@@ -97,36 +102,30 @@ where
{ {
/// Create a new PushRouter without busy threshold (no busy detection) /// Create a new PushRouter without busy threshold (no busy detection)
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> { pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
Self::from_client_with_threshold(client, router_mode, None).await Self::from_client_with_threshold(client, router_mode, None, None).await
} }
/// Create a new PushRouter with optional busy threshold /// Create a new PushRouter with optional busy threshold and worker load monitor
pub async fn from_client_with_threshold( pub async fn from_client_with_threshold(
client: Client, client: Client,
router_mode: RouterMode, router_mode: RouterMode,
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?; let addressed = addressed_router(&client.endpoint).await?;
// Create worker monitor only if we have a threshold and are in dynamic mode // Start worker monitor if provided and in dynamic mode
let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) { if let Some(monitor) = worker_monitor.as_ref()
(Some(threshold), InstanceSource::Dynamic(_)) => { && matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
let monitor = Arc::new(WorkerMonitor::new_with_threshold( {
Arc::new(client.clone()), monitor.start_monitoring().await?;
threshold, }
));
monitor.start_monitoring().await?;
Some(monitor)
}
_ => None,
};
let router = PushRouter { let router = PushRouter {
client: client.clone(), client: client.clone(),
addressed, addressed,
router_mode, router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)), round_robin_counter: Arc::new(AtomicU64::new(0)),
worker_monitor,
busy_threshold, busy_threshold,
_phantom: PhantomData, _phantom: PhantomData,
}; };
......
...@@ -10,6 +10,5 @@ pub mod stream; ...@@ -10,6 +10,5 @@ pub mod stream;
pub mod task; pub mod task;
pub mod tasks; pub mod tasks;
pub mod typed_prefix_watcher; pub mod typed_prefix_watcher;
pub mod worker_monitor;
pub use graceful_shutdown::GracefulShutdownTracker; pub use graceful_shutdown::GracefulShutdownTracker;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment