Unverified Commit 5b19a39b authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: allow router to not track active blocks (prefill), and to not track...


feat: allow router to not track active blocks (prefill), and to not track cached blocks (decode) (#3135)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f78189d7
......@@ -31,6 +31,8 @@ The main KV-aware routing arguments:
>[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
## Architecture
......
......@@ -99,6 +99,13 @@ pub struct Flags {
#[arg(long)]
pub router_replica_sync: Option<bool>,
/// KV Router: Whether to track active blocks in the router for memory management.
/// When false, the router will not maintain state about which blocks are active,
/// reducing memory overhead but potentially affecting scheduling decisions.
/// Default: true
#[arg(long)]
pub router_track_active_blocks: Option<bool>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
......@@ -228,6 +235,7 @@ impl Flags {
self.router_temperature,
self.use_kv_events,
self.router_replica_sync,
self.router_track_active_blocks,
self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
......
......@@ -42,12 +42,13 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_snapshot_threshold=10000, router_reset_states=false))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=10000, router_reset_states=false))]
fn new(
overlap_score_weight: f64,
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
router_track_active_blocks: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Self {
......@@ -57,6 +58,7 @@ impl KvRouterConfig {
router_temperature,
use_kv_events,
router_replica_sync,
router_track_active_blocks,
router_snapshot_threshold,
router_reset_states,
..Default::default()
......
......@@ -23,7 +23,6 @@ use serde::{Deserialize, Serialize};
pub mod approx;
pub mod indexer;
pub mod metrics_aggregator;
pub mod prefill_counter;
pub mod protocols;
pub mod publisher;
pub mod recorder;
......@@ -102,6 +101,9 @@ pub struct KvRouterConfig {
pub router_replica_sync: bool,
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
......@@ -120,6 +122,7 @@ impl Default for KvRouterConfig {
router_temperature: 0.0,
use_kv_events: true,
router_replica_sync: false,
router_track_active_blocks: true,
max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000),
router_reset_states: false,
......@@ -130,11 +133,13 @@ impl Default for KvRouterConfig {
impl KvRouterConfig {
/// Create a new KvRouterConfig with optional weight values.
/// If a weight is None, the default value will be used.
#[allow(clippy::too_many_arguments)]
pub fn new(
overlap_score_weight: Option<f64>,
temperature: Option<f64>,
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
......@@ -145,6 +150,8 @@ impl KvRouterConfig {
router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold
......@@ -157,8 +164,17 @@ impl KvRouterConfig {
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer),
/// Predicts the cached blocks based on requests on a TTL basis.
/// Currently does not persist or snapshot states (WIP to enable that).
ApproxKvIndexer(ApproxKvIndexer),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
}
impl Indexer {
......@@ -169,6 +185,10 @@ impl Indexer {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
}),
}
}
......@@ -176,6 +196,11 @@ impl Indexer {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
}
......@@ -189,6 +214,8 @@ pub struct KvRouter {
scheduler: KvScheduler,
block_size: u32,
kv_router_config: KvRouterConfig,
}
impl KvRouter {
......@@ -234,7 +261,10 @@ impl KvRouter {
.await?;
let runtime_configs_rx = runtime_configs_watcher.receiver();
let indexer = if kv_router_config.use_kv_events {
let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
......@@ -257,6 +287,7 @@ impl KvRouter {
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
consumer_uuid.clone(),
)
.await?;
......@@ -282,6 +313,7 @@ impl KvRouter {
indexer,
scheduler,
block_size,
kv_router_config,
})
}
......@@ -302,12 +334,25 @@ impl KvRouter {
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
// Determine who needs seq_hashes
let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
// Optimize cloning: only clone if both need it, otherwise move
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (approx_indexer_needs_it, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
(false, false) => (None, None),
};
let best_worker_id = self
.scheduler
.schedule(
context_id.to_string(),
isl_tokens,
seq_hashes.clone(),
maybe_seq_hashes_2,
overlap_scores.clone(),
router_config_override,
update_states,
......@@ -316,7 +361,7 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
.await
.unwrap();
};
......@@ -337,13 +382,16 @@ impl KvRouter {
worker_id: i64,
) {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});
self.scheduler
.add_request(
request_id,
seq_hashes,
maybe_seq_hashes,
isl_tokens,
overlap_blocks,
worker_id,
......@@ -351,11 +399,11 @@ impl KvRouter {
.await;
}
pub async fn mark_prefill_completed(&self, request_id: &str) {
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
self.scheduler.mark_prefill_completed(request_id).await
}
pub async fn free(&self, request_id: &str) {
pub async fn free(&self, request_id: &str) -> Result<()> {
self.scheduler.free(request_id).await
}
......@@ -367,12 +415,16 @@ impl KvRouter {
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});
Ok(self
.scheduler
.get_potential_loads(seq_hashes, isl_tokens, overlap_scores)
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
.await)
}
......@@ -404,14 +456,12 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
overlap_blocks,
}
}
RouterRequest::MarkPrefill => {
self.mark_prefill_completed(&context_id).await;
RouterResponse::PrefillMarked { success: true }
}
RouterRequest::MarkFree => {
self.free(&context_id).await;
RouterResponse::FreeMarked { success: true }
}
RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
success: self.mark_prefill_completed(&context_id).await.is_ok(),
},
RouterRequest::MarkFree => RouterResponse::FreeMarked {
success: self.free(&context_id).await.is_ok(),
},
};
let response = Annotated::from_data(response);
......@@ -541,7 +591,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! {
if let Some(first_item) = response_stream.next().await {
chooser.mark_prefill_completed(&context_id).await;
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
}
yield first_item;
}
......@@ -549,7 +601,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
yield item;
}
chooser.free(&context_id).await;
if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e:?}");
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::sync::Arc;
use uuid::Uuid;
use super::protocols::{PrefillEvent, PrefillEventData};
use crate::kv_router::PREFILL_SUBJECT;
use dashmap::DashMap;
use std::collections::HashMap;
use std::hash::Hash;
pub fn get_snapshot<K, V>(state: &DashMap<K, V>) -> HashMap<K, V>
where
K: Clone + Hash + Eq,
V: Copy,
{
state
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect()
}
#[derive(Default)]
struct PrefillCounterState {
tokens_map: HashMap<String, usize>, // Plain HashMap
running_sum: usize, // Plain usize
}
impl PrefillCounterState {
fn insert(&mut self, key: String, value: usize) -> Option<usize> {
// Takes &mut self
let old_value = self.tokens_map.insert(key, value);
if let Some(old) = old_value {
self.running_sum -= old;
self.running_sum += value;
} else {
self.running_sum += value;
}
old_value
}
fn remove(&mut self, key: &str) -> Option<usize> {
// Takes &mut self
let removed = self.tokens_map.remove(key);
if let Some(value) = removed {
self.running_sum -= value;
}
removed
}
fn running_sum(&self) -> usize {
self.running_sum
}
}
/// A counter that tracks pending prefill tokens for each request.
///
/// This struct maintains a local hashmap of request_id to token count,
/// and a running sum of all tokens. It no longer handles its own subscriptions.
#[derive(Default)] // Removed Clone
pub struct PrefillCounter {
state: PrefillCounterState, // No Arc, direct ownership
}
impl PrefillCounter {
// Internal methods for direct state manipulation (no publishing)
fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option<usize> {
// Takes &mut self
self.state.insert(request_id, tokens)
}
fn remove_direct(&mut self, request_id: &str) -> Option<usize> {
// Takes &mut self
self.state.remove(request_id)
}
#[allow(dead_code)]
fn update_direct(&mut self, request_id: String, new_tokens: usize) {
// Takes &mut self
if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() {
let delta = new_tokens as isize - old_tokens as isize;
self.state.running_sum = (self.state.running_sum as isize + delta) as usize;
self.state.tokens_map.insert(request_id, new_tokens);
}
}
pub fn get(&self, request_id: &str) -> Option<usize> {
self.state.tokens_map.get(request_id).copied()
}
pub fn running_sum(&self) -> usize {
self.state.running_sum()
}
pub fn len(&self) -> usize {
self.state.tokens_map.len()
}
pub fn is_empty(&self) -> bool {
self.state.tokens_map.is_empty()
}
}
/// A collection of PrefillCounters for multiple workers with centralized event handling
pub struct PrefillCountersMultiWorker {
pub counters: Arc<DashMap<i64, PrefillCounter>>,
pub request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
}
impl PrefillCountersMultiWorker {
// Helper function to handle new prefill logic
fn handle_new_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
worker_id: i64,
tokens: usize,
) {
// Check if request already exists
if let Some(existing_worker_id) = request_to_workers.get(request_id) {
tracing::warn!(
"Request {} already exists for worker {}, but trying to add to worker {}",
request_id,
*existing_worker_id,
worker_id
);
}
// Update mapping
request_to_workers.insert(request_id.to_string(), worker_id);
// Get or create counter and insert using get_mut
if let Some(mut counter) = counters.get_mut(&worker_id) {
counter.insert_direct(request_id.to_string(), tokens);
} else {
tracing::warn!(
"Worker {} does not exist, creating new PrefillCounter",
worker_id
);
let mut new_counter = PrefillCounter::default();
new_counter.insert_direct(request_id.to_string(), tokens);
counters.insert(worker_id, new_counter);
};
}
// Helper function to handle complete prefill logic
fn handle_complete_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
) -> Option<usize> {
// Remove from request_to_workers and get the worker_id
let Some((_, worker_id)) = request_to_workers.remove(request_id) else {
tracing::warn!("Request {} not found in request_to_workers", request_id);
return None;
};
// Use the worker_id from request_to_workers with get_mut
let Some(mut counter) = counters.get_mut(&worker_id) else {
tracing::warn!(
"No counter found for worker {} for request {}",
worker_id,
request_id
);
return None;
};
let removed_tokens = counter.remove_direct(request_id);
if removed_tokens.is_none() {
tracing::warn!("Attempted to remove non-existent request: {}", request_id);
}
removed_tokens
}
pub fn new(component: Component) -> Self {
let counters = Arc::new(DashMap::new());
let request_to_workers = Arc::new(DashMap::new());
let router_id = Uuid::new_v4();
let multi_worker = Self {
counters: counters.clone(),
request_to_workers: request_to_workers.clone(),
component: component.clone(),
router_id,
};
// Start the subscription loop
let counters_clone = counters.clone();
let request_to_workers_clone = request_to_workers.clone();
let component_clone = component.clone();
let router_id_clone = router_id;
tokio::spawn(async move {
if let Err(e) = Self::subscribe_to_events(
counters_clone,
request_to_workers_clone,
component_clone,
router_id_clone,
)
.await
{
tracing::error!("Error in prefill events subscription: {}", e);
}
});
multi_worker
}
/// Background task to subscribe to prefill events and update all counters
async fn subscribe_to_events(
counters: Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<PrefillEvent>(PREFILL_SUBJECT)
.await?;
while let Some(result) = subscriber.next().await {
let Ok(event) = result else {
tracing::error!("Error receiving prefill event: {}", result.unwrap_err());
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
match event.data {
PrefillEventData::NewPrefill(tokens) => {
Self::handle_new_prefill(
&counters,
&request_to_workers,
&event.request_id,
event.worker_id,
tokens,
);
}
PrefillEventData::UpdatePrefill(_) => {
// Do nothing for now
continue;
}
PrefillEventData::CompletePrefill => {
Self::handle_complete_prefill(
&counters,
&request_to_workers,
&event.request_id,
);
}
}
}
Ok(())
}
pub async fn add_prefill(
&self,
worker_id: i64,
request_id: String,
new_tokens: usize,
) -> Result<()> {
let event = PrefillEvent {
request_id: request_id.clone(),
worker_id,
data: PrefillEventData::NewPrefill(new_tokens),
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Self::handle_new_prefill(
&self.counters,
&self.request_to_workers,
&request_id,
worker_id,
new_tokens,
);
Ok(())
}
pub async fn remove_prefill(&self, request_id: &str) -> Result<Option<usize>> {
// Send the event first with dummy worker_id
let event = PrefillEvent {
request_id: request_id.to_string(),
worker_id: 0, // Dummy worker_id
data: PrefillEventData::CompletePrefill,
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Ok(Self::handle_complete_prefill(
&self.counters,
&self.request_to_workers,
request_id,
))
}
/// Get the running sums for all workers as a HashMap<i64, usize>
pub async fn running_sums(&self) -> HashMap<i64, usize> {
self.counters
.iter()
.map(|entry| (*entry.key(), entry.value().running_sum()))
.collect()
}
/// Get a specific counter's running sum
pub async fn get_worker_sum(&self, worker_id: i64) -> Option<usize> {
self.counters.get(&worker_id).map(|c| c.running_sum())
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime};
use std::sync::{Arc, Mutex};
use std::thread;
use tokio::time::Duration;
#[test]
#[ignore]
fn test_prefill_counter_multiworker_synchronization() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init();
let worker_id_1 = 1;
let worker_id_2 = 2;
let tokens_per_request = 100;
let requests_per_worker = 10;
// Shared state for collecting results from both threads
let results1 = Arc::new(Mutex::new(None));
let results2 = Arc::new(Mutex::new(None));
let final_results1 = Arc::new(Mutex::new(None));
let final_results2 = Arc::new(Mutex::new(None));
let results1_clone = results1.clone();
let results2_clone = results2.clone();
let final_results1_clone = final_results1.clone();
let final_results2_clone = final_results2.clone();
// Thread 1: First distributed runtime with multi_worker1
let handle1 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create first PrefillCountersMultiWorker instance
let multi_worker1 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Send requests to multi_worker1's worker
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1
.add_prefill(worker_id_1, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums1 = multi_worker1.running_sums().await;
*results1_clone.lock().unwrap() = Some(sums1);
// Wait for other thread to add its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker1
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums1 = multi_worker1.running_sums().await;
*final_results1_clone.lock().unwrap() = Some(final_sums1);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Thread 2: Second distributed runtime with multi_worker2
let handle2 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create second PrefillCountersMultiWorker instance
let multi_worker2 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Wait a bit to ensure multi_worker1 has started
tokio::time::sleep(Duration::from_millis(500)).await;
// Send requests to multi_worker2's worker
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2
.add_prefill(worker_id_2, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums2 = multi_worker2.running_sums().await;
*results2_clone.lock().unwrap() = Some(sums2);
// Wait for other thread to remove its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker2
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums2 = multi_worker2.running_sums().await;
*final_results2_clone.lock().unwrap() = Some(final_sums2);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Wait for both threads to complete
handle1.join().unwrap()?;
handle2.join().unwrap()?;
// Extract results
let sums1 = results1.lock().unwrap().take().unwrap();
let sums2 = results2.lock().unwrap().take().unwrap();
let final_sums1 = final_results1.lock().unwrap().take().unwrap();
let final_sums2 = final_results2.lock().unwrap().take().unwrap();
// Verify both multi-workers see all requests
assert_eq!(
sums1.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 1's requests"
);
assert_eq!(
sums1.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 2's requests"
);
assert_eq!(
sums2.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 1's requests"
);
assert_eq!(
sums2.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 2's requests"
);
// Verify both multi-workers show zero sums after removal
assert_eq!(
final_sums1.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 1"
);
assert_eq!(
final_sums1.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 2"
);
assert_eq!(
final_sums2.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 1"
);
assert_eq!(
final_sums2.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 2"
);
Ok(())
}
}
......@@ -163,7 +163,7 @@ pub struct ActiveSequenceEvent {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveSequenceEventData {
AddRequest {
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
},
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
......@@ -56,7 +57,7 @@ pub struct SchedulingResponse {
pub struct SchedulingRequest {
pub request_id: String,
pub token_seq: Vec<SequenceHash>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>,
......@@ -96,6 +97,7 @@ impl KvScheduler {
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
......@@ -124,6 +126,7 @@ impl KvScheduler {
block_size as usize,
worker_ids,
replica_sync,
router_uuid,
));
// Spawn background task to monitor and update workers_with_configs
......@@ -240,20 +243,26 @@ impl KvScheduler {
};
request.respond(response);
// Only update the state if update_states is true
if request.update_states {
let _ = slots_clone
.add_request(
request.request_id,
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await;
// Skip state update if not requested
if !request.update_states {
continue;
}
continue;
let request_id = request.request_id;
if let Err(e) = slots_clone
.add_request(
request_id.clone(),
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await
{
tracing::warn!(
"Failed to add request {request_id} to local slot tracker: {e:?}"
);
}
}
Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update");
......@@ -283,7 +292,7 @@ impl KvScheduler {
&self,
request_id: String,
isl_tokens: usize,
token_seq: Vec<SequenceHash>,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
......@@ -316,7 +325,7 @@ impl KvScheduler {
pub async fn add_request(
&self,
request_id: String,
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
worker_id: i64,
......@@ -327,20 +336,19 @@ impl KvScheduler {
.await;
}
pub async fn mark_prefill_completed(&self, request_id: &str) {
let _ = self
.slots
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await;
.await
}
pub async fn free(&self, request_id: &str) {
let _ = self.slots.free(&request_id.to_string()).await;
pub async fn free(&self, request_id: &str) -> Result<()> {
self.slots.free(&request_id.to_string()).await
}
pub async fn get_potential_loads(
&self,
token_seq: Vec<SequenceHash>,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
......
......@@ -106,7 +106,7 @@ impl ActiveSequences {
fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else {
panic!("Cannot remove a block that does not exist.")
return;
};
// Remove the unique block if no more requests using it
......@@ -122,10 +122,15 @@ impl ActiveSequences {
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
) -> HashSet<RequestId> {
// Check for double-add and panic early
if self.active_seqs.contains_key(&request_id) {
panic!("Request {request_id} is already active. Cannot accept double-add.");
}
// Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry();
......@@ -134,12 +139,16 @@ impl ActiveSequences {
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
for block in &token_sequence {
self.add_block(request_id.clone(), block);
if let Some(sequence) = token_sequence {
for block in &sequence {
self.add_block(request_id.clone(), block);
}
self.active_seqs.insert(request_id.clone(), sequence);
} else {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
}
self.active_seqs.insert(request_id.clone(), token_sequence);
removed_requests
}
......@@ -160,11 +169,15 @@ impl ActiveSequences {
pub fn potential_blocks_and_tokens(
&self,
token_sequence: &[SequenceHash],
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
) -> (usize, usize) {
let potential_blocks = self.new_blocks(token_sequence) + self.active_blocks;
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks
} else {
self.active_blocks
};
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
}
......@@ -189,17 +202,19 @@ impl ActiveSequences {
self.expiry_requests.remove(request_id);
let Some(token_seq) = self.active_seqs.get(request_id) else {
tracing::warn!("Trying to free free non-existent request {request_id}");
return 0;
// Remove from active_seqs and get the token sequence
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks;
}
};
for block in token_seq.clone() {
for block in token_seq {
self.remove_block(request_id, &block)
}
self.active_seqs.remove(request_id).unwrap();
self.active_blocks
}
......@@ -230,7 +245,7 @@ impl ActiveSequences {
enum UpdateSequences {
AddRequest {
request_id: RequestId,
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
......@@ -250,7 +265,7 @@ enum UpdateSequences {
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
PotentialBlocksAndTokens {
token_sequence: Arc<Vec<SequenceHash>>,
token_sequence: Option<Arc<Vec<SequenceHash>>>,
isl: usize,
overlap: u32,
resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>,
......@@ -281,13 +296,21 @@ impl ActiveSequencesMultiWorker {
block_size: usize,
worker_ids: Vec<WorkerId>,
replica_sync: bool,
router_uuid: String,
) -> Self {
assert!(block_size > 1, "block_size must be greater than 1");
let senders = Arc::new(DashMap::new());
let handles = Arc::new(DashMap::new());
let request_to_worker = Arc::new(DashMap::new());
let router_id = Uuid::new_v4();
let router_id = Uuid::parse_str(&router_uuid).unwrap_or_else(|e| {
tracing::warn!(
"Failed to parse router UUID '{}': {}, using new UUID",
router_uuid,
e
);
Uuid::new_v4()
});
for worker_id in worker_ids {
// Create a child cancellation token from the component's runtime
......@@ -313,6 +336,7 @@ impl ActiveSequencesMultiWorker {
let request_to_worker_clone = request_to_worker.clone();
let component_clone = component.clone();
let router_id_clone = router_id;
let cancel_token = component.drt().runtime().child_token();
tokio::spawn(async move {
// NATS subscription loop
......@@ -321,6 +345,7 @@ impl ActiveSequencesMultiWorker {
request_to_worker_clone,
component_clone,
router_id_clone,
cancel_token,
)
.await
{
......@@ -389,7 +414,7 @@ impl ActiveSequencesMultiWorker {
resp_tx,
} => {
let potential_tokens = active_sequences.potential_blocks_and_tokens(
&token_sequence,
token_sequence.as_ref().map(|v| v.as_slice()),
isl,
overlap,
);
......@@ -432,67 +457,83 @@ impl ActiveSequencesMultiWorker {
request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
component: Component,
router_id: Uuid,
cancel_token: CancellationToken,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<ActiveSequenceEvent>(ACTIVE_SEQUENCES_SUBJECT)
.await?;
while let Some(result) = subscriber.next().await {
let Ok(event) = result else {
tracing::error!(
"Error receiving active sequence event: {}",
result.unwrap_err()
);
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
loop {
tokio::select! {
// Handle incoming events
result = subscriber.next() => {
let Some(result) = result else {
// Stream ended
break;
};
match &event.data {
ActiveSequenceEventData::AddRequest {
token_sequence,
isl,
overlap,
} => {
request_to_worker.insert(event.request_id.clone(), event.worker_id);
if let Some(sender) = senders.get(&event.worker_id) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel();
let _ = sender.send(UpdateSequences::AddRequest {
request_id: event.request_id.clone(),
token_sequence: token_sequence.clone(),
isl: *isl,
overlap: *overlap,
resp_tx,
});
} else {
tracing::warn!(
"Worker {} not found, cannot process AddRequest",
event.worker_id
let Ok(event) = result else {
tracing::error!(
"Error receiving active sequence event: {}",
result.unwrap_err()
);
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
}
ActiveSequenceEventData::Free => {
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker_id)
{
let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(),
});
match &event.data {
ActiveSequenceEventData::AddRequest {
token_sequence,
isl,
overlap,
} => {
request_to_worker.insert(event.request_id.clone(), event.worker_id);
if let Some(sender) = senders.get(&event.worker_id) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel();
let _ = sender.send(UpdateSequences::AddRequest {
request_id: event.request_id.clone(),
token_sequence: token_sequence.clone(),
isl: *isl,
overlap: *overlap,
resp_tx,
});
} else {
tracing::warn!(
"Worker {} not found, cannot process AddRequest",
event.worker_id
);
}
}
ActiveSequenceEventData::Free => {
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker_id)
{
let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(),
});
}
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker_id) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker_id)
{
let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(),
});
}
}
}
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker_id) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker_id)
{
let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(),
});
}
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled");
break;
}
}
}
......@@ -522,6 +563,10 @@ impl ActiveSequencesMultiWorker {
if let Some((_, handle)) = self.handles.remove(worker_id) {
handle.abort();
}
// Clean up request_to_worker mappings for this worker
self.request_to_worker
.retain(|_request_id, mapped_worker_id| *mapped_worker_id != *worker_id);
}
// Add new workers
......@@ -540,7 +585,7 @@ impl ActiveSequencesMultiWorker {
pub async fn add_request(
&self,
request_id: RequestId,
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
worker_id: WorkerId,
......@@ -743,13 +788,13 @@ impl ActiveSequencesMultiWorker {
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub async fn potential_blocks_and_tokens(
&self,
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = Arc::new(token_sequence);
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
......@@ -823,307 +868,273 @@ impl Drop for ActiveSequencesMultiWorker {
mod tests {
use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime};
use std::sync::{Arc, Mutex};
use std::thread;
use std::sync::Arc;
#[test]
#[tokio::test]
#[ignore]
fn test_multi_worker_block_sharing() -> Result<()> {
async fn test_multi_worker_cross_instance_sync() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init();
let block_size = 4; // arbitrary block size
// Shared state for collecting results from both threads
let active_tokens_after_add = Arc::new(Mutex::new(HashMap::new()));
let potential_blocks_result = Arc::new(Mutex::new(HashMap::new()));
let active_blocks_after_free = Arc::new(Mutex::new(HashMap::new()));
let active_tokens_after_free = Arc::new(Mutex::new(HashMap::new()));
let active_tokens_after_add_clone = active_tokens_after_add.clone();
let potential_blocks_result_clone = potential_blocks_result.clone();
let active_blocks_after_free_clone = active_blocks_after_free.clone();
let active_tokens_after_free_clone = active_tokens_after_free.clone();
// Clone again for the second thread
let active_tokens_after_add_clone2 = active_tokens_after_add.clone();
let potential_blocks_result_clone2 = potential_blocks_result.clone();
let active_blocks_after_free_clone2 = active_blocks_after_free.clone();
let active_tokens_after_free_clone2 = active_tokens_after_free.clone();
// Thread 1: First runtime with workers 0 and 1
let handle1 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and component with same names as thread 2
let namespace = distributed.namespace("test_multiworker_sequences")?;
let component = namespace
.component("sequences")?
.service_builder()
.create()
.await?;
// Create multi-worker sequence manager with workers 0 and 1
let worker_ids = vec![0, 1];
let seq_manager =
ActiveSequencesMultiWorker::new(component, block_size, worker_ids, true);
// Give some time for the subscription loop to start
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Add requests to workers
// Worker 0: sequence [0, 1, 2]
seq_manager
.add_request(
"request_0".to_string(),
vec![0, 1, 2],
12, // ISL (3 blocks * 4 block_size)
0, // no overlap
0, // worker_id
)
.await?;
// Worker 1: sequence [3, 4]
seq_manager
.add_request(
"request_1".to_string(),
vec![3, 4],
8, // ISL (2 blocks * 4 block_size)
0, // no overlap
1, // worker_id
)
.await?;
// Give some time for the commands to be processed and synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Get active tokens from workers 0 and 1
let tokens = seq_manager.active_tokens().await;
active_tokens_after_add_clone
.lock()
.unwrap()
.insert(0, tokens.get(&0).copied().unwrap_or(0));
active_tokens_after_add_clone
.lock()
.unwrap()
.insert(1, tokens.get(&1).copied().unwrap_or(0));
// Test potential blocks for sequence [0, 1]
let potential = seq_manager.potential_blocks(vec![0, 1]).await;
potential_blocks_result_clone
.lock()
.unwrap()
.insert(0, potential.get(&0).copied().unwrap_or(0));
potential_blocks_result_clone
.lock()
.unwrap()
.insert(1, potential.get(&1).copied().unwrap_or(0));
// Wait for second thread to process its requests
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Free requests from workers 0 and 1
seq_manager.free(&"request_0".to_string()).await?;
seq_manager.free(&"request_1".to_string()).await?;
// Give some time for the commands to be processed
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Get final active blocks and tokens
let blocks = seq_manager.active_blocks().await;
let tokens = seq_manager.active_tokens().await;
active_blocks_after_free_clone
.lock()
.unwrap()
.insert(0, blocks.get(&0).copied().unwrap_or(0));
active_blocks_after_free_clone
.lock()
.unwrap()
.insert(1, blocks.get(&1).copied().unwrap_or(0));
active_tokens_after_free_clone
.lock()
.unwrap()
.insert(0, tokens.get(&0).copied().unwrap_or(0));
active_tokens_after_free_clone
.lock()
.unwrap()
.insert(1, tokens.get(&1).copied().unwrap_or(0));
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Thread 2: Second runtime with worker 2
let handle2 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and component with same names as thread 1
let namespace = distributed.namespace("test_multiworker_sequences")?;
let component = namespace
.component("sequences")?
.service_builder()
.create()
.await?;
// Create multi-worker sequence manager with worker 2
let worker_ids = vec![2];
let seq_manager =
ActiveSequencesMultiWorker::new(component, block_size, worker_ids, true);
// Give some time for the subscription loop to start
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Wait a bit to ensure thread 1 has started
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Worker 2: sequence [0, 1, 2, 3]
seq_manager
.add_request(
"request_2".to_string(),
vec![0, 1, 2, 3],
16, // ISL (4 blocks * 4 block_size)
0, // no overlap
2, // worker_id
)
.await?;
// Give some time for the commands to be processed and synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Get active tokens from worker 2
let tokens = seq_manager.active_tokens().await;
active_tokens_after_add_clone2
.lock()
.unwrap()
.insert(2, tokens.get(&2).copied().unwrap_or(0));
// Test potential blocks for sequence [0, 1]
let potential = seq_manager.potential_blocks(vec![0, 1]).await;
potential_blocks_result_clone2
.lock()
.unwrap()
.insert(2, potential.get(&2).copied().unwrap_or(0));
// Wait for first thread to free its requests
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Free request from worker 2
seq_manager.free(&"request_2".to_string()).await?;
// Give some time for the commands to be processed
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Get final active blocks and tokens
let blocks = seq_manager.active_blocks().await;
let tokens = seq_manager.active_tokens().await;
active_blocks_after_free_clone2
.lock()
.unwrap()
.insert(2, blocks.get(&2).copied().unwrap_or(0));
active_tokens_after_free_clone2
.lock()
.unwrap()
.insert(2, tokens.get(&2).copied().unwrap_or(0));
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Create namespace and shared component for both seq_managers
let namespace = distributed.namespace("test_cross_instance_sync")?;
let component = namespace
.component("sequences")?
.service_builder()
.create()
.await?;
// Wait for both threads to complete
handle1.join().unwrap()?;
handle2.join().unwrap()?;
// Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works
let worker_ids = vec![0, 1, 2];
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
worker_ids.clone(),
true,
Uuid::new_v4().to_string(),
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
worker_ids,
true,
Uuid::new_v4().to_string(),
));
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// PHASE 1: Add requests using both seq_manager_1 and seq_manager_2
// Add request_0 to worker 0: sequence [0, 1, 2]
seq_manager_1
.add_request(
"request_0".to_string(),
Some(vec![0, 1, 2]),
12, // ISL (3 blocks * 4 block_size)
0, // no overlap
0, // worker_id
)
.await?;
// Extract results
let tokens_after_add = active_tokens_after_add.lock().unwrap();
let potential_blocks = potential_blocks_result.lock().unwrap();
let blocks_after_free = active_blocks_after_free.lock().unwrap();
let tokens_after_free = active_tokens_after_free.lock().unwrap();
// Add request_1 to worker 1: sequence [3, 4]
seq_manager_1
.add_request(
"request_1".to_string(),
Some(vec![3, 4]),
8, // ISL (2 blocks * 4 block_size)
0, // no overlap
1, // worker_id
)
.await?;
// Verify active tokens after adding requests
assert_eq!(
tokens_after_add[&0], 12,
"Worker 0 should have 12 active tokens"
);
assert_eq!(
tokens_after_add[&1], 8,
"Worker 1 should have 8 active tokens"
);
assert_eq!(
tokens_after_add[&2], 16,
"Worker 2 should have 16 active tokens"
);
// Add request_2 to worker 2: sequence [0, 1, 2, 3] using seq_manager_2
seq_manager_2
.add_request(
"request_2".to_string(),
Some(vec![0, 1, 2, 3]),
16, // ISL (4 blocks * 4 block_size)
0, // no overlap
2, // worker_id
)
.await?;
// Test potential blocks for sequence [0, 1]
// Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1])
assert_eq!(
potential_blocks[&0], 3,
"Worker 0 should have 3 potential blocks"
);
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Worker 1 should return 4 (has blocks 3, 4, would need to add blocks 0, 1)
assert_eq!(
potential_blocks[&1], 4,
"Worker 1 should have 4 potential blocks"
);
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let blocks_phase1 = seq_manager_1.active_blocks().await;
let tokens_phase1 = seq_manager_1.active_tokens().await;
// Worker 2 should return 4 (already has blocks 0, 1, 2, 3, so no new blocks needed for [0, 1])
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
assert_eq!(
potential_blocks[&2], 4,
"Worker 2 should have 4 potential blocks"
blocks_phase1[&0], 3,
"Worker 0 should have 3 active blocks (from request_0)"
);
// Verify active blocks are zero for all workers
assert_eq!(
blocks_after_free[&0], 0,
"Worker 0 should have 0 active blocks"
blocks_phase1[&1], 2,
"Worker 1 should have 2 active blocks (from request_1)"
);
assert_eq!(
blocks_after_free[&1], 0,
"Worker 1 should have 0 active blocks"
blocks_phase1[&2], 4,
"Worker 2 should have 4 active blocks (from request_2 added by seq_manager_2)"
);
assert_eq!(
blocks_after_free[&2], 0,
"Worker 2 should have 0 active blocks"
tokens_phase1[&0], 12,
"Worker 0 should have 12 active tokens"
);
// Verify active tokens are zero for all workers
assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
assert_eq!(
tokens_after_free[&0], 0,
"Worker 0 should have 0 active tokens after freeing all"
tokens_phase1[&2], 16,
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
// PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
// Free request_2 (which was added by seq_manager_2) using seq_manager_1
seq_manager_1.free(&"request_2".to_string()).await?;
// Free request_0 and request_1 (which were added by seq_manager_1) using seq_manager_2
seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty
let blocks_phase2 = seq_manager_2.active_blocks().await;
let tokens_phase2 = seq_manager_2.active_tokens().await;
// Verify phase 2 results - everything should be empty
for worker_id in 0..=2 {
assert_eq!(
blocks_phase2[&worker_id], 0,
"Worker {} should have 0 active blocks after all requests freed",
worker_id
);
assert_eq!(
tokens_phase2[&worker_id], 0,
"Worker {} should have 0 active tokens after all requests freed",
worker_id
);
}
Ok(())
}
#[tokio::test]
#[ignore]
async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init();
let block_size = 4; // arbitrary block size
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and shared component for both seq_managers
let namespace = distributed.namespace("test_no_token_seq_sync")?;
let component = namespace
.component("sequences")?
.service_builder()
.create()
.await?;
// Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works
let worker_ids = vec![0, 1, 2];
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
worker_ids.clone(),
true,
Uuid::new_v4().to_string(),
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
worker_ids,
true,
Uuid::new_v4().to_string(),
));
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// PHASE 1: Add requests (without token sequences) using both seq_managers
// Add request_0 to worker 0 with no token sequence
seq_manager_1
.add_request(
"request_0".to_string(),
None, // No token sequence
12, // ISL (12 tokens)
0, // no overlap
0, // worker_id
)
.await?;
// Add request_1 to worker 1 with no token sequence
seq_manager_1
.add_request(
"request_1".to_string(),
None, // No token sequence
8, // ISL (8 tokens)
0, // no overlap
1, // worker_id
)
.await?;
// Add request_2 to worker 2 with no token sequence using seq_manager_2
seq_manager_2
.add_request(
"request_2".to_string(),
None, // No token sequence
16, // ISL (16 tokens)
0, // no overlap
2, // worker_id
)
.await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let tokens_phase1 = seq_manager_1.active_tokens().await;
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
assert_eq!(
tokens_after_free[&1], 0,
"Worker 1 should have 0 active tokens after freeing all"
tokens_phase1[&0], 12,
"Worker 0 should have 12 active tokens"
);
assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
assert_eq!(
tokens_after_free[&2], 0,
"Worker 2 should have 0 active tokens after freeing all"
tokens_phase1[&2], 16,
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
// PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
// Mark prefill completed and free request_2 (which was added by seq_manager_2) using seq_manager_1
seq_manager_1
.mark_prefill_completed(&"request_2".to_string())
.await?;
seq_manager_1.free(&"request_2".to_string()).await?;
// Mark prefill completed and free requests 0 and 1 (which were added by seq_manager_1) using seq_manager_2
seq_manager_2
.mark_prefill_completed(&"request_0".to_string())
.await?;
seq_manager_2
.mark_prefill_completed(&"request_1".to_string())
.await?;
seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty
let tokens_phase2 = seq_manager_2.active_tokens().await;
// Verify phase 2 results - everything should be empty
for worker_id in 0..=2 {
assert_eq!(
tokens_phase2[&worker_id], 0,
"Worker {} should have 0 active tokens after all requests freed",
worker_id
);
}
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment