Unverified Commit 5b19a39b authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: allow router to not track active blocks (prefill), and to not track...


feat: allow router to not track active blocks (prefill), and to not track cached blocks (decode) (#3135)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f78189d7
...@@ -31,6 +31,8 @@ The main KV-aware routing arguments: ...@@ -31,6 +31,8 @@ The main KV-aware routing arguments:
>[!Note] >[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported. > State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
## Architecture ## Architecture
......
...@@ -99,6 +99,13 @@ pub struct Flags { ...@@ -99,6 +99,13 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub router_replica_sync: Option<bool>, pub router_replica_sync: Option<bool>,
/// KV Router: Whether to track active blocks in the router for memory management.
/// When false, the router will not maintain state about which blocks are active,
/// reducing memory overhead but potentially affecting scheduling decisions.
/// Default: true
#[arg(long)]
pub router_track_active_blocks: Option<bool>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model /// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4). /// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
...@@ -228,6 +235,7 @@ impl Flags { ...@@ -228,6 +235,7 @@ impl Flags {
self.router_temperature, self.router_temperature,
self.use_kv_events, self.use_kv_events,
self.router_replica_sync, self.router_replica_sync,
self.router_track_active_blocks,
self.max_num_batched_tokens, self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run) // defaulting below args (no longer maintaining new flags for dynamo-run)
None, None,
......
...@@ -42,12 +42,13 @@ impl KvRouterConfig { ...@@ -42,12 +42,13 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_snapshot_threshold=10000, router_reset_states=false))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=10000, router_reset_states=false))]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
router_temperature: f64, router_temperature: f64,
use_kv_events: bool, use_kv_events: bool,
router_replica_sync: bool, router_replica_sync: bool,
router_track_active_blocks: bool,
router_snapshot_threshold: Option<u32>, router_snapshot_threshold: Option<u32>,
router_reset_states: bool, router_reset_states: bool,
) -> Self { ) -> Self {
...@@ -57,6 +58,7 @@ impl KvRouterConfig { ...@@ -57,6 +58,7 @@ impl KvRouterConfig {
router_temperature, router_temperature,
use_kv_events, use_kv_events,
router_replica_sync, router_replica_sync,
router_track_active_blocks,
router_snapshot_threshold, router_snapshot_threshold,
router_reset_states, router_reset_states,
..Default::default() ..Default::default()
......
...@@ -23,7 +23,6 @@ use serde::{Deserialize, Serialize}; ...@@ -23,7 +23,6 @@ use serde::{Deserialize, Serialize};
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
pub mod metrics_aggregator; pub mod metrics_aggregator;
pub mod prefill_counter;
pub mod protocols; pub mod protocols;
pub mod publisher; pub mod publisher;
pub mod recorder; pub mod recorder;
...@@ -102,6 +101,9 @@ pub struct KvRouterConfig { ...@@ -102,6 +101,9 @@ pub struct KvRouterConfig {
pub router_replica_sync: bool, pub router_replica_sync: bool,
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
// TODO: this is not actually used for now // TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32, pub max_num_batched_tokens: u32,
...@@ -120,6 +122,7 @@ impl Default for KvRouterConfig { ...@@ -120,6 +122,7 @@ impl Default for KvRouterConfig {
router_temperature: 0.0, router_temperature: 0.0,
use_kv_events: true, use_kv_events: true,
router_replica_sync: false, router_replica_sync: false,
router_track_active_blocks: true,
max_num_batched_tokens: 8192, max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000), router_snapshot_threshold: Some(10000),
router_reset_states: false, router_reset_states: false,
...@@ -130,11 +133,13 @@ impl Default for KvRouterConfig { ...@@ -130,11 +133,13 @@ impl Default for KvRouterConfig {
impl KvRouterConfig { impl KvRouterConfig {
/// Create a new KvRouterConfig with optional weight values. /// Create a new KvRouterConfig with optional weight values.
/// If a weight is None, the default value will be used. /// If a weight is None, the default value will be used.
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
overlap_score_weight: Option<f64>, overlap_score_weight: Option<f64>,
temperature: Option<f64>, temperature: Option<f64>,
use_kv_events: Option<bool>, use_kv_events: Option<bool>,
replica_sync: Option<bool>, replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
max_num_batched_tokens: Option<u32>, max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>, router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>, router_reset_states: Option<bool>,
...@@ -145,6 +150,8 @@ impl KvRouterConfig { ...@@ -145,6 +150,8 @@ impl KvRouterConfig {
router_temperature: temperature.unwrap_or(default.router_temperature), router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events), use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync), router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
max_num_batched_tokens: max_num_batched_tokens max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens), .unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold router_snapshot_threshold: router_snapshot_threshold
...@@ -157,8 +164,17 @@ impl KvRouterConfig { ...@@ -157,8 +164,17 @@ impl KvRouterConfig {
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this // TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it // since both variants implement it
pub enum Indexer { pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer), KvIndexer(KvIndexer),
/// Predicts the cached blocks based on requests on a TTL basis.
/// Currently does not persist or snapshot states (WIP to enable that).
ApproxKvIndexer(ApproxKvIndexer), ApproxKvIndexer(ApproxKvIndexer),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
} }
impl Indexer { impl Indexer {
...@@ -169,6 +185,10 @@ impl Indexer { ...@@ -169,6 +185,10 @@ impl Indexer {
match self { match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await, Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await, Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
}),
} }
} }
...@@ -176,6 +196,11 @@ impl Indexer { ...@@ -176,6 +196,11 @@ impl Indexer {
match self { match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await, Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await, Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
} }
} }
} }
...@@ -189,6 +214,8 @@ pub struct KvRouter { ...@@ -189,6 +214,8 @@ pub struct KvRouter {
scheduler: KvScheduler, scheduler: KvScheduler,
block_size: u32, block_size: u32,
kv_router_config: KvRouterConfig,
} }
impl KvRouter { impl KvRouter {
...@@ -234,7 +261,10 @@ impl KvRouter { ...@@ -234,7 +261,10 @@ impl KvRouter {
.await?; .await?;
let runtime_configs_rx = runtime_configs_watcher.receiver(); let runtime_configs_rx = runtime_configs_watcher.receiver();
let indexer = if kv_router_config.use_kv_events { let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component); let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
Indexer::KvIndexer(KvIndexer::new( Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(), cancellation_token.clone(),
...@@ -257,6 +287,7 @@ impl KvRouter { ...@@ -257,6 +287,7 @@ impl KvRouter {
runtime_configs_rx, runtime_configs_rx,
selector, selector,
kv_router_config.router_replica_sync, kv_router_config.router_replica_sync,
consumer_uuid.clone(),
) )
.await?; .await?;
...@@ -282,6 +313,7 @@ impl KvRouter { ...@@ -282,6 +313,7 @@ impl KvRouter {
indexer, indexer,
scheduler, scheduler,
block_size, block_size,
kv_router_config,
}) })
} }
...@@ -302,12 +334,25 @@ impl KvRouter { ...@@ -302,12 +334,25 @@ impl KvRouter {
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
// Determine who needs seq_hashes
let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
// Optimize cloning: only clone if both need it, otherwise move
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (approx_indexer_needs_it, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
(false, false) => (None, None),
};
let best_worker_id = self let best_worker_id = self
.scheduler .scheduler
.schedule( .schedule(
context_id.to_string(), context_id.to_string(),
isl_tokens, isl_tokens,
seq_hashes.clone(), maybe_seq_hashes_2,
overlap_scores.clone(), overlap_scores.clone(),
router_config_override, router_config_override,
update_states, update_states,
...@@ -316,7 +361,7 @@ impl KvRouter { ...@@ -316,7 +361,7 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer indexer
.process_routing_decision(best_worker_id, block_hashes, seq_hashes) .process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
.await .await
.unwrap(); .unwrap();
}; };
...@@ -337,13 +382,16 @@ impl KvRouter { ...@@ -337,13 +382,16 @@ impl KvRouter {
worker_id: i64, worker_id: i64,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes); let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});
self.scheduler self.scheduler
.add_request( .add_request(
request_id, request_id,
seq_hashes, maybe_seq_hashes,
isl_tokens, isl_tokens,
overlap_blocks, overlap_blocks,
worker_id, worker_id,
...@@ -351,11 +399,11 @@ impl KvRouter { ...@@ -351,11 +399,11 @@ impl KvRouter {
.await; .await;
} }
pub async fn mark_prefill_completed(&self, request_id: &str) { pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
self.scheduler.mark_prefill_completed(request_id).await self.scheduler.mark_prefill_completed(request_id).await
} }
pub async fn free(&self, request_id: &str) { pub async fn free(&self, request_id: &str) -> Result<()> {
self.scheduler.free(request_id).await self.scheduler.free(request_id).await
} }
...@@ -367,12 +415,16 @@ impl KvRouter { ...@@ -367,12 +415,16 @@ impl KvRouter {
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> { pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});
Ok(self Ok(self
.scheduler .scheduler
.get_potential_loads(seq_hashes, isl_tokens, overlap_scores) .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
.await) .await)
} }
...@@ -404,14 +456,12 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -404,14 +456,12 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
overlap_blocks, overlap_blocks,
} }
} }
RouterRequest::MarkPrefill => { RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
self.mark_prefill_completed(&context_id).await; success: self.mark_prefill_completed(&context_id).await.is_ok(),
RouterResponse::PrefillMarked { success: true } },
} RouterRequest::MarkFree => RouterResponse::FreeMarked {
RouterRequest::MarkFree => { success: self.free(&context_id).await.is_ok(),
self.free(&context_id).await; },
RouterResponse::FreeMarked { success: true }
}
}; };
let response = Annotated::from_data(response); let response = Annotated::from_data(response);
...@@ -541,7 +591,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -541,7 +591,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
if let Some(first_item) = response_stream.next().await { if let Some(first_item) = response_stream.next().await {
chooser.mark_prefill_completed(&context_id).await; if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
}
yield first_item; yield first_item;
} }
...@@ -549,7 +601,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -549,7 +601,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
yield item; yield item;
} }
chooser.free(&context_id).await; if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e:?}");
}
}); });
Ok(ResponseStream::new(wrapped_stream, stream_context)) Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::sync::Arc;
use uuid::Uuid;
use super::protocols::{PrefillEvent, PrefillEventData};
use crate::kv_router::PREFILL_SUBJECT;
use dashmap::DashMap;
use std::collections::HashMap;
use std::hash::Hash;
pub fn get_snapshot<K, V>(state: &DashMap<K, V>) -> HashMap<K, V>
where
K: Clone + Hash + Eq,
V: Copy,
{
state
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect()
}
#[derive(Default)]
struct PrefillCounterState {
tokens_map: HashMap<String, usize>, // Plain HashMap
running_sum: usize, // Plain usize
}
impl PrefillCounterState {
fn insert(&mut self, key: String, value: usize) -> Option<usize> {
// Takes &mut self
let old_value = self.tokens_map.insert(key, value);
if let Some(old) = old_value {
self.running_sum -= old;
self.running_sum += value;
} else {
self.running_sum += value;
}
old_value
}
fn remove(&mut self, key: &str) -> Option<usize> {
// Takes &mut self
let removed = self.tokens_map.remove(key);
if let Some(value) = removed {
self.running_sum -= value;
}
removed
}
fn running_sum(&self) -> usize {
self.running_sum
}
}
/// A counter that tracks pending prefill tokens for each request.
///
/// This struct maintains a local hashmap of request_id to token count,
/// and a running sum of all tokens. It no longer handles its own subscriptions.
#[derive(Default)] // Removed Clone
pub struct PrefillCounter {
state: PrefillCounterState, // No Arc, direct ownership
}
impl PrefillCounter {
// Internal methods for direct state manipulation (no publishing)
fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option<usize> {
// Takes &mut self
self.state.insert(request_id, tokens)
}
fn remove_direct(&mut self, request_id: &str) -> Option<usize> {
// Takes &mut self
self.state.remove(request_id)
}
#[allow(dead_code)]
fn update_direct(&mut self, request_id: String, new_tokens: usize) {
// Takes &mut self
if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() {
let delta = new_tokens as isize - old_tokens as isize;
self.state.running_sum = (self.state.running_sum as isize + delta) as usize;
self.state.tokens_map.insert(request_id, new_tokens);
}
}
pub fn get(&self, request_id: &str) -> Option<usize> {
self.state.tokens_map.get(request_id).copied()
}
pub fn running_sum(&self) -> usize {
self.state.running_sum()
}
pub fn len(&self) -> usize {
self.state.tokens_map.len()
}
pub fn is_empty(&self) -> bool {
self.state.tokens_map.is_empty()
}
}
/// A collection of PrefillCounters for multiple workers with centralized event handling
pub struct PrefillCountersMultiWorker {
pub counters: Arc<DashMap<i64, PrefillCounter>>,
pub request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
}
impl PrefillCountersMultiWorker {
// Helper function to handle new prefill logic
fn handle_new_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
worker_id: i64,
tokens: usize,
) {
// Check if request already exists
if let Some(existing_worker_id) = request_to_workers.get(request_id) {
tracing::warn!(
"Request {} already exists for worker {}, but trying to add to worker {}",
request_id,
*existing_worker_id,
worker_id
);
}
// Update mapping
request_to_workers.insert(request_id.to_string(), worker_id);
// Get or create counter and insert using get_mut
if let Some(mut counter) = counters.get_mut(&worker_id) {
counter.insert_direct(request_id.to_string(), tokens);
} else {
tracing::warn!(
"Worker {} does not exist, creating new PrefillCounter",
worker_id
);
let mut new_counter = PrefillCounter::default();
new_counter.insert_direct(request_id.to_string(), tokens);
counters.insert(worker_id, new_counter);
};
}
// Helper function to handle complete prefill logic
fn handle_complete_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
) -> Option<usize> {
// Remove from request_to_workers and get the worker_id
let Some((_, worker_id)) = request_to_workers.remove(request_id) else {
tracing::warn!("Request {} not found in request_to_workers", request_id);
return None;
};
// Use the worker_id from request_to_workers with get_mut
let Some(mut counter) = counters.get_mut(&worker_id) else {
tracing::warn!(
"No counter found for worker {} for request {}",
worker_id,
request_id
);
return None;
};
let removed_tokens = counter.remove_direct(request_id);
if removed_tokens.is_none() {
tracing::warn!("Attempted to remove non-existent request: {}", request_id);
}
removed_tokens
}
pub fn new(component: Component) -> Self {
let counters = Arc::new(DashMap::new());
let request_to_workers = Arc::new(DashMap::new());
let router_id = Uuid::new_v4();
let multi_worker = Self {
counters: counters.clone(),
request_to_workers: request_to_workers.clone(),
component: component.clone(),
router_id,
};
// Start the subscription loop
let counters_clone = counters.clone();
let request_to_workers_clone = request_to_workers.clone();
let component_clone = component.clone();
let router_id_clone = router_id;
tokio::spawn(async move {
if let Err(e) = Self::subscribe_to_events(
counters_clone,
request_to_workers_clone,
component_clone,
router_id_clone,
)
.await
{
tracing::error!("Error in prefill events subscription: {}", e);
}
});
multi_worker
}
/// Background task to subscribe to prefill events and update all counters
async fn subscribe_to_events(
counters: Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<PrefillEvent>(PREFILL_SUBJECT)
.await?;
while let Some(result) = subscriber.next().await {
let Ok(event) = result else {
tracing::error!("Error receiving prefill event: {}", result.unwrap_err());
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
match event.data {
PrefillEventData::NewPrefill(tokens) => {
Self::handle_new_prefill(
&counters,
&request_to_workers,
&event.request_id,
event.worker_id,
tokens,
);
}
PrefillEventData::UpdatePrefill(_) => {
// Do nothing for now
continue;
}
PrefillEventData::CompletePrefill => {
Self::handle_complete_prefill(
&counters,
&request_to_workers,
&event.request_id,
);
}
}
}
Ok(())
}
pub async fn add_prefill(
&self,
worker_id: i64,
request_id: String,
new_tokens: usize,
) -> Result<()> {
let event = PrefillEvent {
request_id: request_id.clone(),
worker_id,
data: PrefillEventData::NewPrefill(new_tokens),
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Self::handle_new_prefill(
&self.counters,
&self.request_to_workers,
&request_id,
worker_id,
new_tokens,
);
Ok(())
}
pub async fn remove_prefill(&self, request_id: &str) -> Result<Option<usize>> {
// Send the event first with dummy worker_id
let event = PrefillEvent {
request_id: request_id.to_string(),
worker_id: 0, // Dummy worker_id
data: PrefillEventData::CompletePrefill,
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Ok(Self::handle_complete_prefill(
&self.counters,
&self.request_to_workers,
request_id,
))
}
/// Get the running sums for all workers as a HashMap<i64, usize>
pub async fn running_sums(&self) -> HashMap<i64, usize> {
self.counters
.iter()
.map(|entry| (*entry.key(), entry.value().running_sum()))
.collect()
}
/// Get a specific counter's running sum
pub async fn get_worker_sum(&self, worker_id: i64) -> Option<usize> {
self.counters.get(&worker_id).map(|c| c.running_sum())
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime};
use std::sync::{Arc, Mutex};
use std::thread;
use tokio::time::Duration;
#[test]
#[ignore]
fn test_prefill_counter_multiworker_synchronization() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init();
let worker_id_1 = 1;
let worker_id_2 = 2;
let tokens_per_request = 100;
let requests_per_worker = 10;
// Shared state for collecting results from both threads
let results1 = Arc::new(Mutex::new(None));
let results2 = Arc::new(Mutex::new(None));
let final_results1 = Arc::new(Mutex::new(None));
let final_results2 = Arc::new(Mutex::new(None));
let results1_clone = results1.clone();
let results2_clone = results2.clone();
let final_results1_clone = final_results1.clone();
let final_results2_clone = final_results2.clone();
// Thread 1: First distributed runtime with multi_worker1
let handle1 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create first PrefillCountersMultiWorker instance
let multi_worker1 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Send requests to multi_worker1's worker
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1
.add_prefill(worker_id_1, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums1 = multi_worker1.running_sums().await;
*results1_clone.lock().unwrap() = Some(sums1);
// Wait for other thread to add its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker1
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums1 = multi_worker1.running_sums().await;
*final_results1_clone.lock().unwrap() = Some(final_sums1);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Thread 2: Second distributed runtime with multi_worker2
let handle2 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create second PrefillCountersMultiWorker instance
let multi_worker2 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Wait a bit to ensure multi_worker1 has started
tokio::time::sleep(Duration::from_millis(500)).await;
// Send requests to multi_worker2's worker
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2
.add_prefill(worker_id_2, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums2 = multi_worker2.running_sums().await;
*results2_clone.lock().unwrap() = Some(sums2);
// Wait for other thread to remove its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker2
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums2 = multi_worker2.running_sums().await;
*final_results2_clone.lock().unwrap() = Some(final_sums2);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Wait for both threads to complete
handle1.join().unwrap()?;
handle2.join().unwrap()?;
// Extract results
let sums1 = results1.lock().unwrap().take().unwrap();
let sums2 = results2.lock().unwrap().take().unwrap();
let final_sums1 = final_results1.lock().unwrap().take().unwrap();
let final_sums2 = final_results2.lock().unwrap().take().unwrap();
// Verify both multi-workers see all requests
assert_eq!(
sums1.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 1's requests"
);
assert_eq!(
sums1.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 2's requests"
);
assert_eq!(
sums2.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 1's requests"
);
assert_eq!(
sums2.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 2's requests"
);
// Verify both multi-workers show zero sums after removal
assert_eq!(
final_sums1.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 1"
);
assert_eq!(
final_sums1.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 2"
);
assert_eq!(
final_sums2.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 1"
);
assert_eq!(
final_sums2.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 2"
);
Ok(())
}
}
...@@ -163,7 +163,7 @@ pub struct ActiveSequenceEvent { ...@@ -163,7 +163,7 @@ pub struct ActiveSequenceEvent {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveSequenceEventData { pub enum ActiveSequenceEventData {
AddRequest { AddRequest {
token_sequence: Vec<SequenceHash>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
}, },
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::{Component, Instance}; use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher; use dynamo_runtime::traits::events::EventPublisher;
...@@ -56,7 +57,7 @@ pub struct SchedulingResponse { ...@@ -56,7 +57,7 @@ pub struct SchedulingResponse {
pub struct SchedulingRequest { pub struct SchedulingRequest {
pub request_id: String, pub request_id: String,
pub token_seq: Vec<SequenceHash>, pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlaps: OverlapScores, pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>, pub decode_blocks: HashMap<i64, usize>,
...@@ -96,6 +97,7 @@ impl KvScheduler { ...@@ -96,6 +97,7 @@ impl KvScheduler {
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>, runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool, replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone(); let instances: Vec<Instance> = instances_rx.borrow().clone();
...@@ -124,6 +126,7 @@ impl KvScheduler { ...@@ -124,6 +126,7 @@ impl KvScheduler {
block_size as usize, block_size as usize,
worker_ids, worker_ids,
replica_sync, replica_sync,
router_uuid,
)); ));
// Spawn background task to monitor and update workers_with_configs // Spawn background task to monitor and update workers_with_configs
...@@ -240,20 +243,26 @@ impl KvScheduler { ...@@ -240,20 +243,26 @@ impl KvScheduler {
}; };
request.respond(response); request.respond(response);
// Only update the state if update_states is true // Skip state update if not requested
if request.update_states { if !request.update_states {
let _ = slots_clone continue;
.add_request(
request.request_id,
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await;
} }
continue; let request_id = request.request_id;
if let Err(e) = slots_clone
.add_request(
request_id.clone(),
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await
{
tracing::warn!(
"Failed to add request {request_id} to local slot tracker: {e:?}"
);
}
} }
Err(KvSchedulerError::NoEndpoints) => { Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update"); tracing::trace!("no endpoints available; waiting for endpoints update");
...@@ -283,7 +292,7 @@ impl KvScheduler { ...@@ -283,7 +292,7 @@ impl KvScheduler {
&self, &self,
request_id: String, request_id: String,
isl_tokens: usize, isl_tokens: usize,
token_seq: Vec<SequenceHash>, token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores, overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
...@@ -316,7 +325,7 @@ impl KvScheduler { ...@@ -316,7 +325,7 @@ impl KvScheduler {
pub async fn add_request( pub async fn add_request(
&self, &self,
request_id: String, request_id: String,
token_sequence: Vec<SequenceHash>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
worker_id: i64, worker_id: i64,
...@@ -327,20 +336,19 @@ impl KvScheduler { ...@@ -327,20 +336,19 @@ impl KvScheduler {
.await; .await;
} }
pub async fn mark_prefill_completed(&self, request_id: &str) { pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
let _ = self self.slots
.slots
.mark_prefill_completed(&request_id.to_string()) .mark_prefill_completed(&request_id.to_string())
.await; .await
} }
pub async fn free(&self, request_id: &str) { pub async fn free(&self, request_id: &str) -> Result<()> {
let _ = self.slots.free(&request_id.to_string()).await; self.slots.free(&request_id.to_string()).await
} }
pub async fn get_potential_loads( pub async fn get_potential_loads(
&self, &self,
token_seq: Vec<SequenceHash>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment