Unverified Commit 5166a3dd authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Router replicas with state-sharing (#2264)

parent 10f4302d
...@@ -1857,6 +1857,7 @@ dependencies = [ ...@@ -1857,6 +1857,7 @@ dependencies = [
"chrono", "chrono",
"criterion", "criterion",
"cudarc", "cudarc",
"dashmap",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
"dialoguer", "dialoguer",
......
...@@ -112,6 +112,12 @@ def parse_args(): ...@@ -112,6 +112,12 @@ def parse_args():
help=" KV Router. Disable KV events.", help=" KV Router. Disable KV events.",
) )
parser.set_defaults(use_kv_events=True) parser.set_defaults(use_kv_events=True)
parser.add_argument(
"--router-replica-sync",
action="store_true",
default=False,
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.",
)
parser.add_argument( parser.add_argument(
"--static-endpoint", "--static-endpoint",
type=validate_static_endpoint, type=validate_static_endpoint,
...@@ -148,6 +154,7 @@ async def async_main(): ...@@ -148,6 +154,7 @@ async def async_main():
overlap_score_weight=flags.kv_overlap_score_weight, overlap_score_weight=flags.kv_overlap_score_weight,
router_temperature=flags.router_temperature, router_temperature=flags.router_temperature,
use_kv_events=flags.use_kv_events, use_kv_events=flags.use_kv_events,
router_replica_sync=flags.router_replica_sync,
) )
elif flags.router_mode == "random": elif flags.router_mode == "random":
router_mode = RouterMode.Random router_mode = RouterMode.Random
......
...@@ -84,7 +84,7 @@ use std::net::SocketAddr; ...@@ -84,7 +84,7 @@ use std::net::SocketAddr;
use std::time::Duration as StdDuration; use std::time::Duration as StdDuration;
use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics}; use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
use dynamo_llm::kv_router::scheduler::Endpoint; use dynamo_llm::kv_router::scoring::Endpoint;
use dynamo_llm::kv_router::scoring::ProcessedEndpoints; use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
use dynamo_runtime::{ use dynamo_runtime::{
......
...@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let selector = Box::new(CustomWorkerSelector::default()); let selector = Box::new(CustomWorkerSelector::default());
let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?; let router = KvRouter::new(component.clone(), args.block_size, Some(selector), None).await?;
let router = Ingress::for_engine(Arc::new(router))?; let router = Ingress::for_engine(Arc::new(router))?;
component component
......
...@@ -17,12 +17,13 @@ For performance testing, compare a typical workload with `--router-mode random|r ...@@ -17,12 +17,13 @@ For performance testing, compare a typical workload with `--router-mode random|r
The KV-aware routing arguments: The KV-aware routing arguments:
- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). - `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). Defaults to 1.
- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked. - `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 (default) recovers the deterministic behavior where the min logit is picked.
- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. - `--use-kv-events`/`--no-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true (default), then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
- `--router-replica-sync`: Enables state synchronization between multiple router replicas via NATS. Disabled by default, and can be enabled by passing the flag in. When enabled, router replicas share their view of KV cache distribution and active sequences, allowing all routers to make optimal routing decisions even when requests are distributed across multiple router instances. This improves fault tolerance and routing accuracy in multi-router deployments.
## Architecture ## Architecture
...@@ -45,6 +46,22 @@ We can then use the default routing methods exposed by the client class to send ...@@ -45,6 +46,22 @@ We can then use the default routing methods exposed by the client class to send
KV Cache routing uses direct routing with a special worker selection algorithm. KV Cache routing uses direct routing with a special worker selection algorithm.
## Serving Two Router Replicas
For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance.
To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend:
```bash
# Router replica 1
python -m dynamo.frontend --router-mode kv --port 8000 --router-replica-sync
# Router replica 2
python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync
```
When `--router-replica-sync` is enabled, the router replicas will communicate with each other via NATS to maintain consistent state across instances. This allows both routers to have a complete view of the KV cache distribution and make optimal routing decisions, even when requests are distributed across multiple router instances.
## Understanding KV Cache ## Understanding KV Cache
The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching). The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching).
...@@ -88,30 +105,46 @@ Further details can be found for: [TRT-LLM](https://developer.nvidia.com/blog/in ...@@ -88,30 +105,46 @@ Further details can be found for: [TRT-LLM](https://developer.nvidia.com/blog/in
| |
+------------------+------------------+ +------------------+------------------+
| | | | | |
| KV match: 15% | KV match: 50% | KV match: 75% | Cached: 2 blocks | Cached: 5 blocks | Cached: 8 blocks
| Prefill: 8 blks | Prefill: 5 blks | Prefill: 2 blks
| Decode: 10 blks | Decode: 7 blks | Decode: 9 blks
v v v v v v
+----------------+ +----------------+ +----------------+ +----------------+ +----------------+ +----------------+
| Worker 1 | | Worker 2 | | Worker 3 | | Worker 1 | | Worker 2 | | Worker 3 |
| (Load: 30%) | | (Load: 50%) | | (Load: 80%) |
+----------------+ +----------------+ +----------------+ +----------------+ +----------------+ +----------------+
``` ```
Load balancing in LLM serving becomes complex when enabling KV Cache reuse. While KV Cache reuse can save significant computation, if the routing strategy is not aware of the unique KV states of each worker we can: Load balancing in LLM serving becomes complex when enabling KV Cache reuse. While KV Cache reuse can save significant computation, if the routing strategy is not aware of the unique KV states of each worker we can:
- miss opportunities for KV Cache reuse if routing to the wrong node - miss opportunities for KV Cache reuse if routing to the "wrong" node
- get into an imbalanced state where a few workers are processing many requests, lowering throughput of entire system - get into an imbalanced state where a few workers are processing many requests, lowering throughput of entire system
The best way to solve these issues is for the router to have a global view of KV Cache and load. With this view, the router can use a cost function to score the workers and make decisions to maximize cache hits while keeping the system balanced and throughput high. The router uses a cost function that considers both the prefill cost (influenced by cached blocks) and the decode load to make optimal routing decisions:
### Cost Calculation
1. **Prefill blocks**: The number of tokens that need to be processed during prefill is predicted based on the request's input tokens and the cached blocks available on each worker. This is divided by the block size to get the effective "prefill blocks". This prediction is updated when the first output token is produced, signaling prefill completion.
In the above image, our cost function is (KV match - Load) so we select Worker 2 even though Worker 3 would offer the best KV match. 2. **Decode blocks**: The number of blocks needed during the decode phase is predicted based on the request's input tokens and the current active sequences on each worker. This is updated when the request is freed (blocks are dereferenced or freed).
- Worker 1 = (0.15 - 0.30) = -0.15
- **Worker 2 = (0.50 - 0.50) = 0** 3. **Cost formula**: `cost = overlap_score_weight * prefill_blocks + decode_blocks`
- Worker 3 = (0.75 - 0.80) = -0.05 - Lower cost is better
- The `overlap_score_weight` parameter controls the importance of cache hits vs. load balancing
- A higher weight prioritizes cache reuse (better TTFT) while a lower weight prioritizes load distribution (better ITL)
### Worker Selection
The router selects the worker with the lowest cost. When `router_temperature` is set to a non-zero value, the router uses softmax sampling on the normalized cost logits to introduce randomness in the selection, which can help with load distribution.
Example calculation with `overlap_score_weight = 1.0`:
- Worker 1: cost = 1.0 * 8 + 10 = 18
- **Worker 2: cost = 1.0 * 5 + 7 = 12** (selected - lowest cost)
- Worker 3: cost = 1.0 * 2 + 9 = 11
## Events ## Events
In Dynamo, we want to support KV Cache Routing and load balancing for many backends that have different implementations of KV Cache and record different metrics. To that end, we built a KVPublisher that can be plugged into any framework to publish KV Events and a WorkerMetricsPublisher that can publish Metric Events. In Dynamo, we support KV Cache Routing for many backends that have different implementations of KV Cache. To enable this, we built a KVPublisher that can be plugged into any framework to publish KV Events.
On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree and a KvMetricsAggregator which aggregates metric events by worker. On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree for tracking cached blocks across all workers.
```text ```text
+----------------+ +-----------------+ +----------------+ +-----------------+
...@@ -121,13 +154,8 @@ On the receiving side we have a KVIndexer which accepts events from the KVPublis ...@@ -121,13 +154,8 @@ On the receiving side we have a KVIndexer which accepts events from the KVPublis
| +------------+ | remove_kv_block() | | KVIndexer | | | +------------+ | remove_kv_block() | | KVIndexer | |
| |KVPublisher | |------------------------>| +-------------+ | | |KVPublisher | |------------------------>| +-------------+ |
| +------------+ | | | | +------------+ | | |
| | num_request_waiting | +--------------+| | | | |
| +------------+ | gpu_cache_usage_perc | |KvMetricsAggre|| +----------------+ +-----------------+
| |KvMetrics | |------------------------>| | gator ||
| |Publisher | | ... | +--------------+|
| +------------+ | +-----------------+
+----------------+
``` ```
### KVPublisher ### KVPublisher
...@@ -144,18 +172,15 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr ...@@ -144,18 +172,15 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr
The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks. The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks.
### WorkerMetricsPublisher ### Inter-Router Communication
We added a KvMetrics Publisher which sends the following metrics to the KvMetricsAggregator:
- num_requests_waiting In multi-router deployments, each router only observes a subset of requests. To maintain a consistent global view of active sequences and KV cache states, routers broadcast their local actions to other replicas through three synchronization events:
- gpu_cache_usage_perc
- gpu_prefix_cache_hit_rate 1. **AddRequest**: Published when assigning a request to a worker, containing the request ID, worker ID, token sequence blocks, and overlap score. This updates other routers' tracking of which blocks are in use.
- request_active_slots
- request_total_slots 2. **MarkPrefillCompleted**: Published when a request transitions from prefill to decode phase, signaling that prefill tokens should no longer count toward the worker's active prefill load.
- kv_active_blocks
- kv_total_blocks
Currently, the WorkerMetricsPublisher exists as a Python binding. 3. **Free**: Published when a request completes and its resources are released, allowing other routers to update their block reference counts.
### KvMetricsAggregator Each event includes a unique router ID to prevent processing of self-generated events. This asynchronous communication ensures all routers maintain synchronized KV cache state for optimal routing decisions despite handling different request streams.
The KvMetricsAggregator receives these metrics and aggregates them. It has a method `get_metrics` which returns an object of `AggregatedMetrics`.
...@@ -96,6 +96,12 @@ pub struct Flags { ...@@ -96,6 +96,12 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub use_kv_events: Option<bool>, pub use_kv_events: Option<bool>,
/// KV Router: Whether to enable replica synchronization across multiple router instances.
/// When true, routers will publish and subscribe to events to maintain consistent state.
/// Default: false
#[arg(long)]
pub router_replica_sync: Option<bool>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model /// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4). /// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
...@@ -223,6 +229,7 @@ impl Flags { ...@@ -223,6 +229,7 @@ impl Flags {
self.kv_overlap_score_weight, self.kv_overlap_score_weight,
self.router_temperature, self.router_temperature,
self.use_kv_events, self.use_kv_events,
self.router_replica_sync,
self.max_num_batched_tokens, self.max_num_batched_tokens,
), ),
) )
......
...@@ -1152,6 +1152,7 @@ dependencies = [ ...@@ -1152,6 +1152,7 @@ dependencies = [
"candle-core", "candle-core",
"chrono", "chrono",
"cudarc", "cudarc",
"dashmap",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
"dialoguer", "dialoguer",
......
...@@ -35,13 +35,19 @@ pub struct KvRouterConfig { ...@@ -35,13 +35,19 @@ pub struct KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false))]
fn new(overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool) -> Self { fn new(
overlap_score_weight: f64,
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
) -> Self {
KvRouterConfig { KvRouterConfig {
inner: RsKvRouterConfig { inner: RsKvRouterConfig {
overlap_score_weight, overlap_score_weight,
router_temperature, router_temperature,
use_kv_events, use_kv_events,
router_replica_sync,
..Default::default() ..Default::default()
}, },
} }
......
...@@ -85,6 +85,7 @@ derive-getters = "0.5" ...@@ -85,6 +85,7 @@ derive-getters = "0.5"
offset-allocator = "0.2" offset-allocator = "0.2"
regex = "1" regex = "1"
rayon = "1" rayon = "1"
dashmap = { version = "5.5.3" }
# input/text # input/text
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
......
...@@ -217,7 +217,7 @@ impl ModelManager { ...@@ -217,7 +217,7 @@ impl ModelManager {
component.clone(), component.clone(),
kv_cache_block_size, kv_cache_block_size,
Some(selector), Some(selector),
kv_router_config.unwrap_or_default().use_kv_events, kv_router_config,
) )
.await?; .await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
......
...@@ -15,11 +15,11 @@ use dynamo_runtime::{ ...@@ -15,11 +15,11 @@ use dynamo_runtime::{
protocols::annotated::Annotated, protocols::annotated::Annotated,
}; };
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use tokio::sync::Mutex;
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
pub mod metrics_aggregator; pub mod metrics_aggregator;
pub mod prefill_counter;
pub mod protocols; pub mod protocols;
pub mod publisher; pub mod publisher;
pub mod recorder; pub mod recorder;
...@@ -48,9 +48,18 @@ use dynamo_runtime::traits::events::EventSubscriber; ...@@ -48,9 +48,18 @@ use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public // [gluo TODO] shouldn't need to be public
// this should be discovered from the component // this should be discovered from the component
// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
// for metric publishing (push-based)
pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_EVENT_SUBJECT: &str = "kv_events";
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
/// A trait that users can implement to define custom selection logic /// A trait that users can implement to define custom selection logic
pub trait WorkerSelector { pub trait WorkerSelector {
...@@ -71,6 +80,8 @@ pub struct KvRouterConfig { ...@@ -71,6 +80,8 @@ pub struct KvRouterConfig {
pub use_kv_events: bool, pub use_kv_events: bool,
pub router_replica_sync: bool,
// TODO: this is not actually used for now // TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32, pub max_num_batched_tokens: u32,
...@@ -82,6 +93,7 @@ impl Default for KvRouterConfig { ...@@ -82,6 +93,7 @@ impl Default for KvRouterConfig {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
router_temperature: 0.0, router_temperature: 0.0,
use_kv_events: true, use_kv_events: true,
router_replica_sync: false,
max_num_batched_tokens: 8192, max_num_batched_tokens: 8192,
} }
} }
...@@ -94,6 +106,7 @@ impl KvRouterConfig { ...@@ -94,6 +106,7 @@ impl KvRouterConfig {
overlap_score_weight: Option<f64>, overlap_score_weight: Option<f64>,
temperature: Option<f64>, temperature: Option<f64>,
use_kv_events: Option<bool>, use_kv_events: Option<bool>,
replica_sync: Option<bool>,
max_num_batched_tokens: Option<u32>, max_num_batched_tokens: Option<u32>,
) -> Self { ) -> Self {
let default = Self::default(); let default = Self::default();
...@@ -101,6 +114,7 @@ impl KvRouterConfig { ...@@ -101,6 +114,7 @@ impl KvRouterConfig {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
router_temperature: temperature.unwrap_or(default.router_temperature), router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events), use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
max_num_batched_tokens: max_num_batched_tokens max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens), .unwrap_or(default.max_num_batched_tokens),
} }
...@@ -135,10 +149,6 @@ pub struct KvRouter { ...@@ -135,10 +149,6 @@ pub struct KvRouter {
scheduler: KvScheduler, scheduler: KvScheduler,
block_size: u32, block_size: u32,
// To ensure blocking reads / writes
// TODO: benchmark tradeoffs
find_best_match_mutex: Mutex<()>,
} }
impl KvRouter { impl KvRouter {
...@@ -146,8 +156,10 @@ impl KvRouter { ...@@ -146,8 +156,10 @@ impl KvRouter {
component: Component, component: Component,
block_size: u32, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
use_kv_events: bool, kv_router_config: Option<KvRouterConfig>,
) -> Result<Self> { ) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
let cancellation_token = component let cancellation_token = component
.drt() .drt()
.primary_lease() .primary_lease()
...@@ -164,7 +176,7 @@ impl KvRouter { ...@@ -164,7 +176,7 @@ impl KvRouter {
} }
}; };
let indexer = if use_kv_events { let indexer = if kv_router_config.use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
} else { } else {
// hard code 120 seconds for now // hard code 120 seconds for now
...@@ -176,10 +188,11 @@ impl KvRouter { ...@@ -176,10 +188,11 @@ impl KvRouter {
}; };
let scheduler = KvScheduler::start( let scheduler = KvScheduler::start(
component.namespace().clone(), component.clone(),
block_size, block_size,
instances_rx, instances_rx,
selector, selector,
kv_router_config.router_replica_sync,
) )
.await?; .await?;
...@@ -215,7 +228,6 @@ impl KvRouter { ...@@ -215,7 +228,6 @@ impl KvRouter {
indexer, indexer,
scheduler, scheduler,
block_size, block_size,
find_best_match_mutex: Mutex::new(()), // Add this
}) })
} }
...@@ -227,10 +239,6 @@ impl KvRouter { ...@@ -227,10 +239,6 @@ impl KvRouter {
context_id: &str, context_id: &str,
tokens: &[u32], tokens: &[u32],
) -> anyhow::Result<(i64, u32)> { ) -> anyhow::Result<(i64, u32)> {
// Acquire mutex to serialize access
// TODO: may as well make all the subroutines synchronous if benchmarking favors this
let _guard = self.find_best_match_mutex.lock().await;
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
...@@ -263,17 +271,14 @@ impl KvRouter { ...@@ -263,17 +271,14 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount)) Ok((best_worker_id, overlap_amount))
} }
/// Free all blocks associated with a request pub async fn mark_prefill_completed(&self, request_id: &str) {
pub async fn mark_prefill_completed(&self, request_id: &String) {
self.scheduler.mark_prefill_completed(request_id).await self.scheduler.mark_prefill_completed(request_id).await
} }
/// Free all blocks associated with a request pub async fn free(&self, request_id: &str) {
pub async fn free(&self, request_id: &String) {
self.scheduler.free(request_id).await self.scheduler.free(request_id).await
} }
/// Get the block size this router was configured with
pub fn block_size(&self) -> u32 { pub fn block_size(&self) -> u32 {
self.block_size self.block_size
} }
......
...@@ -18,7 +18,7 @@ use std::sync::Once; ...@@ -18,7 +18,7 @@ use std::sync::Once;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics}; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scoring::Endpoint;
use crate::kv_router::ProcessedEndpoints; use crate::kv_router::ProcessedEndpoints;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::sync::Arc;
use uuid::Uuid;
use super::protocols::{PrefillEvent, PrefillEventData};
use crate::kv_router::PREFILL_SUBJECT;
use dashmap::DashMap;
use std::collections::HashMap;
use std::hash::Hash;
pub fn get_snapshot<K, V>(state: &DashMap<K, V>) -> HashMap<K, V>
where
K: Clone + Hash + Eq,
V: Copy,
{
state
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect()
}
#[derive(Default)]
struct PrefillCounterState {
tokens_map: HashMap<String, usize>, // Plain HashMap
running_sum: usize, // Plain usize
}
impl PrefillCounterState {
fn insert(&mut self, key: String, value: usize) -> Option<usize> {
// Takes &mut self
let old_value = self.tokens_map.insert(key, value);
if let Some(old) = old_value {
self.running_sum -= old;
self.running_sum += value;
} else {
self.running_sum += value;
}
old_value
}
fn remove(&mut self, key: &str) -> Option<usize> {
// Takes &mut self
let removed = self.tokens_map.remove(key);
if let Some(value) = removed {
self.running_sum -= value;
}
removed
}
fn running_sum(&self) -> usize {
self.running_sum
}
}
/// A counter that tracks pending prefill tokens for each request.
///
/// This struct maintains a local hashmap of request_id to token count,
/// and a running sum of all tokens. It no longer handles its own subscriptions.
#[derive(Default)] // Removed Clone
pub struct PrefillCounter {
state: PrefillCounterState, // No Arc, direct ownership
}
impl PrefillCounter {
// Internal methods for direct state manipulation (no publishing)
fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option<usize> {
// Takes &mut self
self.state.insert(request_id, tokens)
}
fn remove_direct(&mut self, request_id: &str) -> Option<usize> {
// Takes &mut self
self.state.remove(request_id)
}
#[allow(dead_code)]
fn update_direct(&mut self, request_id: String, new_tokens: usize) {
// Takes &mut self
if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() {
let delta = new_tokens as isize - old_tokens as isize;
self.state.running_sum = (self.state.running_sum as isize + delta) as usize;
self.state.tokens_map.insert(request_id, new_tokens);
}
}
pub fn get(&self, request_id: &str) -> Option<usize> {
self.state.tokens_map.get(request_id).copied()
}
pub fn running_sum(&self) -> usize {
self.state.running_sum()
}
pub fn len(&self) -> usize {
self.state.tokens_map.len()
}
pub fn is_empty(&self) -> bool {
self.state.tokens_map.is_empty()
}
}
/// A collection of PrefillCounters for multiple workers with centralized event handling
pub struct PrefillCountersMultiWorker {
pub counters: Arc<DashMap<i64, PrefillCounter>>,
pub request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
}
impl PrefillCountersMultiWorker {
// Helper function to handle new prefill logic
fn handle_new_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
worker_id: i64,
tokens: usize,
) {
// Check if request already exists
if let Some(existing_worker_id) = request_to_workers.get(request_id) {
tracing::warn!(
"Request {} already exists for worker {}, but trying to add to worker {}",
request_id,
*existing_worker_id,
worker_id
);
}
// Update mapping
request_to_workers.insert(request_id.to_string(), worker_id);
// Get or create counter and insert using get_mut
if let Some(mut counter) = counters.get_mut(&worker_id) {
counter.insert_direct(request_id.to_string(), tokens);
} else {
tracing::warn!(
"Worker {} does not exist, creating new PrefillCounter",
worker_id
);
let mut new_counter = PrefillCounter::default();
new_counter.insert_direct(request_id.to_string(), tokens);
counters.insert(worker_id, new_counter);
};
}
// Helper function to handle complete prefill logic
fn handle_complete_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
) -> Option<usize> {
// Remove from request_to_workers and get the worker_id
let Some((_, worker_id)) = request_to_workers.remove(request_id) else {
tracing::warn!("Request {} not found in request_to_workers", request_id);
return None;
};
// Use the worker_id from request_to_workers with get_mut
let Some(mut counter) = counters.get_mut(&worker_id) else {
tracing::warn!(
"No counter found for worker {} for request {}",
worker_id,
request_id
);
return None;
};
let removed_tokens = counter.remove_direct(request_id);
if removed_tokens.is_none() {
tracing::warn!("Attempted to remove non-existent request: {}", request_id);
}
removed_tokens
}
pub fn new(component: Component) -> Self {
let counters = Arc::new(DashMap::new());
let request_to_workers = Arc::new(DashMap::new());
let router_id = Uuid::new_v4();
let multi_worker = Self {
counters: counters.clone(),
request_to_workers: request_to_workers.clone(),
component: component.clone(),
router_id,
};
// Start the subscription loop
let counters_clone = counters.clone();
let request_to_workers_clone = request_to_workers.clone();
let component_clone = component.clone();
let router_id_clone = router_id;
tokio::spawn(async move {
if let Err(e) = Self::subscribe_to_events(
counters_clone,
request_to_workers_clone,
component_clone,
router_id_clone,
)
.await
{
tracing::error!("Error in prefill events subscription: {}", e);
}
});
multi_worker
}
/// Background task to subscribe to prefill events and update all counters
async fn subscribe_to_events(
counters: Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<PrefillEvent>(PREFILL_SUBJECT)
.await?;
while let Some(result) = subscriber.next().await {
let Ok(event) = result else {
tracing::error!("Error receiving prefill event: {}", result.unwrap_err());
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
match event.data {
PrefillEventData::NewPrefill(tokens) => {
Self::handle_new_prefill(
&counters,
&request_to_workers,
&event.request_id,
event.worker_id,
tokens,
);
}
PrefillEventData::UpdatePrefill(_) => {
// Do nothing for now
continue;
}
PrefillEventData::CompletePrefill => {
Self::handle_complete_prefill(
&counters,
&request_to_workers,
&event.request_id,
);
}
}
}
Ok(())
}
pub async fn add_prefill(
&self,
worker_id: i64,
request_id: String,
new_tokens: usize,
) -> Result<()> {
let event = PrefillEvent {
request_id: request_id.clone(),
worker_id,
data: PrefillEventData::NewPrefill(new_tokens),
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Self::handle_new_prefill(
&self.counters,
&self.request_to_workers,
&request_id,
worker_id,
new_tokens,
);
Ok(())
}
pub async fn remove_prefill(&self, request_id: &str) -> Result<Option<usize>> {
// Send the event first with dummy worker_id
let event = PrefillEvent {
request_id: request_id.to_string(),
worker_id: 0, // Dummy worker_id
data: PrefillEventData::CompletePrefill,
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Ok(Self::handle_complete_prefill(
&self.counters,
&self.request_to_workers,
request_id,
))
}
/// Get the running sums for all workers as a HashMap<i64, usize>
pub async fn running_sums(&self) -> HashMap<i64, usize> {
self.counters
.iter()
.map(|entry| (*entry.key(), entry.value().running_sum()))
.collect()
}
/// Get a specific counter's running sum
pub async fn get_worker_sum(&self, worker_id: i64) -> Option<usize> {
self.counters.get(&worker_id).map(|c| c.running_sum())
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime};
use std::sync::{Arc, Mutex};
use std::thread;
use tokio::time::Duration;
#[test]
#[ignore]
fn test_prefill_counter_multiworker_synchronization() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init();
let worker_id_1 = 1;
let worker_id_2 = 2;
let tokens_per_request = 100;
let requests_per_worker = 10;
// Shared state for collecting results from both threads
let results1 = Arc::new(Mutex::new(None));
let results2 = Arc::new(Mutex::new(None));
let final_results1 = Arc::new(Mutex::new(None));
let final_results2 = Arc::new(Mutex::new(None));
let results1_clone = results1.clone();
let results2_clone = results2.clone();
let final_results1_clone = final_results1.clone();
let final_results2_clone = final_results2.clone();
// Thread 1: First distributed runtime with multi_worker1
let handle1 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create first PrefillCountersMultiWorker instance
let multi_worker1 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Send requests to multi_worker1's worker
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1
.add_prefill(worker_id_1, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums1 = multi_worker1.running_sums().await;
*results1_clone.lock().unwrap() = Some(sums1);
// Wait for other thread to add its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker1
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums1 = multi_worker1.running_sums().await;
*final_results1_clone.lock().unwrap() = Some(final_sums1);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Thread 2: Second distributed runtime with multi_worker2
let handle2 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create second PrefillCountersMultiWorker instance
let multi_worker2 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Wait a bit to ensure multi_worker1 has started
tokio::time::sleep(Duration::from_millis(500)).await;
// Send requests to multi_worker2's worker
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2
.add_prefill(worker_id_2, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums2 = multi_worker2.running_sums().await;
*results2_clone.lock().unwrap() = Some(sums2);
// Wait for other thread to remove its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker2
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums2 = multi_worker2.running_sums().await;
*final_results2_clone.lock().unwrap() = Some(final_sums2);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Wait for both threads to complete
handle1.join().unwrap()?;
handle2.join().unwrap()?;
// Extract results
let sums1 = results1.lock().unwrap().take().unwrap();
let sums2 = results2.lock().unwrap().take().unwrap();
let final_sums1 = final_results1.lock().unwrap().take().unwrap();
let final_sums2 = final_results2.lock().unwrap().take().unwrap();
// Verify both multi-workers see all requests
assert_eq!(
sums1.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 1's requests"
);
assert_eq!(
sums1.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 2's requests"
);
assert_eq!(
sums2.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 1's requests"
);
assert_eq!(
sums2.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 2's requests"
);
// Verify both multi-workers show zero sums after removal
assert_eq!(
final_sums1.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 1"
);
assert_eq!(
final_sums1.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 2"
);
assert_eq!(
final_sums2.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 1"
);
assert_eq!(
final_sums2.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 2"
);
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License"); use crate::tokens::{SequenceHash, Token};
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::tokens::Token;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RouterRequest { pub struct RouterRequest {
...@@ -128,6 +117,56 @@ impl From<i64> for ExternalSequenceBlockHash { ...@@ -128,6 +117,56 @@ impl From<i64> for ExternalSequenceBlockHash {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillEvent {
pub request_id: String,
pub worker_id: i64,
pub data: PrefillEventData,
pub router_id: Uuid,
}
/// Represents the different stages of prefilling tokens for a request.
///
/// Each variant contains a `usize` representing the number of tokens
/// that are pending prefill in the request.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum PrefillEventData {
NewPrefill(usize),
UpdatePrefill(usize),
CompletePrefill,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActiveSequenceEvent {
pub request_id: String,
pub worker_id: i64,
pub data: ActiveSequenceEventData,
pub router_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveSequenceEventData {
AddRequest {
token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32,
},
Free,
MarkPrefillCompleted,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActiveBlockEvent {
pub request_id: String,
pub data: ActiveBlockEventData,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveBlockEventData {
NewBlock(Vec<SequenceHash>),
FreeBlock,
}
/// Represents a collection of cache events and a shutdown flag. /// Represents a collection of cache events and a shutdown flag.
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheEvents { pub struct KvCacheEvents {
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
use crate::kv_router::{ use crate::kv_router::{
indexer::{compute_block_hash_for_seq, RouterEvent}, indexer::{compute_block_hash_for_seq, RouterEvent},
protocols::*, protocols::*,
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, scoring::LoadEvent,
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider};
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::{Component, Namespace},
pipeline::{ pipeline::{
network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream,
SingleIn, SingleIn,
...@@ -499,9 +500,18 @@ impl WorkerMetricsPublisher { ...@@ -499,9 +500,18 @@ impl WorkerMetricsPublisher {
pub async fn create_endpoint(&self, component: Component) -> Result<()> { pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone(); let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?; let handler = Ingress::for_engine(handler)?;
// let worker_id = component
// .drt()
// .primary_lease()
// .map(|lease| lease.id())
// .unwrap_or_else(|| {
// tracing::warn!("Component is static, assuming worker_id of 0");
// 0
// });
component component
.endpoint(KV_METRICS_ENDPOINT) .endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder() .endpoint_builder()
...@@ -513,13 +523,90 @@ impl WorkerMetricsPublisher { ...@@ -513,13 +523,90 @@ impl WorkerMetricsPublisher {
.start() .start()
.await .await
} }
/// Starts a background task to publish metrics over NATS
///
/// This task monitors metric changes (specifically kv_active_blocks and num_requests_waiting)
/// and publishes stable metrics to NATS after they've been unchanged for 1ms.
#[allow(dead_code)]
fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: i64) {
let nats_rx = self.rx.clone();
tokio::spawn(async move {
let mut rx = nats_rx;
let mut last_kv_active_blocks: Option<u64> = None;
let mut last_num_requests_waiting: Option<u64> = None;
let mut pending_publish: Option<Arc<ForwardPassMetrics>> = None;
let mut publish_timer =
Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0)));
publish_timer.as_mut().reset(tokio::time::Instant::now()); // Complete immediately
loop {
tokio::select! {
// Handle metrics changes
result = rx.changed() => {
if result.is_err() {
tracing::debug!(
"Metrics publisher sender dropped, stopping NATS background task"
);
break;
}
let metrics = rx.borrow_and_update().clone();
// Extract the values we care about
let current_kv_active_blocks = metrics.kv_stats.kv_active_blocks;
let current_num_requests_waiting =
metrics.worker_stats.num_requests_waiting;
// Check if these specific metrics have changed
let has_changed = match (last_kv_active_blocks, last_num_requests_waiting) {
(Some(last_kv), Some(last_requests)) => {
last_kv != current_kv_active_blocks
|| last_requests != current_num_requests_waiting
}
_ => true, // First time, consider it changed
};
// If load metrics changed, schedule a publish
if has_changed {
pending_publish = Some(metrics.clone());
last_kv_active_blocks = Some(current_kv_active_blocks);
last_num_requests_waiting = Some(current_num_requests_waiting);
// Start the 1ms timer
publish_timer.as_mut().reset(
tokio::time::Instant::now() + tokio::time::Duration::from_millis(1)
);
}
}
// Timer expired - publish if we have pending metrics
_ = &mut publish_timer => {
if let Some(metrics) = pending_publish.take() {
// Create LoadEvent wrapping the metrics
let load_event = LoadEvent {
worker_id,
data: (*metrics).clone(),
};
if let Err(e) =
namespace.publish(KV_METRICS_SUBJECT, &load_event).await
{
tracing::warn!("Failed to publish metrics over NATS: {}", e);
}
}
}
}
}
});
}
} }
struct KvLoadEndpoingHander { struct KvLoadEndpointHandler {
metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>, metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
} }
impl KvLoadEndpoingHander { impl KvLoadEndpointHandler {
pub fn new(metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>) -> Self { pub fn new(metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>) -> Self {
Self { metrics_rx } Self { metrics_rx }
} }
...@@ -527,7 +614,7 @@ impl KvLoadEndpoingHander { ...@@ -527,7 +614,7 @@ impl KvLoadEndpoingHander {
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<()>, ManyOut<Annotated<ForwardPassMetrics>>, Error> impl AsyncEngine<SingleIn<()>, ManyOut<Annotated<ForwardPassMetrics>>, Error>
for KvLoadEndpoingHander for KvLoadEndpointHandler
{ {
async fn generate( async fn generate(
&self, &self,
...@@ -880,3 +967,116 @@ mod test_exponential_backoff { ...@@ -880,3 +967,116 @@ mod test_exponential_backoff {
assert!(max_calculated <= MAX_BACKOFF_MS); assert!(max_calculated <= MAX_BACKOFF_MS);
} }
} }
#[cfg(test)]
mod test_worker_metrics_publisher {
use super::*;
use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats};
use dynamo_runtime::traits::events::EventSubscriber; // Add this import
use dynamo_runtime::{DistributedRuntime, Runtime};
use futures::StreamExt;
#[tokio::test]
#[ignore] // Mark as ignored as requested
async fn test_metrics_publishing_behavior() -> Result<()> {
// Set up runtime and namespace
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::from_settings(rt.clone()).await?;
let namespace = drt.namespace("test".to_string())?;
// Create a subscriber for the metrics events using subscribe_with_type
let mut subscriber = namespace
.subscribe_with_type::<LoadEvent>(KV_METRICS_SUBJECT)
.await
.unwrap();
// Create WorkerMetricsPublisher
let publisher = WorkerMetricsPublisher::new().unwrap();
let worker_id = 1234;
// Start NATS metrics publishing
publisher.start_nats_metrics_publishing(namespace.clone(), worker_id);
// Allow some time for the background task to start
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
// Test 1: Publish 10 different metrics with 0.5ms intervals
// Only the last one should be published after 1ms of stability
for i in 0..10 {
let metrics = Arc::new(ForwardPassMetrics {
kv_stats: KvStats {
kv_active_blocks: (i * 100) as u64, // Changing load metric
kv_total_blocks: 1000,
gpu_cache_usage_perc: 0.5,
gpu_prefix_cache_hit_rate: 0.8,
},
worker_stats: WorkerStats {
num_requests_waiting: (i * 10) as u64, // Changing load metric
data_parallel_rank: None,
request_active_slots: 50,
request_total_slots: 100,
},
spec_decode_stats: None,
});
publisher.publish(metrics).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
// Wait a bit more than 1ms to ensure the last metric is published
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
// Verify we receive exactly one event with the last metric values
let result =
tokio::time::timeout(tokio::time::Duration::from_millis(500), subscriber.next())
.await
.unwrap();
let event = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id);
assert_eq!(event.data.kv_stats.kv_active_blocks, 900); // Last value: 9 * 100
assert_eq!(event.data.worker_stats.num_requests_waiting, 90); // Last value: 9 * 10
// Ensure no more events are waiting
let no_msg =
tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await;
assert!(no_msg.is_err(), "Expected no more messages, but found one");
// Test 2: Publish 10 more metrics where everything changes EXCEPT the load metrics
for i in 0..10 {
let metrics = Arc::new(ForwardPassMetrics {
kv_stats: KvStats {
kv_active_blocks: 900, // Keep same as last published
kv_total_blocks: 1000 + (i * 100) as u64, // Change other metrics
gpu_cache_usage_perc: 0.3 + (i as f32 * 0.05), // Change other metrics
gpu_prefix_cache_hit_rate: 0.7 + (i as f32 * 0.01), // Change other metrics
},
worker_stats: WorkerStats {
num_requests_waiting: 90, // Keep same as last published
data_parallel_rank: None,
request_active_slots: 40 + (i * 5) as u64, // Change other metrics
request_total_slots: 100 + (i * 10) as u64, // Change other metrics
},
spec_decode_stats: None,
});
publisher.publish(metrics).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
// Wait to ensure no events are published
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
// Verify no events are received
let no_msg =
tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await;
assert!(
no_msg.is_err(),
"Expected no messages when load metrics don't change"
);
rt.shutdown();
Ok(())
}
}
This diff is collapsed.
...@@ -15,10 +15,39 @@ ...@@ -15,10 +15,39 @@
//! Scoring functions for the KV router. //! Scoring functions for the KV router.
use super::protocols::{ForwardPassMetrics, LoadMetrics};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use crate::kv_router::scheduler::Endpoint; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LoadEvent {
pub worker_id: i64,
pub data: ForwardPassMetrics,
}
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Endpoint {
pub name: String,
pub subject: String,
pub data: LoadMetrics,
}
impl Endpoint {
pub fn worker_id(&self) -> i64 {
i64::from_str_radix(
self.subject
.split("-")
.last()
.expect("invalid subject")
.to_string()
.as_str(),
16,
)
.expect("invalid worker id")
}
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment