Unverified Commit 5b19a39b authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: allow router to not track active blocks (prefill), and to not track...


feat: allow router to not track active blocks (prefill), and to not track cached blocks (decode) (#3135)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f78189d7
......@@ -31,6 +31,8 @@ The main KV-aware routing arguments:
>[!Note]
> State persistence is only available when KV events are enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not currently supported.
>
> When `--kv-overlap-score-weight` is set to 0 or `--no-kv-events` is set, no KvIndexer will be launched to drain and process KV events. It's recommended to disable your backend workers from relaying events through `KvEventPublisher` to avoid event accumulation in JetStream. WIP to enable disabling publishing of KV events completely in these cases.
## Architecture
......
......@@ -99,6 +99,13 @@ pub struct Flags {
#[arg(long)]
pub router_replica_sync: Option<bool>,
/// KV Router: Whether to track active blocks in the router for memory management.
/// When false, the router will not maintain state about which blocks are active,
/// reducing memory overhead but potentially affecting scheduling decisions.
/// Default: true
#[arg(long)]
pub router_track_active_blocks: Option<bool>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
......@@ -228,6 +235,7 @@ impl Flags {
self.router_temperature,
self.use_kv_events,
self.router_replica_sync,
self.router_track_active_blocks,
self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
......
......@@ -42,12 +42,13 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_snapshot_threshold=10000, router_reset_states=false))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=10000, router_reset_states=false))]
fn new(
overlap_score_weight: f64,
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
router_track_active_blocks: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Self {
......@@ -57,6 +58,7 @@ impl KvRouterConfig {
router_temperature,
use_kv_events,
router_replica_sync,
router_track_active_blocks,
router_snapshot_threshold,
router_reset_states,
..Default::default()
......
......@@ -23,7 +23,6 @@ use serde::{Deserialize, Serialize};
pub mod approx;
pub mod indexer;
pub mod metrics_aggregator;
pub mod prefill_counter;
pub mod protocols;
pub mod publisher;
pub mod recorder;
......@@ -102,6 +101,9 @@ pub struct KvRouterConfig {
pub router_replica_sync: bool,
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
......@@ -120,6 +122,7 @@ impl Default for KvRouterConfig {
router_temperature: 0.0,
use_kv_events: true,
router_replica_sync: false,
router_track_active_blocks: true,
max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000),
router_reset_states: false,
......@@ -130,11 +133,13 @@ impl Default for KvRouterConfig {
impl KvRouterConfig {
/// Create a new KvRouterConfig with optional weight values.
/// If a weight is None, the default value will be used.
#[allow(clippy::too_many_arguments)]
pub fn new(
overlap_score_weight: Option<f64>,
temperature: Option<f64>,
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
......@@ -145,6 +150,8 @@ impl KvRouterConfig {
router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold
......@@ -157,8 +164,17 @@ impl KvRouterConfig {
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
/// Updates itself based on KV events emitted by backend workers.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer),
/// Predicts the cached blocks based on requests on a TTL basis.
/// Currently does not persist or snapshot states (WIP to enable that).
ApproxKvIndexer(ApproxKvIndexer),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
}
impl Indexer {
......@@ -169,6 +185,10 @@ impl Indexer {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
}),
}
}
......@@ -176,6 +196,11 @@ impl Indexer {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
}
......@@ -189,6 +214,8 @@ pub struct KvRouter {
scheduler: KvScheduler,
block_size: u32,
kv_router_config: KvRouterConfig,
}
impl KvRouter {
......@@ -234,7 +261,10 @@ impl KvRouter {
.await?;
let runtime_configs_rx = runtime_configs_watcher.receiver();
let indexer = if kv_router_config.use_kv_events {
let indexer = if kv_router_config.overlap_score_weight == 0.0 {
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer::None
} else if kv_router_config.use_kv_events {
let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
Indexer::KvIndexer(KvIndexer::new(
cancellation_token.clone(),
......@@ -257,6 +287,7 @@ impl KvRouter {
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
consumer_uuid.clone(),
)
.await?;
......@@ -282,6 +313,7 @@ impl KvRouter {
indexer,
scheduler,
block_size,
kv_router_config,
})
}
......@@ -302,12 +334,25 @@ impl KvRouter {
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
// Determine who needs seq_hashes
let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
// Optimize cloning: only clone if both need it, otherwise move
let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
match (approx_indexer_needs_it, scheduler_needs_it) {
(true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
(true, false) => (Some(seq_hashes), None),
(false, true) => (None, Some(seq_hashes)),
(false, false) => (None, None),
};
let best_worker_id = self
.scheduler
.schedule(
context_id.to_string(),
isl_tokens,
seq_hashes.clone(),
maybe_seq_hashes_2,
overlap_scores.clone(),
router_config_override,
update_states,
......@@ -316,7 +361,7 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
.await
.unwrap();
};
......@@ -337,13 +382,16 @@ impl KvRouter {
worker_id: i64,
) {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});
self.scheduler
.add_request(
request_id,
seq_hashes,
maybe_seq_hashes,
isl_tokens,
overlap_blocks,
worker_id,
......@@ -351,11 +399,11 @@ impl KvRouter {
.await;
}
pub async fn mark_prefill_completed(&self, request_id: &str) {
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
self.scheduler.mark_prefill_completed(request_id).await
}
pub async fn free(&self, request_id: &str) {
pub async fn free(&self, request_id: &str) -> Result<()> {
self.scheduler.free(request_id).await
}
......@@ -367,12 +415,16 @@ impl KvRouter {
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
compute_seq_hash_for_block(&block_hashes)
});
Ok(self
.scheduler
.get_potential_loads(seq_hashes, isl_tokens, overlap_scores)
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
.await)
}
......@@ -404,14 +456,12 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
overlap_blocks,
}
}
RouterRequest::MarkPrefill => {
self.mark_prefill_completed(&context_id).await;
RouterResponse::PrefillMarked { success: true }
}
RouterRequest::MarkFree => {
self.free(&context_id).await;
RouterResponse::FreeMarked { success: true }
}
RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
success: self.mark_prefill_completed(&context_id).await.is_ok(),
},
RouterRequest::MarkFree => RouterResponse::FreeMarked {
success: self.free(&context_id).await.is_ok(),
},
};
let response = Annotated::from_data(response);
......@@ -541,7 +591,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let wrapped_stream = Box::pin(async_stream::stream! {
if let Some(first_item) = response_stream.next().await {
chooser.mark_prefill_completed(&context_id).await;
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
}
yield first_item;
}
......@@ -549,7 +601,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
yield item;
}
chooser.free(&context_id).await;
if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e:?}");
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::sync::Arc;
use uuid::Uuid;
use super::protocols::{PrefillEvent, PrefillEventData};
use crate::kv_router::PREFILL_SUBJECT;
use dashmap::DashMap;
use std::collections::HashMap;
use std::hash::Hash;
pub fn get_snapshot<K, V>(state: &DashMap<K, V>) -> HashMap<K, V>
where
K: Clone + Hash + Eq,
V: Copy,
{
state
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect()
}
#[derive(Default)]
struct PrefillCounterState {
tokens_map: HashMap<String, usize>, // Plain HashMap
running_sum: usize, // Plain usize
}
impl PrefillCounterState {
fn insert(&mut self, key: String, value: usize) -> Option<usize> {
// Takes &mut self
let old_value = self.tokens_map.insert(key, value);
if let Some(old) = old_value {
self.running_sum -= old;
self.running_sum += value;
} else {
self.running_sum += value;
}
old_value
}
fn remove(&mut self, key: &str) -> Option<usize> {
// Takes &mut self
let removed = self.tokens_map.remove(key);
if let Some(value) = removed {
self.running_sum -= value;
}
removed
}
fn running_sum(&self) -> usize {
self.running_sum
}
}
/// A counter that tracks pending prefill tokens for each request.
///
/// This struct maintains a local hashmap of request_id to token count,
/// and a running sum of all tokens. It no longer handles its own subscriptions.
#[derive(Default)] // Removed Clone
pub struct PrefillCounter {
state: PrefillCounterState, // No Arc, direct ownership
}
impl PrefillCounter {
// Internal methods for direct state manipulation (no publishing)
fn insert_direct(&mut self, request_id: String, tokens: usize) -> Option<usize> {
// Takes &mut self
self.state.insert(request_id, tokens)
}
fn remove_direct(&mut self, request_id: &str) -> Option<usize> {
// Takes &mut self
self.state.remove(request_id)
}
#[allow(dead_code)]
fn update_direct(&mut self, request_id: String, new_tokens: usize) {
// Takes &mut self
if let Some(old_tokens) = self.state.tokens_map.get(&request_id).copied() {
let delta = new_tokens as isize - old_tokens as isize;
self.state.running_sum = (self.state.running_sum as isize + delta) as usize;
self.state.tokens_map.insert(request_id, new_tokens);
}
}
pub fn get(&self, request_id: &str) -> Option<usize> {
self.state.tokens_map.get(request_id).copied()
}
pub fn running_sum(&self) -> usize {
self.state.running_sum()
}
pub fn len(&self) -> usize {
self.state.tokens_map.len()
}
pub fn is_empty(&self) -> bool {
self.state.tokens_map.is_empty()
}
}
/// A collection of PrefillCounters for multiple workers with centralized event handling
pub struct PrefillCountersMultiWorker {
pub counters: Arc<DashMap<i64, PrefillCounter>>,
pub request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
}
impl PrefillCountersMultiWorker {
// Helper function to handle new prefill logic
fn handle_new_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
worker_id: i64,
tokens: usize,
) {
// Check if request already exists
if let Some(existing_worker_id) = request_to_workers.get(request_id) {
tracing::warn!(
"Request {} already exists for worker {}, but trying to add to worker {}",
request_id,
*existing_worker_id,
worker_id
);
}
// Update mapping
request_to_workers.insert(request_id.to_string(), worker_id);
// Get or create counter and insert using get_mut
if let Some(mut counter) = counters.get_mut(&worker_id) {
counter.insert_direct(request_id.to_string(), tokens);
} else {
tracing::warn!(
"Worker {} does not exist, creating new PrefillCounter",
worker_id
);
let mut new_counter = PrefillCounter::default();
new_counter.insert_direct(request_id.to_string(), tokens);
counters.insert(worker_id, new_counter);
};
}
// Helper function to handle complete prefill logic
fn handle_complete_prefill(
counters: &Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: &Arc<DashMap<String, i64>>,
request_id: &str,
) -> Option<usize> {
// Remove from request_to_workers and get the worker_id
let Some((_, worker_id)) = request_to_workers.remove(request_id) else {
tracing::warn!("Request {} not found in request_to_workers", request_id);
return None;
};
// Use the worker_id from request_to_workers with get_mut
let Some(mut counter) = counters.get_mut(&worker_id) else {
tracing::warn!(
"No counter found for worker {} for request {}",
worker_id,
request_id
);
return None;
};
let removed_tokens = counter.remove_direct(request_id);
if removed_tokens.is_none() {
tracing::warn!("Attempted to remove non-existent request: {}", request_id);
}
removed_tokens
}
pub fn new(component: Component) -> Self {
let counters = Arc::new(DashMap::new());
let request_to_workers = Arc::new(DashMap::new());
let router_id = Uuid::new_v4();
let multi_worker = Self {
counters: counters.clone(),
request_to_workers: request_to_workers.clone(),
component: component.clone(),
router_id,
};
// Start the subscription loop
let counters_clone = counters.clone();
let request_to_workers_clone = request_to_workers.clone();
let component_clone = component.clone();
let router_id_clone = router_id;
tokio::spawn(async move {
if let Err(e) = Self::subscribe_to_events(
counters_clone,
request_to_workers_clone,
component_clone,
router_id_clone,
)
.await
{
tracing::error!("Error in prefill events subscription: {}", e);
}
});
multi_worker
}
/// Background task to subscribe to prefill events and update all counters
async fn subscribe_to_events(
counters: Arc<DashMap<i64, PrefillCounter>>,
request_to_workers: Arc<DashMap<String, i64>>,
component: Component,
router_id: Uuid,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<PrefillEvent>(PREFILL_SUBJECT)
.await?;
while let Some(result) = subscriber.next().await {
let Ok(event) = result else {
tracing::error!("Error receiving prefill event: {}", result.unwrap_err());
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
match event.data {
PrefillEventData::NewPrefill(tokens) => {
Self::handle_new_prefill(
&counters,
&request_to_workers,
&event.request_id,
event.worker_id,
tokens,
);
}
PrefillEventData::UpdatePrefill(_) => {
// Do nothing for now
continue;
}
PrefillEventData::CompletePrefill => {
Self::handle_complete_prefill(
&counters,
&request_to_workers,
&event.request_id,
);
}
}
}
Ok(())
}
pub async fn add_prefill(
&self,
worker_id: i64,
request_id: String,
new_tokens: usize,
) -> Result<()> {
let event = PrefillEvent {
request_id: request_id.clone(),
worker_id,
data: PrefillEventData::NewPrefill(new_tokens),
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Self::handle_new_prefill(
&self.counters,
&self.request_to_workers,
&request_id,
worker_id,
new_tokens,
);
Ok(())
}
pub async fn remove_prefill(&self, request_id: &str) -> Result<Option<usize>> {
// Send the event first with dummy worker_id
let event = PrefillEvent {
request_id: request_id.to_string(),
worker_id: 0, // Dummy worker_id
data: PrefillEventData::CompletePrefill,
router_id: self.router_id,
};
self.component.publish(PREFILL_SUBJECT, &event).await?;
// Use the helper function
Ok(Self::handle_complete_prefill(
&self.counters,
&self.request_to_workers,
request_id,
))
}
/// Get the running sums for all workers as a HashMap<i64, usize>
pub async fn running_sums(&self) -> HashMap<i64, usize> {
self.counters
.iter()
.map(|entry| (*entry.key(), entry.value().running_sum()))
.collect()
}
/// Get a specific counter's running sum
pub async fn get_worker_sum(&self, worker_id: i64) -> Option<usize> {
self.counters.get(&worker_id).map(|c| c.running_sum())
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime};
use std::sync::{Arc, Mutex};
use std::thread;
use tokio::time::Duration;
#[test]
#[ignore]
fn test_prefill_counter_multiworker_synchronization() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init();
let worker_id_1 = 1;
let worker_id_2 = 2;
let tokens_per_request = 100;
let requests_per_worker = 10;
// Shared state for collecting results from both threads
let results1 = Arc::new(Mutex::new(None));
let results2 = Arc::new(Mutex::new(None));
let final_results1 = Arc::new(Mutex::new(None));
let final_results2 = Arc::new(Mutex::new(None));
let results1_clone = results1.clone();
let results2_clone = results2.clone();
let final_results1_clone = final_results1.clone();
let final_results2_clone = final_results2.clone();
// Thread 1: First distributed runtime with multi_worker1
let handle1 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create first PrefillCountersMultiWorker instance
let multi_worker1 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Send requests to multi_worker1's worker
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1
.add_prefill(worker_id_1, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums1 = multi_worker1.running_sums().await;
*results1_clone.lock().unwrap() = Some(sums1);
// Wait for other thread to add its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker1
for i in 0..requests_per_worker {
let request_id = format!("mw1_request_{}", i);
multi_worker1.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums1 = multi_worker1.running_sums().await;
*final_results1_clone.lock().unwrap() = Some(final_sums1);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Thread 2: Second distributed runtime with multi_worker2
let handle2 = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and components with same names
let namespace = distributed.namespace("test_prefill_multiworker")?;
let component = namespace
.component("counters")?
.service_builder()
.create()
.await?;
// Create second PrefillCountersMultiWorker instance
let multi_worker2 = PrefillCountersMultiWorker::new(component);
// Give some time for subscribers to initialize
tokio::time::sleep(Duration::from_millis(3000)).await;
// Wait a bit to ensure multi_worker1 has started
tokio::time::sleep(Duration::from_millis(500)).await;
// Send requests to multi_worker2's worker
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2
.add_prefill(worker_id_2, request_id, tokens_per_request)
.await?;
}
// Wait for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get running sums after additions
let sums2 = multi_worker2.running_sums().await;
*results2_clone.lock().unwrap() = Some(sums2);
// Wait for other thread to remove its requests
tokio::time::sleep(Duration::from_millis(2000)).await;
// Remove all requests from multi_worker2
for i in 0..requests_per_worker {
let request_id = format!("mw2_request_{}", i);
multi_worker2.remove_prefill(&request_id).await?;
}
// Wait for removal synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Get final running sums
let final_sums2 = multi_worker2.running_sums().await;
*final_results2_clone.lock().unwrap() = Some(final_sums2);
// Keep runtime alive a bit longer for synchronization
tokio::time::sleep(Duration::from_millis(1000)).await;
// Shutdown runtime
runtime.shutdown();
Ok::<(), anyhow::Error>(())
})
});
// Wait for both threads to complete
handle1.join().unwrap()?;
handle2.join().unwrap()?;
// Extract results
let sums1 = results1.lock().unwrap().take().unwrap();
let sums2 = results2.lock().unwrap().take().unwrap();
let final_sums1 = final_results1.lock().unwrap().take().unwrap();
let final_sums2 = final_results2.lock().unwrap().take().unwrap();
// Verify both multi-workers see all requests
assert_eq!(
sums1.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 1's requests"
);
assert_eq!(
sums1.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker1 should see worker 2's requests"
);
assert_eq!(
sums2.get(&worker_id_1),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 1's requests"
);
assert_eq!(
sums2.get(&worker_id_2),
Some(&(requests_per_worker * tokens_per_request)),
"MultiWorker2 should see worker 2's requests"
);
// Verify both multi-workers show zero sums after removal
assert_eq!(
final_sums1.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 1"
);
assert_eq!(
final_sums1.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker1 should show zero for worker 2"
);
assert_eq!(
final_sums2.get(&worker_id_1).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 1"
);
assert_eq!(
final_sums2.get(&worker_id_2).copied().unwrap_or(0),
0,
"MultiWorker2 should show zero for worker 2"
);
Ok(())
}
}
......@@ -163,7 +163,7 @@ pub struct ActiveSequenceEvent {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveSequenceEventData {
AddRequest {
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
},
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
......@@ -56,7 +57,7 @@ pub struct SchedulingResponse {
pub struct SchedulingRequest {
pub request_id: String,
pub token_seq: Vec<SequenceHash>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>,
......@@ -96,6 +97,7 @@ impl KvScheduler {
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
......@@ -124,6 +126,7 @@ impl KvScheduler {
block_size as usize,
worker_ids,
replica_sync,
router_uuid,
));
// Spawn background task to monitor and update workers_with_configs
......@@ -240,20 +243,26 @@ impl KvScheduler {
};
request.respond(response);
// Only update the state if update_states is true
if request.update_states {
let _ = slots_clone
.add_request(
request.request_id,
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await;
// Skip state update if not requested
if !request.update_states {
continue;
}
continue;
let request_id = request.request_id;
if let Err(e) = slots_clone
.add_request(
request_id.clone(),
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await
{
tracing::warn!(
"Failed to add request {request_id} to local slot tracker: {e:?}"
);
}
}
Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update");
......@@ -283,7 +292,7 @@ impl KvScheduler {
&self,
request_id: String,
isl_tokens: usize,
token_seq: Vec<SequenceHash>,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
......@@ -316,7 +325,7 @@ impl KvScheduler {
pub async fn add_request(
&self,
request_id: String,
token_sequence: Vec<SequenceHash>,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
worker_id: i64,
......@@ -327,20 +336,19 @@ impl KvScheduler {
.await;
}
pub async fn mark_prefill_completed(&self, request_id: &str) {
let _ = self
.slots
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await;
.await
}
pub async fn free(&self, request_id: &str) {
let _ = self.slots.free(&request_id.to_string()).await;
pub async fn free(&self, request_id: &str) -> Result<()> {
self.slots.free(&request_id.to_string()).await
}
pub async fn get_potential_loads(
&self,
token_seq: Vec<SequenceHash>,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment