Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
......@@ -689,6 +689,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
results
}
/// Return true if any worker satisfies the provided predicate on active token count.
pub fn any_worker_matches_active_tokens(
&self,
mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool,
) -> bool {
let table = self.workers.read();
for (worker, lock) in &table.slots {
if predicate(*worker, lock.read().active_tokens()) {
return true;
}
}
false
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for entry in self.request_to_lora.iter() {
......
......@@ -117,7 +117,7 @@ impl SequencePublisher for NoopSequencePublisher {
}
/// Minimal [`WorkerConfigLike`] for scheduler/queue tests and benchmarks.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SimpleWorkerConfig {
pub data_parallel_start_rank: u32,
pub data_parallel_size: u32,
......
......@@ -587,16 +587,13 @@ impl ModelManager {
// Get of create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(
kv_router_config.clone(),
worker_type,
));
let selector = DefaultWorkerSelector::new(kv_router_config.clone(), worker_type);
let chooser = KvRouter::new(
endpoint.clone(),
client,
workers_with_configs,
kv_cache_block_size,
Some(selector),
selector,
kv_router_config,
worker_type,
model_name,
......
......@@ -38,8 +38,6 @@ pub mod metrics;
pub mod prefill_router;
pub mod publisher;
pub mod push_router;
pub mod queue;
pub mod recorder;
pub mod remote_indexer;
pub mod scheduler;
pub mod sequence;
......@@ -54,7 +52,7 @@ use crate::{
discovery::RuntimeConfigWatch,
kv_router::{
remote_indexer::RemoteIndexer,
scheduler::{KvScheduler, PotentialLoad},
scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
sequence::{SequenceError, SequenceRequest},
},
local_model::runtime_config::ModelRuntimeConfig,
......@@ -109,10 +107,6 @@ pub fn router_discovery_query(namespace: String, component: String) -> Discovery
}
}
/// Concrete `WorkerSelector` bound to the runtime config type.
pub type WorkerSelector =
dyn dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync;
#[derive(Clone)]
pub enum Indexer {
/// Single-threaded radix tree with channel-based event processing.
......@@ -297,23 +291,29 @@ impl Indexer {
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter {
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
indexer: Indexer,
scheduler: KvScheduler,
scheduler: KvScheduler<Sel>,
block_size: u32,
kv_router_config: KvRouterConfig,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
}
impl KvRouter {
impl<Sel> KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub async fn new(
endpoint: Endpoint,
client: Client,
mut workers_with_configs: RuntimeConfigWatch,
block_size: u32,
selector: Option<Box<WorkerSelector>>,
selector: Sel,
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
model_name: Option<String>,
......@@ -327,10 +327,13 @@ impl KvRouter {
if !kv_router_config.skip_initial_worker_wait {
let _ = workers_with_configs
.wait_for(|m| !m.is_empty())
.wait_for(|m| m.len() >= kv_router_config.min_initial_workers)
.await
.map_err(|_| {
anyhow::anyhow!("runtime config watch closed before any workers appeared")
anyhow::anyhow!(
"runtime config watch closed before {} workers appeared",
kv_router_config.min_initial_workers
)
})?;
}
......@@ -596,7 +599,11 @@ impl KvRouter {
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
for KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
async fn generate(
&self,
request: SingleIn<RouterRequest>,
......@@ -649,7 +656,10 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
}
}
impl Drop for KvRouter {
impl<Sel> Drop for KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
fn drop(&mut self) {
tracing::info!("Dropping KvRouter - cancelling background tasks");
self.cancellation_token.cancel();
......
......@@ -490,11 +490,11 @@ impl Drop for KvEventPublisher {
}
}
use dynamo_kv_router::EventSink;
use dynamo_kv_router::RouterEventSink;
struct EventPlanePublisher(EventPublisher);
impl EventSink for EventPlanePublisher {
impl RouterEventSink for EventPlanePublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.0.publish(event)
}
......@@ -502,7 +502,7 @@ impl EventSink for EventPlanePublisher {
struct JetStreamPublisher(NatsQueue);
impl EventSink for JetStreamPublisher {
impl RouterEventSink for JetStreamPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
NatsQueue::publish_event(&self.0, KV_EVENT_SUBJECT, event)
}
......@@ -510,7 +510,7 @@ impl EventSink for JetStreamPublisher {
/// Publishes a single [`KvCacheEvent`] to the event sink and, when present, the local indexer.
/// Errors are logged and swallowed so the caller loop can continue uninterrupted.
async fn emit<P: EventSink>(
async fn emit<P: RouterEventSink>(
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
......@@ -530,7 +530,7 @@ async fn emit<P: EventSink>(
impl BatchingState {
/// Publishes any pending batch as a single [`RouterEvent`] and advances the monotonic
/// batch ID. No-ops when nothing is pending, so callers may call unconditionally.
async fn flush<P: EventSink + Send + Sync + 'static>(
async fn flush<P: RouterEventSink + Send + Sync + 'static>(
&mut self,
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
......@@ -581,7 +581,7 @@ impl BatchingState {
/// - A `Stored` event's `parent_hash` breaks the sequential chain
/// - The batch window expires (`Some(timeout_ms)`; `None` = disabled, flush every event)
/// - Channel is closed or a cancellation signal is received
async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -719,7 +719,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
}
/// Batched event processor for ephemeral transports (NATS Core / ZMQ).
async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
async fn start_event_processor<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -740,7 +740,7 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
}
/// Batched event processor using JetStream (durable).
async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>(
async fn start_event_processor_jetstream<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -1481,7 +1481,7 @@ mod tests_startup_helpers {
}
}
impl EventSink for MockComponent {
impl RouterEventSink for MockComponent {
fn publish_event(
&self,
event: &RouterEvent,
......@@ -2466,7 +2466,7 @@ mod event_processor_tests {
}
}
impl EventSink for MockPublisher {
impl RouterEventSink for MockPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.events.lock().unwrap().push(event.clone());
async { Ok(()) }
......@@ -2759,6 +2759,116 @@ mod event_processor_tests {
assert_eq!(total_blocks, 3, "All 3 blocks should be accounted for");
}
/// Test that reusing an older parent hash breaks the current sequential batch.
#[tokio::test]
async fn test_run_event_processor_loop_reused_parent_hash_breaks_chain() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(3),
tokens_hash: LocalBlockHash(300),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Reused parent hash should flush the current batch before starting a new one"
);
if let KvCacheEventData::Stored(data) = &events[0].event.data {
assert_eq!(
data.blocks.len(),
2,
"First batch should keep the valid chain"
);
assert_eq!(
data.parent_hash, None,
"First batch should preserve the original root parent"
);
} else {
panic!("Expected first event to be Stored");
}
if let KvCacheEventData::Stored(data) = &events[1].event.data {
assert_eq!(
data.blocks.len(),
1,
"Second batch should contain only the inconsistent event"
);
assert_eq!(
data.parent_hash,
Some(ExternalSequenceBlockHash(1)),
"Second batch should preserve the reused parent hash"
);
} else {
panic!("Expected second event to be Stored");
}
}
/// Test that with short timeout and slow input, events are NOT batched
/// Parametrized over different timeout values: 0ms, 0.1ms, 0.2ms
/// All use 2ms delay between events, so each event times out before the next arrives
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use crate::kv_router::sequence::RuntimeSequencePublisher;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
/// Concrete `SchedulerQueue` wired to the runtime publisher and config types.
pub type SchedulerQueue = dynamo_kv_router::queue::SchedulerQueue<
RuntimeSequencePublisher,
ModelRuntimeConfig,
RouterSchedulingPolicy,
>;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::recorder::Recorder;
use dynamo_kv_router::protocols::RouterEvent;
// Type alias for backward compatibility
pub type KvRecorder = Recorder<RouterEvent>;
#[cfg(test)]
mod tests {
use super::*;
use dynamo_kv_router::indexer::{KvIndexer, KvIndexerMetrics};
use dynamo_kv_router::protocols::*;
use std::time::Duration;
use tempfile::tempdir;
use tokio::fs;
use tokio_util::sync::CancellationToken;
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
)
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
)
}
#[tokio::test]
async fn test_recorder_streams_events_to_file() {
// Create a temporary directory for output files
let dir = tempdir().unwrap();
let file_path = dir.path().join("kv_events.jsonl");
// Part 1: Record events to a file
let token = CancellationToken::new();
let recorder = KvRecorder::new(token.clone(), &file_path, None, None, None)
.await
.unwrap();
let event_tx = recorder.event_sender();
// Create first event from worker 1 using helper function
let event1 = create_store_event(1, 42, vec![1, 2, 3], None);
// Create second event from worker 2 using helper function
let event2 = create_remove_event(1, 43, vec![2, 3]);
// Send both events one after another
event_tx.send(event1).await.unwrap();
event_tx.send(event2).await.unwrap();
// Allow some time for processing
tokio::time::sleep(Duration::from_millis(10)).await;
// Check that both events were recorded
assert_eq!(recorder.event_count().await, 2);
// Force shutdown to flush file
recorder.shutdown();
tokio::time::sleep(Duration::from_millis(10)).await;
// Read the file and verify content
let content = fs::read_to_string(&file_path).await.unwrap();
let lines: Vec<&str> = content.lines().collect();
// Print the content of the JSONL file
println!("JSONL file content:");
for (i, line) in lines.iter().enumerate() {
println!("Line {}: {}", i + 1, line);
}
assert_eq!(lines.len(), 2, "Expected 2 lines in the file");
// Part 2: Now create a KvIndexer and load the events from the file
let indexer_token = CancellationToken::new();
let kv_block_size = 32; // Default block size for testing
let kv_indexer_metrics = KvIndexerMetrics::new_unregistered();
let indexer = KvIndexer::new(
indexer_token.clone(),
kv_block_size,
kv_indexer_metrics.into(),
);
let indexer_event_tx = indexer.event_sender();
// Use the send_events method to load events from file to indexer
let count = KvRecorder::send_events(&file_path, &indexer_event_tx, false, None, None)
.await
.unwrap();
assert_eq!(count, 2, "Expected to send 2 events from file to indexer");
}
}
......@@ -3,15 +3,14 @@
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
pub use dynamo_kv_router::scheduling::{
KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;
use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait;
use super::WorkerSelector;
use super::metrics::ROUTER_QUEUE_METRICS;
use super::queue::SchedulerQueue;
use super::sequence::{
ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
RuntimeSequencePublisher, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig;
......@@ -22,35 +21,32 @@ use dynamo_kv_router::{
};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use dynamo_tokens::SequenceHash;
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMulti>,
queue: Arc<SchedulerQueue>,
pub struct KvScheduler<Sel = DefaultWorkerSelector>
where
Sel: WorkerSelectorTrait<ModelRuntimeConfig>,
{
inner: Arc<
LocalScheduler<RuntimeSequencePublisher, ModelRuntimeConfig, RouterSchedulingPolicy, Sel>,
>,
}
impl KvScheduler {
impl<Sel> KvScheduler<Sel>
where
Sel: WorkerSelectorTrait<ModelRuntimeConfig> + Send + Sync + 'static,
{
pub async fn start(
component: Component,
block_size: u32,
workers_with_configs: RuntimeConfigWatch,
selector: Option<Box<WorkerSelector>>,
selector: Sel,
kv_router_config: &KvRouterConfig,
worker_type: &'static str,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
// Get initial workers from watch receiver.
// When skip_initial_worker_wait is false, the caller ensures at least one
// worker is present (via wait_for). When true the map may be empty;
// workers will be lazily registered via allowed_worker_ids per-request.
let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
workers_with_configs.borrow().clone();
......@@ -66,56 +62,11 @@ impl KvScheduler {
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
// Spawn background task to sync slots when the watch value changes.
//
// In EPP mode (skip_initial_worker_wait=true) we skip the monitoring task:
// the per-request allowed_worker_ids is the source of truth, workers are
// lazily registered via register_external_workers() from the C bindings,
// and update_workers() would impose discovery-based lifecycle (add/remove)
// on the slot tracker, conflicting with EPP ownership.
if kv_router_config.skip_initial_worker_wait {
let watch_worker_configs = !kv_router_config.skip_initial_worker_wait;
if !watch_worker_configs {
tracing::info!("skipping discovery-based worker monitoring");
} else {
let slots_monitor = slots.clone();
let mut monitor_rx = workers_with_configs.clone();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("KvScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers != last_workers {
let dp_range: HashMap<u64, (u32, u32)> = current_workers
.iter()
.map(|(&id, c)| {
(id, (c.data_parallel_start_rank, c.data_parallel_size))
})
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
}
});
}
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let policy =
RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
tracing::info!(
......@@ -123,52 +74,36 @@ impl KvScheduler {
kv_router_config.router_queue_policy
);
let queue = Arc::new(SchedulerQueue::new(
slots.clone(),
let inner = Arc::new(LocalScheduler::new(
slots,
workers_with_configs.clone(),
kv_router_config.router_queue_threshold,
block_size,
selector,
policy,
component.drt().child_token(),
worker_type,
watch_worker_configs,
));
let queue_clone = queue.clone();
// Background task: receive requests and periodically recheck pending
let metrics_scheduler = Arc::clone(&inner);
let metrics_cancel_token = component.drt().child_token();
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
tracing::trace!("scheduler background task started");
ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
loop {
tokio::select! {
_ = scheduler_cancel_token.cancelled() => {
tracing::trace!("scheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("scheduler shutdown");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
}
_ = metrics_cancel_token.cancelled() => break,
_ = recheck_interval.tick() => {
queue_clone.update().await;
ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
ROUTER_QUEUE_METRICS
.set_pending(worker_type, metrics_scheduler.pending_count());
}
}
}
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler {
request_tx,
slots,
queue,
})
Ok(Self { inner })
}
#[expect(clippy::too_many_arguments)]
......@@ -185,85 +120,51 @@ impl KvScheduler {
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)??;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"scheduler.schedule completed"
);
Ok(response)
let response = self
.inner
.schedule(
maybe_request_id,
isl_tokens,
token_seq,
overlaps,
router_config_override,
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
)
.await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
response
}
/// Register externally-provided workers in the slot tracker.
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.queue.register_workers(worker_ids);
self.inner.register_workers(worker_ids);
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
self.inner.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
self.inner.mark_prefill_completed(request_id).await?;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
self.inner.free(request_id).await?;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
Ok(())
}
/// Number of requests currently parked in the scheduler queue.
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
self.inner.pending_count()
}
/// Get the worker type for this scheduler ("prefill" or "decode").
/// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str {
self.slots.worker_type()
self.inner.worker_type()
}
pub fn add_output_block(
......@@ -271,8 +172,7 @@ impl KvScheduler {
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
self.inner.add_output_block(request_id, decay_fraction)
}
pub fn get_potential_loads(
......@@ -281,34 +181,11 @@ impl KvScheduler {
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
// Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker
let mut loads = Vec::new();
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
self.inner
.get_potential_loads(token_seq, isl_tokens, overlaps)
}
/// Get active request counts grouped by LORA name
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
self.inner.get_active_lora_counts()
}
}
......@@ -20,7 +20,8 @@ use dashmap::DashMap;
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use dynamo_mocker::common::bootstrap::{BootstrapServer, connect_to_prefill};
use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal,
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal, RawKvEvent,
RawKvEventSink,
};
use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise};
use dynamo_mocker::engine::create_engine;
......@@ -48,11 +49,7 @@ pub const MOCKER_COMPONENT: &str = "mocker";
struct KvEventSinkAdapter(KvEventPublisher);
impl KvCacheEventSink for KvEventSinkAdapter {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.0
.publish(event)
.map_err(|e| anyhow::anyhow!("Failed to send KV event: {}", e))
......@@ -77,13 +74,8 @@ enum ZmqRawKvEvent {
},
}
struct ZmqKvEventMsg {
event: KvCacheEvent,
block_token_ids: Option<Vec<Vec<u32>>>,
}
struct ZmqKvEventSink {
tx: mpsc::UnboundedSender<ZmqKvEventMsg>,
tx: mpsc::UnboundedSender<RawKvEvent>,
}
/// Maximum number of entries in the replay ring buffer.
......@@ -96,7 +88,7 @@ impl ZmqKvEventSink {
dp_rank: u32,
block_size: u32,
) -> Result<Self> {
let (tx, mut rx) = mpsc::unbounded_channel::<ZmqKvEventMsg>();
let (tx, mut rx) = mpsc::unbounded_channel::<RawKvEvent>();
// Bind the PUB socket before returning so that any SUB connect()
// that follows is guaranteed to find the endpoint already listening.
......@@ -250,17 +242,10 @@ impl ZmqKvEventSink {
}
}
impl KvCacheEventSink for ZmqKvEventSink {
fn publish(
&self,
event: KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
impl RawKvEventSink for ZmqKvEventSink {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()> {
self.tx
.send(ZmqKvEventMsg {
event,
block_token_ids: block_token_ids.map(|t| t.to_vec()),
})
.send(event)
.map_err(|_| anyhow::anyhow!("ZMQ event sink channel closed"))
}
}
......@@ -413,8 +398,8 @@ impl MockEngine {
for dp_rank in 0..args.dp_size {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (kv_event_sink, relay_publisher): (
Option<Arc<dyn KvCacheEventSink>>,
let (kv_event_publishers, relay_publisher): (
KvEventPublishers,
Option<KvEventPublisher>,
) = match component {
Some(comp) if args.zmq_kv_events_port.is_some() => {
......@@ -442,14 +427,17 @@ impl MockEngine {
None,
) {
Ok(publisher) => (
Some(Arc::new(sink) as Arc<dyn KvCacheEventSink>),
KvEventPublishers::new(
None,
Some(Arc::new(sink) as Arc<dyn RawKvEventSink>),
),
Some(publisher),
),
Err(e) => {
tracing::error!(
"Failed to create KV event relay for dp_rank {dp_rank}: {e}"
);
(None, None)
(KvEventPublishers::default(), None)
}
}
}
......@@ -457,7 +445,7 @@ impl MockEngine {
tracing::error!(
"Failed to create ZMQ KV event sink for dp_rank {dp_rank}: {e}"
);
(None, None)
(KvEventPublishers::default(), None)
}
}
}
......@@ -471,26 +459,29 @@ impl MockEngine {
None,
) {
Ok(publisher) => (
Some(Arc::new(KvEventSinkAdapter(publisher))
as Arc<dyn KvCacheEventSink>),
KvEventPublishers::new(
Some(Arc::new(KvEventSinkAdapter(publisher))
as Arc<dyn KvCacheEventSink>),
None,
),
None,
),
Err(e) => {
tracing::error!(
"Failed to create KV event publisher for dp_rank {dp_rank}: {e}"
);
(None, None)
(KvEventPublishers::default(), None)
}
}
}
None => (None, None),
None => (KvEventPublishers::default(), None),
};
let scheduler = create_engine(
args.clone(),
dp_rank,
Some(output_tx),
kv_event_sink,
kv_event_publishers,
Some(cancel_token.clone()),
);
......
......@@ -13,7 +13,6 @@ repository.workspace = true
[dependencies]
# repo
dynamo-kv-router = { workspace = true }
dynamo-runtime = { workspace = true }
dynamo-tokens = { workspace = true }
# workspace
......@@ -41,3 +40,4 @@ tokio-timerfd = "0.2"
[dev-dependencies]
rstest = "0.18.2"
tempfile = { workspace = true }
......@@ -5,14 +5,15 @@
//!
//! Enabled by setting `DYN_MOCKER_KV_CACHE_TRACE=1` or `true`.
use dynamo_runtime::config::environment_names::mocker;
use std::env;
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
const DYN_MOCKER_KV_CACHE_TRACE: &str = "DYN_MOCKER_KV_CACHE_TRACE";
/// Check the env var to enable KV cache allocation/eviction trace logs.
pub static KV_CACHE_TRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
env::var(mocker::DYN_MOCKER_KV_CACHE_TRACE)
env::var(DYN_MOCKER_KV_CACHE_TRACE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
});
......
This diff is collapsed.
......@@ -134,13 +134,12 @@ impl ActiveSequence {
let hashes = self.block_hashes[hash_start..hash_end].to_vec();
let token_ids = if self.emit_token_ids && hash_start < hash_end {
let all_token_ids: Vec<Vec<u32>> = self
.tokens
.blocks()
.iter()
.map(|b| b.tokens().to_vec())
.collect();
Some(all_token_ids[hash_start..hash_end].to_vec())
Some(
self.tokens.blocks()[hash_start..hash_end]
.iter()
.map(|b| b.tokens().to_vec())
.collect(),
)
} else {
None
};
......@@ -276,13 +275,15 @@ impl ActiveSequence {
}
// Free all blocks when we reach max tokens
signals.extend(self.free_signal());
signals.extend(self.free_signal_for_tokens(self.len()));
signals
}
/// Free all blocks, generating appropriate signals for each block type
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.unique_blocks
fn free_signal_for_tokens(&self, active_tokens: usize) -> Vec<MoveBlock> {
let active_blocks = active_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
self.unique_blocks[..active_blocks]
.iter()
.rev()
.map(|block| match block {
......@@ -296,6 +297,11 @@ impl ActiveSequence {
.collect()
}
/// Free the currently active allocation footprint.
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.free_signal_for_tokens(self.num_allocated_tokens)
}
/// Move the request to a preempted state and return the free signals from freeing current blocks.
/// Upon preemption, the sequence retains the tokens generated during the decode phase (if any).
/// Resets `num_allocated_tokens` so re-admission will re-allocate from scratch.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -11,5 +11,5 @@ pub mod cache;
pub mod common;
pub mod engine;
pub mod kv_manager;
pub mod replay;
pub mod scheduler;
pub mod simulation;
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment