Unverified Commit 7aa8e0e6 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: move worker_monitor to the llm crate (#3667)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3d7f7a56
......@@ -31,6 +31,8 @@ The main KV-aware routing arguments:
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
- `--busy-threshold`: Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`.
>[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
>
......
......@@ -7,5 +7,8 @@ pub use model_manager::{ModelManager, ModelManagerError};
mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};
mod worker_monitor;
pub use worker_monitor::{KvWorkerMonitor, WorkerLoadState};
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "v1/kv_routers";
......@@ -407,7 +407,7 @@ impl ModelWatcher {
NvCreateEmbeddingRequest,
Annotated<NvCreateEmbeddingResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
client, self.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
......@@ -415,13 +415,12 @@ impl ModelWatcher {
.add_embeddings_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
let push_router =
PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(client, self.router_mode, None, None)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_chat_completions_model(card.name(), checksum, engine)?;
......@@ -431,7 +430,7 @@ impl ModelWatcher {
NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
client, self.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
......@@ -453,11 +452,11 @@ impl ModelWatcher {
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
client, self.router_mode, None, None
)
.await?;
// Note: Embeddings don't need KV routing complexity
// Note: Embeddings don't need KV routing complexity or load monitoring
let service_backend = ServiceBackend::from_engine(Arc::new(router));
// Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
......@@ -473,11 +472,12 @@ impl ModelWatcher {
.add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 5: Tensor + Tensor (non-LLM)
// No KV cache concepts - not an LLM model
let push_router = PushRouter::<
NvCreateTensorRequest,
Annotated<NvCreateTensorResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
client, self.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring
// system is not tied to KV cache concepts, which are LLM-specific. This would allow
// different types of workers to define their own load metrics and busy thresholds.
use crate::component::{Client, InstanceSource};
use crate::traits::DistributedRuntimeProvider;
use crate::traits::events::EventSubscriber;
use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent;
use crate::model_card::{self, ModelDeploymentCard};
use dynamo_runtime::component::Client;
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::watch;
use tokio_stream::StreamExt;
// Constants for monitoring configuration
const KV_METRICS_SUBJECT: &str = "kv_metrics";
// Internal structs for deserializing metrics events
#[derive(serde::Deserialize)]
struct LoadEvent {
worker_id: i64,
data: ForwardPassMetrics,
}
#[derive(serde::Deserialize)]
struct ForwardPassMetrics {
worker_stats: WorkerStats,
kv_stats: KvStats,
}
#[derive(serde::Deserialize)]
struct WorkerStats {
data_parallel_rank: Option<u32>,
}
#[derive(serde::Deserialize)]
struct KvStats {
kv_active_blocks: u64,
}
#[derive(serde::Deserialize, Clone)]
struct RuntimeConfig {
total_kv_blocks: Option<u64>,
data_parallel_size: u32,
}
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
pub struct WorkerLoadState {
......@@ -83,15 +50,15 @@ impl WorkerLoadState {
}
/// Worker monitor for tracking KV cache usage and busy states
pub struct WorkerMonitor {
pub struct KvWorkerMonitor {
client: Arc<Client>,
worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
busy_threshold: f64,
}
impl WorkerMonitor {
impl KvWorkerMonitor {
/// Create a new worker monitor with custom threshold
pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self {
pub fn new(client: Arc<Client>, busy_threshold: f64) -> Self {
Self {
client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())),
......@@ -103,9 +70,12 @@ impl WorkerMonitor {
pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
self.worker_load_states.clone()
}
}
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
/// Start background monitoring of worker KV cache usage
pub async fn start_monitoring(&self) -> anyhow::Result<()> {
async fn start_monitoring(&self) -> anyhow::Result<()> {
let endpoint = &self.client.endpoint;
let component = endpoint.component();
......@@ -114,19 +84,12 @@ impl WorkerMonitor {
return Ok(());
};
// WorkerMonitor is in the wrong crate. It deals with LLM things (KV) so it should be in
// dynamo-llm not dynamo-runtime.
// That means we cannot use ModelDeploymentCard, so use serde_json::Value for now .
// Watch for runtime config updates from model deployment cards
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
"v1/mdc/", // should be model_card::ROOT_PREFIX but wrong crate
model_card::ROOT_PATH,
key_extractors::lease_id,
|card: serde_json::Value| {
let runtime_config: Option<RuntimeConfig> = card
.get("runtime_config")
.and_then(|rc| serde_json::from_value(rc.clone()).ok());
runtime_config
},
|card: ModelDeploymentCard| Some(card.runtime_config),
component.drt().child_token(),
)
.await?;
......@@ -138,7 +101,7 @@ impl WorkerMonitor {
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
let cancellation_token = component.drt().child_token();
let busy_threshold = self.busy_threshold; // Capture threshold for the closure
let busy_threshold = self.busy_threshold;
// Spawn background monitoring task
tokio::spawn(async move {
......@@ -151,7 +114,7 @@ impl WorkerMonitor {
break;
}
// Handle runtime config updates - now receives full HashMap
// Handle runtime config updates
_ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone();
......@@ -163,7 +126,6 @@ impl WorkerMonitor {
let state = states.entry(*lease_id).or_default();
// Populate total_blocks for all dp_ranks (they share the same total)
// data_parallel_size defaults to 1 via serde in ModelRuntimeConfig
if let Some(total_blocks) = runtime_config.total_kv_blocks {
for dp_rank in 0..runtime_config.data_parallel_size {
state.kv_total_blocks.insert(dp_rank, total_blocks);
......
......@@ -281,11 +281,21 @@ where
let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card).into_operator();
// Create worker monitor only if busy_threshold is set
let worker_monitor = busy_threshold.map(|threshold| {
Arc::new(crate::discovery::KvWorkerMonitor::new(
Arc::new(client.clone()),
threshold,
)) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
});
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(),
router_mode,
busy_threshold,
worker_monitor,
)
.await?;
let service_backend = match router_mode {
......
......@@ -15,7 +15,7 @@ pub mod context;
pub mod error;
pub mod network;
pub use network::egress::addressed_router::{AddressedPushRouter, AddressedRequest};
pub use network::egress::push_router::{PushRouter, RouterMode};
pub use network::egress::push_router::{PushRouter, RouterMode, WorkerLoadMonitor};
pub mod registry;
pub use crate::engine::{
......
......@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
use crate::utils::worker_monitor::WorkerMonitor;
use crate::{
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
......@@ -29,6 +28,15 @@ use std::{
};
use tokio_stream::StreamExt;
/// Trait for monitoring worker load and determining busy state.
/// Implementations can define custom load metrics and busy thresholds.
#[async_trait]
pub trait WorkerLoadMonitor: Send + Sync {
/// Start background monitoring of worker load.
/// This should spawn background tasks that update the client's free instances.
async fn start_monitoring(&self) -> anyhow::Result<()>;
}
#[derive(Clone)]
pub struct PushRouter<T, U>
where
......@@ -54,9 +62,6 @@ where
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
/// Worker monitor for tracking KV cache usage
worker_monitor: Option<Arc<WorkerMonitor>>,
/// Threshold for determining when a worker is busy (0.0 to 1.0)
/// If None, busy detection is disabled
busy_threshold: Option<f64>,
......@@ -97,36 +102,30 @@ where
{
/// Create a new PushRouter without busy threshold (no busy detection)
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
Self::from_client_with_threshold(client, router_mode, None).await
Self::from_client_with_threshold(client, router_mode, None, None).await
}
/// Create a new PushRouter with optional busy threshold
/// Create a new PushRouter with optional busy threshold and worker load monitor
pub async fn from_client_with_threshold(
client: Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
// Create worker monitor only if we have a threshold and are in dynamic mode
let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) {
(Some(threshold), InstanceSource::Dynamic(_)) => {
let monitor = Arc::new(WorkerMonitor::new_with_threshold(
Arc::new(client.clone()),
threshold,
));
monitor.start_monitoring().await?;
Some(monitor)
}
_ => None,
};
// Start worker monitor if provided and in dynamic mode
if let Some(monitor) = worker_monitor.as_ref()
&& matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
{
monitor.start_monitoring().await?;
}
let router = PushRouter {
client: client.clone(),
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
worker_monitor,
busy_threshold,
_phantom: PhantomData,
};
......
......@@ -10,6 +10,5 @@ pub mod stream;
pub mod task;
pub mod tasks;
pub mod typed_prefix_watcher;
pub mod worker_monitor;
pub use graceful_shutdown::GracefulShutdownTracker;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment