Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
......@@ -689,6 +689,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
results
}
/// Return true if any worker satisfies the provided predicate on active token count.
pub fn any_worker_matches_active_tokens(
&self,
mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool,
) -> bool {
let table = self.workers.read();
for (worker, lock) in &table.slots {
if predicate(*worker, lock.read().active_tokens()) {
return true;
}
}
false
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for entry in self.request_to_lora.iter() {
......
......@@ -117,7 +117,7 @@ impl SequencePublisher for NoopSequencePublisher {
}
/// Minimal [`WorkerConfigLike`] for scheduler/queue tests and benchmarks.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SimpleWorkerConfig {
pub data_parallel_start_rank: u32,
pub data_parallel_size: u32,
......
......@@ -587,16 +587,13 @@ impl ModelManager {
// Get of create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(
kv_router_config.clone(),
worker_type,
));
let selector = DefaultWorkerSelector::new(kv_router_config.clone(), worker_type);
let chooser = KvRouter::new(
endpoint.clone(),
client,
workers_with_configs,
kv_cache_block_size,
Some(selector),
selector,
kv_router_config,
worker_type,
model_name,
......
......@@ -38,8 +38,6 @@ pub mod metrics;
pub mod prefill_router;
pub mod publisher;
pub mod push_router;
pub mod queue;
pub mod recorder;
pub mod remote_indexer;
pub mod scheduler;
pub mod sequence;
......@@ -54,7 +52,7 @@ use crate::{
discovery::RuntimeConfigWatch,
kv_router::{
remote_indexer::RemoteIndexer,
scheduler::{KvScheduler, PotentialLoad},
scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
sequence::{SequenceError, SequenceRequest},
},
local_model::runtime_config::ModelRuntimeConfig,
......@@ -109,10 +107,6 @@ pub fn router_discovery_query(namespace: String, component: String) -> Discovery
}
}
/// Concrete `WorkerSelector` bound to the runtime config type.
pub type WorkerSelector =
dyn dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync;
#[derive(Clone)]
pub enum Indexer {
/// Single-threaded radix tree with channel-based event processing.
......@@ -297,23 +291,29 @@ impl Indexer {
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter {
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
indexer: Indexer,
scheduler: KvScheduler,
scheduler: KvScheduler<Sel>,
block_size: u32,
kv_router_config: KvRouterConfig,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
}
impl KvRouter {
impl<Sel> KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub async fn new(
endpoint: Endpoint,
client: Client,
mut workers_with_configs: RuntimeConfigWatch,
block_size: u32,
selector: Option<Box<WorkerSelector>>,
selector: Sel,
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
model_name: Option<String>,
......@@ -327,10 +327,13 @@ impl KvRouter {
if !kv_router_config.skip_initial_worker_wait {
let _ = workers_with_configs
.wait_for(|m| !m.is_empty())
.wait_for(|m| m.len() >= kv_router_config.min_initial_workers)
.await
.map_err(|_| {
anyhow::anyhow!("runtime config watch closed before any workers appeared")
anyhow::anyhow!(
"runtime config watch closed before {} workers appeared",
kv_router_config.min_initial_workers
)
})?;
}
......@@ -596,7 +599,11 @@ impl KvRouter {
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
for KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
async fn generate(
&self,
request: SingleIn<RouterRequest>,
......@@ -649,7 +656,10 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
}
}
impl Drop for KvRouter {
impl<Sel> Drop for KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
fn drop(&mut self) {
tracing::info!("Dropping KvRouter - cancelling background tasks");
self.cancellation_token.cancel();
......
......@@ -490,11 +490,11 @@ impl Drop for KvEventPublisher {
}
}
use dynamo_kv_router::EventSink;
use dynamo_kv_router::RouterEventSink;
struct EventPlanePublisher(EventPublisher);
impl EventSink for EventPlanePublisher {
impl RouterEventSink for EventPlanePublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.0.publish(event)
}
......@@ -502,7 +502,7 @@ impl EventSink for EventPlanePublisher {
struct JetStreamPublisher(NatsQueue);
impl EventSink for JetStreamPublisher {
impl RouterEventSink for JetStreamPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
NatsQueue::publish_event(&self.0, KV_EVENT_SUBJECT, event)
}
......@@ -510,7 +510,7 @@ impl EventSink for JetStreamPublisher {
/// Publishes a single [`KvCacheEvent`] to the event sink and, when present, the local indexer.
/// Errors are logged and swallowed so the caller loop can continue uninterrupted.
async fn emit<P: EventSink>(
async fn emit<P: RouterEventSink>(
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
......@@ -530,7 +530,7 @@ async fn emit<P: EventSink>(
impl BatchingState {
/// Publishes any pending batch as a single [`RouterEvent`] and advances the monotonic
/// batch ID. No-ops when nothing is pending, so callers may call unconditionally.
async fn flush<P: EventSink + Send + Sync + 'static>(
async fn flush<P: RouterEventSink + Send + Sync + 'static>(
&mut self,
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
......@@ -581,7 +581,7 @@ impl BatchingState {
/// - A `Stored` event's `parent_hash` breaks the sequential chain
/// - The batch window expires (`Some(timeout_ms)`; `None` = disabled, flush every event)
/// - Channel is closed or a cancellation signal is received
async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -719,7 +719,7 @@ async fn run_event_processor_loop<P: EventSink + Send + Sync + 'static>(
}
/// Batched event processor for ephemeral transports (NATS Core / ZMQ).
async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
async fn start_event_processor<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -740,7 +740,7 @@ async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
}
/// Batched event processor using JetStream (durable).
async fn start_event_processor_jetstream<P: EventSink + Send + Sync + 'static>(
async fn start_event_processor_jetstream<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -1481,7 +1481,7 @@ mod tests_startup_helpers {
}
}
impl EventSink for MockComponent {
impl RouterEventSink for MockComponent {
fn publish_event(
&self,
event: &RouterEvent,
......@@ -2466,7 +2466,7 @@ mod event_processor_tests {
}
}
impl EventSink for MockPublisher {
impl RouterEventSink for MockPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.events.lock().unwrap().push(event.clone());
async { Ok(()) }
......@@ -2759,6 +2759,116 @@ mod event_processor_tests {
assert_eq!(total_blocks, 3, "All 3 blocks should be accounted for");
}
/// Test that reusing an older parent hash breaks the current sequential batch.
#[tokio::test]
async fn test_run_event_processor_loop_reused_parent_hash_breaks_chain() {
let timeout_ms = Some(100); // 100ms timeout
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(3),
tokens_hash: LocalBlockHash(300),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Reused parent hash should flush the current batch before starting a new one"
);
if let KvCacheEventData::Stored(data) = &events[0].event.data {
assert_eq!(
data.blocks.len(),
2,
"First batch should keep the valid chain"
);
assert_eq!(
data.parent_hash, None,
"First batch should preserve the original root parent"
);
} else {
panic!("Expected first event to be Stored");
}
if let KvCacheEventData::Stored(data) = &events[1].event.data {
assert_eq!(
data.blocks.len(),
1,
"Second batch should contain only the inconsistent event"
);
assert_eq!(
data.parent_hash,
Some(ExternalSequenceBlockHash(1)),
"Second batch should preserve the reused parent hash"
);
} else {
panic!("Expected second event to be Stored");
}
}
/// Test that with short timeout and slow input, events are NOT batched
/// Parametrized over different timeout values: 0ms, 0.1ms, 0.2ms
/// All use 2ms delay between events, so each event times out before the next arrives
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use crate::kv_router::sequence::RuntimeSequencePublisher;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
/// Concrete `SchedulerQueue` wired to the runtime publisher and config types.
pub type SchedulerQueue = dynamo_kv_router::queue::SchedulerQueue<
RuntimeSequencePublisher,
ModelRuntimeConfig,
RouterSchedulingPolicy,
>;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::recorder::Recorder;
use dynamo_kv_router::protocols::RouterEvent;
// Type alias for backward compatibility
pub type KvRecorder = Recorder<RouterEvent>;
#[cfg(test)]
mod tests {
use super::*;
use dynamo_kv_router::indexer::{KvIndexer, KvIndexerMetrics};
use dynamo_kv_router::protocols::*;
use std::time::Duration;
use tempfile::tempdir;
use tokio::fs;
use tokio_util::sync::CancellationToken;
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
)
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent::new(
worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
)
}
#[tokio::test]
async fn test_recorder_streams_events_to_file() {
// Create a temporary directory for output files
let dir = tempdir().unwrap();
let file_path = dir.path().join("kv_events.jsonl");
// Part 1: Record events to a file
let token = CancellationToken::new();
let recorder = KvRecorder::new(token.clone(), &file_path, None, None, None)
.await
.unwrap();
let event_tx = recorder.event_sender();
// Create first event from worker 1 using helper function
let event1 = create_store_event(1, 42, vec![1, 2, 3], None);
// Create second event from worker 2 using helper function
let event2 = create_remove_event(1, 43, vec![2, 3]);
// Send both events one after another
event_tx.send(event1).await.unwrap();
event_tx.send(event2).await.unwrap();
// Allow some time for processing
tokio::time::sleep(Duration::from_millis(10)).await;
// Check that both events were recorded
assert_eq!(recorder.event_count().await, 2);
// Force shutdown to flush file
recorder.shutdown();
tokio::time::sleep(Duration::from_millis(10)).await;
// Read the file and verify content
let content = fs::read_to_string(&file_path).await.unwrap();
let lines: Vec<&str> = content.lines().collect();
// Print the content of the JSONL file
println!("JSONL file content:");
for (i, line) in lines.iter().enumerate() {
println!("Line {}: {}", i + 1, line);
}
assert_eq!(lines.len(), 2, "Expected 2 lines in the file");
// Part 2: Now create a KvIndexer and load the events from the file
let indexer_token = CancellationToken::new();
let kv_block_size = 32; // Default block size for testing
let kv_indexer_metrics = KvIndexerMetrics::new_unregistered();
let indexer = KvIndexer::new(
indexer_token.clone(),
kv_block_size,
kv_indexer_metrics.into(),
);
let indexer_event_tx = indexer.event_sender();
// Use the send_events method to load events from file to indexer
let count = KvRecorder::send_events(&file_path, &indexer_event_tx, false, None, None)
.await
.unwrap();
assert_eq!(count, 2, "Expected to send 2 events from file to indexer");
}
}
......@@ -3,15 +3,14 @@
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
pub use dynamo_kv_router::scheduling::{
KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;
use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait;
use super::WorkerSelector;
use super::metrics::ROUTER_QUEUE_METRICS;
use super::queue::SchedulerQueue;
use super::sequence::{
ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
RuntimeSequencePublisher, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig;
......@@ -22,35 +21,32 @@ use dynamo_kv_router::{
};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use dynamo_tokens::SequenceHash;
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMulti>,
queue: Arc<SchedulerQueue>,
pub struct KvScheduler<Sel = DefaultWorkerSelector>
where
Sel: WorkerSelectorTrait<ModelRuntimeConfig>,
{
inner: Arc<
LocalScheduler<RuntimeSequencePublisher, ModelRuntimeConfig, RouterSchedulingPolicy, Sel>,
>,
}
impl KvScheduler {
impl<Sel> KvScheduler<Sel>
where
Sel: WorkerSelectorTrait<ModelRuntimeConfig> + Send + Sync + 'static,
{
pub async fn start(
component: Component,
block_size: u32,
workers_with_configs: RuntimeConfigWatch,
selector: Option<Box<WorkerSelector>>,
selector: Sel,
kv_router_config: &KvRouterConfig,
worker_type: &'static str,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
// Get initial workers from watch receiver.
// When skip_initial_worker_wait is false, the caller ensures at least one
// worker is present (via wait_for). When true the map may be empty;
// workers will be lazily registered via allowed_worker_ids per-request.
let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
workers_with_configs.borrow().clone();
......@@ -66,56 +62,11 @@ impl KvScheduler {
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
// Spawn background task to sync slots when the watch value changes.
//
// In EPP mode (skip_initial_worker_wait=true) we skip the monitoring task:
// the per-request allowed_worker_ids is the source of truth, workers are
// lazily registered via register_external_workers() from the C bindings,
// and update_workers() would impose discovery-based lifecycle (add/remove)
// on the slot tracker, conflicting with EPP ownership.
if kv_router_config.skip_initial_worker_wait {
let watch_worker_configs = !kv_router_config.skip_initial_worker_wait;
if !watch_worker_configs {
tracing::info!("skipping discovery-based worker monitoring");
} else {
let slots_monitor = slots.clone();
let mut monitor_rx = workers_with_configs.clone();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("KvScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers != last_workers {
let dp_range: HashMap<u64, (u32, u32)> = current_workers
.iter()
.map(|(&id, c)| {
(id, (c.data_parallel_start_rank, c.data_parallel_size))
})
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
}
});
}
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let policy =
RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
tracing::info!(
......@@ -123,52 +74,36 @@ impl KvScheduler {
kv_router_config.router_queue_policy
);
let queue = Arc::new(SchedulerQueue::new(
slots.clone(),
let inner = Arc::new(LocalScheduler::new(
slots,
workers_with_configs.clone(),
kv_router_config.router_queue_threshold,
block_size,
selector,
policy,
component.drt().child_token(),
worker_type,
watch_worker_configs,
));
let queue_clone = queue.clone();
// Background task: receive requests and periodically recheck pending
let metrics_scheduler = Arc::clone(&inner);
let metrics_cancel_token = component.drt().child_token();
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
tracing::trace!("scheduler background task started");
ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
loop {
tokio::select! {
_ = scheduler_cancel_token.cancelled() => {
tracing::trace!("scheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("scheduler shutdown");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
}
_ = metrics_cancel_token.cancelled() => break,
_ = recheck_interval.tick() => {
queue_clone.update().await;
ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
ROUTER_QUEUE_METRICS
.set_pending(worker_type, metrics_scheduler.pending_count());
}
}
}
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler {
request_tx,
slots,
queue,
})
Ok(Self { inner })
}
#[expect(clippy::too_many_arguments)]
......@@ -185,85 +120,51 @@ impl KvScheduler {
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)??;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"scheduler.schedule completed"
);
Ok(response)
let response = self
.inner
.schedule(
maybe_request_id,
isl_tokens,
token_seq,
overlaps,
router_config_override,
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
)
.await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
response
}
/// Register externally-provided workers in the slot tracker.
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.queue.register_workers(worker_ids);
self.inner.register_workers(worker_ids);
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
self.inner.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
self.inner.mark_prefill_completed(request_id).await?;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
self.inner.free(request_id).await?;
ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
Ok(())
}
/// Number of requests currently parked in the scheduler queue.
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
self.inner.pending_count()
}
/// Get the worker type for this scheduler ("prefill" or "decode").
/// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str {
self.slots.worker_type()
self.inner.worker_type()
}
pub fn add_output_block(
......@@ -271,8 +172,7 @@ impl KvScheduler {
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
self.inner.add_output_block(request_id, decay_fraction)
}
pub fn get_potential_loads(
......@@ -281,34 +181,11 @@ impl KvScheduler {
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
// Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker
let mut loads = Vec::new();
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
self.inner
.get_potential_loads(token_seq, isl_tokens, overlaps)
}
/// Get active request counts grouped by LORA name
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
self.inner.get_active_lora_counts()
}
}
......@@ -20,7 +20,8 @@ use dashmap::DashMap;
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use dynamo_mocker::common::bootstrap::{BootstrapServer, connect_to_prefill};
use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal,
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal, RawKvEvent,
RawKvEventSink,
};
use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise};
use dynamo_mocker::engine::create_engine;
......@@ -48,11 +49,7 @@ pub const MOCKER_COMPONENT: &str = "mocker";
struct KvEventSinkAdapter(KvEventPublisher);
impl KvCacheEventSink for KvEventSinkAdapter {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.0
.publish(event)
.map_err(|e| anyhow::anyhow!("Failed to send KV event: {}", e))
......@@ -77,13 +74,8 @@ enum ZmqRawKvEvent {
},
}
struct ZmqKvEventMsg {
event: KvCacheEvent,
block_token_ids: Option<Vec<Vec<u32>>>,
}
struct ZmqKvEventSink {
tx: mpsc::UnboundedSender<ZmqKvEventMsg>,
tx: mpsc::UnboundedSender<RawKvEvent>,
}
/// Maximum number of entries in the replay ring buffer.
......@@ -96,7 +88,7 @@ impl ZmqKvEventSink {
dp_rank: u32,
block_size: u32,
) -> Result<Self> {
let (tx, mut rx) = mpsc::unbounded_channel::<ZmqKvEventMsg>();
let (tx, mut rx) = mpsc::unbounded_channel::<RawKvEvent>();
// Bind the PUB socket before returning so that any SUB connect()
// that follows is guaranteed to find the endpoint already listening.
......@@ -250,17 +242,10 @@ impl ZmqKvEventSink {
}
}
impl KvCacheEventSink for ZmqKvEventSink {
fn publish(
&self,
event: KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
impl RawKvEventSink for ZmqKvEventSink {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()> {
self.tx
.send(ZmqKvEventMsg {
event,
block_token_ids: block_token_ids.map(|t| t.to_vec()),
})
.send(event)
.map_err(|_| anyhow::anyhow!("ZMQ event sink channel closed"))
}
}
......@@ -413,8 +398,8 @@ impl MockEngine {
for dp_rank in 0..args.dp_size {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (kv_event_sink, relay_publisher): (
Option<Arc<dyn KvCacheEventSink>>,
let (kv_event_publishers, relay_publisher): (
KvEventPublishers,
Option<KvEventPublisher>,
) = match component {
Some(comp) if args.zmq_kv_events_port.is_some() => {
......@@ -442,14 +427,17 @@ impl MockEngine {
None,
) {
Ok(publisher) => (
Some(Arc::new(sink) as Arc<dyn KvCacheEventSink>),
KvEventPublishers::new(
None,
Some(Arc::new(sink) as Arc<dyn RawKvEventSink>),
),
Some(publisher),
),
Err(e) => {
tracing::error!(
"Failed to create KV event relay for dp_rank {dp_rank}: {e}"
);
(None, None)
(KvEventPublishers::default(), None)
}
}
}
......@@ -457,7 +445,7 @@ impl MockEngine {
tracing::error!(
"Failed to create ZMQ KV event sink for dp_rank {dp_rank}: {e}"
);
(None, None)
(KvEventPublishers::default(), None)
}
}
}
......@@ -471,26 +459,29 @@ impl MockEngine {
None,
) {
Ok(publisher) => (
Some(Arc::new(KvEventSinkAdapter(publisher))
as Arc<dyn KvCacheEventSink>),
KvEventPublishers::new(
Some(Arc::new(KvEventSinkAdapter(publisher))
as Arc<dyn KvCacheEventSink>),
None,
),
None,
),
Err(e) => {
tracing::error!(
"Failed to create KV event publisher for dp_rank {dp_rank}: {e}"
);
(None, None)
(KvEventPublishers::default(), None)
}
}
}
None => (None, None),
None => (KvEventPublishers::default(), None),
};
let scheduler = create_engine(
args.clone(),
dp_rank,
Some(output_tx),
kv_event_sink,
kv_event_publishers,
Some(cancel_token.clone()),
);
......
......@@ -13,7 +13,6 @@ repository.workspace = true
[dependencies]
# repo
dynamo-kv-router = { workspace = true }
dynamo-runtime = { workspace = true }
dynamo-tokens = { workspace = true }
# workspace
......@@ -41,3 +40,4 @@ tokio-timerfd = "0.2"
[dev-dependencies]
rstest = "0.18.2"
tempfile = { workspace = true }
......@@ -5,14 +5,15 @@
//!
//! Enabled by setting `DYN_MOCKER_KV_CACHE_TRACE=1` or `true`.
use dynamo_runtime::config::environment_names::mocker;
use std::env;
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
const DYN_MOCKER_KV_CACHE_TRACE: &str = "DYN_MOCKER_KV_CACHE_TRACE";
/// Check the env var to enable KV cache allocation/eviction trace logs.
pub static KV_CACHE_TRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
env::var(mocker::DYN_MOCKER_KV_CACHE_TRACE)
env::var(DYN_MOCKER_KV_CACHE_TRACE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
});
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use dynamo_kv_router::config::RouterQueuePolicy;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
......@@ -17,11 +18,66 @@ use dynamo_tokens::{BlockHash, SequenceHash, Token};
/// Trait for publishing KV cache events.
/// This abstracts the runtime dependency so mocker components can remain generic.
pub trait KvCacheEventSink: Send + Sync {
fn publish(
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()>;
}
/// Raw KV event payload used by transport-specific publishers such as the
/// vLLM-native ZMQ event stream.
#[derive(Debug, Clone)]
pub struct RawKvEvent {
pub event: KvCacheEvent,
pub block_token_ids: Option<Vec<Vec<u32>>>,
}
/// Trait for publishing transport-specific raw KV event payloads.
pub trait RawKvEventSink: Send + Sync {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()>;
}
/// Shared KV event publisher bundle used by schedulers and KV managers.
#[derive(Clone, Default)]
pub struct KvEventPublishers {
event_sink: Option<Arc<dyn KvCacheEventSink>>,
raw_sink: Option<Arc<dyn RawKvEventSink>>,
}
impl KvEventPublishers {
pub fn new(
event_sink: Option<Arc<dyn KvCacheEventSink>>,
raw_sink: Option<Arc<dyn RawKvEventSink>>,
) -> Self {
Self {
event_sink,
raw_sink,
}
}
pub fn raw_enabled(&self) -> bool {
self.raw_sink.is_some()
}
pub fn is_empty(&self) -> bool {
self.event_sink.is_none() && self.raw_sink.is_none()
}
pub fn publish(
&self,
event: KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()>;
) -> anyhow::Result<()> {
if let Some(sink) = self.event_sink.as_ref() {
sink.publish(event.clone())?;
}
if let Some(sink) = self.raw_sink.as_ref() {
sink.publish(RawKvEvent {
event,
block_token_ids: block_token_ids.map(|token_ids| token_ids.to_vec()),
})?;
}
Ok(())
}
}
pub type NumBlocks = usize;
......@@ -186,8 +242,7 @@ pub struct MockEngineArgs {
#[validate(range(min = 1))]
pub num_gpu_blocks: usize,
#[builder(default = "64")]
#[validate(range(min = 2))]
#[builder(default = "0")]
pub block_size: usize,
// This was 1024 in the past but reverted back to 256
......@@ -310,6 +365,10 @@ pub struct MockEngineArgs {
#[builder(default)]
pub preemption_mode: PreemptionMode,
/// Optional replay-only override for the router queue policy.
#[builder(default = "None")]
pub router_queue_policy: Option<RouterQueuePolicy>,
/// SGLang-specific configuration. Only used when `engine_type == Sglang`.
#[builder(default = "None")]
pub sglang: Option<SglangArgs>,
......@@ -320,14 +379,70 @@ impl Default for MockEngineArgs {
MockEngineArgsBuilder::default()
.build()
.expect("Failed to build default MockEngineArgs")
.normalized()
.expect("Failed to normalize default MockEngineArgs")
}
}
impl MockEngineArgs {
const DEFAULT_VLLM_BLOCK_SIZE: usize = 64;
const DEFAULT_SGLANG_BLOCK_SIZE: usize = 1;
pub fn builder() -> MockEngineArgsBuilder {
MockEngineArgsBuilder::default()
}
pub fn normalized(mut self) -> anyhow::Result<Self> {
match self.engine_type {
EngineType::Vllm => {
if self.block_size == 0 {
self.block_size = Self::DEFAULT_VLLM_BLOCK_SIZE;
}
}
EngineType::Sglang => {
let page_size = self.sglang.as_ref().and_then(|sglang| sglang.page_size);
match (self.block_size, page_size) {
(0, None) => {
self.block_size = Self::DEFAULT_SGLANG_BLOCK_SIZE;
}
(0, Some(page_size)) => {
self.block_size = page_size;
}
(block_size, Some(page_size)) if block_size == page_size => {}
(_, Some(page_size)) => {
return Err(anyhow::anyhow!(
"engine_type=sglang requires block_size and sglang.page_size to match when both are set, got block_size={} and sglang.page_size={page_size}",
self.block_size,
));
}
(_, None) => {}
}
}
}
if self.engine_type == EngineType::Sglang
&& let Some(chunked_prefill_size) = self
.sglang
.as_ref()
.and_then(|sglang| sglang.chunked_prefill_size)
&& chunked_prefill_size % self.block_size != 0
{
return Err(anyhow::anyhow!(
"engine_type=sglang requires sglang.chunked_prefill_size to be divisible by block_size, got chunked_prefill_size={} and block_size={}",
chunked_prefill_size,
self.block_size,
));
}
self.validate()
.map_err(|error| anyhow::anyhow!("Failed to validate MockEngineArgs: {error}"))?;
if self.block_size == 0 {
return Err(anyhow::anyhow!("block_size must be greater than 0"));
}
Ok(self)
}
pub fn is_prefill(&self) -> bool {
self.worker_type == WorkerType::Prefill
}
......@@ -342,11 +457,13 @@ impl MockEngineArgs {
/// Create MockEngineArgs from a JSON file containing extra engine arguments
pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
let mut builder = Self::builder();
// Load and parse the JSON file
let file_content = std::fs::read_to_string(path)?;
let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;
Self::from_json_str(&file_content)
}
pub fn from_json_str(content: &str) -> anyhow::Result<Self> {
let mut builder = Self::builder();
let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(content)?;
// Define valid field names
let valid_fields: HashSet<&str> = [
......@@ -377,6 +494,7 @@ impl MockEngineArgs {
"zmq_kv_events_port",
"zmq_replay_port",
"preemption_mode",
"router_queue_policy",
"sglang",
]
.iter()
......@@ -533,6 +651,13 @@ impl MockEngineArgs {
builder = builder.preemption_mode(mode);
}
if let Some(value) = extra_args.get("router_queue_policy")
&& let Some(policy_str) = value.as_str()
{
let policy = policy_str.parse().map_err(|e: String| anyhow::anyhow!(e))?;
builder = builder.router_queue_policy(Some(policy));
}
if let Some(value) = extra_args.get("sglang") {
let cfg: SglangArgs = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse sglang config: {}", e))?;
......@@ -615,12 +740,14 @@ impl MockEngineArgs {
builder
.build()
.map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
.and_then(Self::normalized)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_unique_block_default_uniqueness() {
......@@ -647,4 +774,132 @@ mod tests {
}
}
}
#[test]
fn test_normalized_sglang_uses_page_size_alias_for_block_size() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.sglang(Some(SglangArgs {
page_size: Some(16),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 16);
}
#[test]
fn test_normalized_sglang_accepts_equal_block_size_and_page_size() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(8)
.sglang(Some(SglangArgs {
page_size: Some(8),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 8);
}
#[test]
fn test_normalized_sglang_rejects_mismatched_block_size_and_page_size() {
let error = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(8)
.sglang(Some(SglangArgs {
page_size: Some(4),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
error
.to_string()
.contains("block_size and sglang.page_size to match"),
"unexpected error: {error}",
);
}
#[test]
fn test_normalized_sglang_defaults_block_size_to_one() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 1);
}
#[test]
fn test_from_json_file_normalizes_sglang_page_size() {
let tempdir = tempfile::tempdir().unwrap();
let path = tempdir.path().join("args.json");
std::fs::write(
&path,
serde_json::to_string(&json!({
"engine_type": "sglang",
"sglang": {
"page_size": 32
}
}))
.unwrap(),
)
.unwrap();
let args = MockEngineArgs::from_json_file(&path).unwrap();
assert_eq!(args.block_size, 32);
}
#[test]
fn test_normalized_sglang_rejects_chunked_prefill_not_divisible_by_block_size() {
let error = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(6),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
error
.to_string()
.contains("chunked_prefill_size to be divisible by block_size"),
"unexpected error: {error}",
);
}
#[test]
fn test_normalized_sglang_accepts_chunked_prefill_divisible_by_block_size() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(8),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 4);
}
}
......@@ -134,13 +134,12 @@ impl ActiveSequence {
let hashes = self.block_hashes[hash_start..hash_end].to_vec();
let token_ids = if self.emit_token_ids && hash_start < hash_end {
let all_token_ids: Vec<Vec<u32>> = self
.tokens
.blocks()
.iter()
.map(|b| b.tokens().to_vec())
.collect();
Some(all_token_ids[hash_start..hash_end].to_vec())
Some(
self.tokens.blocks()[hash_start..hash_end]
.iter()
.map(|b| b.tokens().to_vec())
.collect(),
)
} else {
None
};
......@@ -276,13 +275,15 @@ impl ActiveSequence {
}
// Free all blocks when we reach max tokens
signals.extend(self.free_signal());
signals.extend(self.free_signal_for_tokens(self.len()));
signals
}
/// Free all blocks, generating appropriate signals for each block type
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.unique_blocks
fn free_signal_for_tokens(&self, active_tokens: usize) -> Vec<MoveBlock> {
let active_blocks = active_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
self.unique_blocks[..active_blocks]
.iter()
.rev()
.map(|block| match block {
......@@ -296,6 +297,11 @@ impl ActiveSequence {
.collect()
}
/// Free the currently active allocation footprint.
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.free_signal_for_tokens(self.num_allocated_tokens)
}
/// Move the request to a preempted state and return the free signals from freeing current blocks.
/// Upon preemption, the sequence retains the tokens generated during the decode phase (if any).
/// Resets `num_allocated_tokens` so re-admission will re-allocate from scratch.
......
......@@ -3,12 +3,10 @@
//! Engine factory — creates the appropriate scheduler based on [`EngineType`].
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{EngineType, KvCacheEventSink, MockEngineArgs, OutputSignal};
use crate::common::protocols::{EngineType, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
/// Create a scheduler for the configured engine type.
......@@ -19,7 +17,7 @@ pub fn create_engine(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Box<dyn SchedulerHandle> {
match args.engine_type {
......@@ -27,14 +25,14 @@ pub fn create_engine(
args,
dp_rank,
output_tx,
kv_event_sink,
kv_event_publishers,
cancellation_token,
)),
EngineType::Sglang => Box::new(SglangScheduler::new(
args,
dp_rank,
output_tx,
kv_event_sink,
kv_event_publishers,
cancellation_token,
)),
}
......
......@@ -7,11 +7,10 @@
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::cache::radix_cache::{NodeId, RadixCache};
use crate::common::kv_cache_trace;
use crate::common::protocols::KvCacheEventSink;
use crate::common::protocols::KvEventPublishers;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData,
......@@ -31,27 +30,32 @@ pub struct AllocResult {
pub struct SglangKvManager {
cache: RadixCache,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
next_event_id: u64,
/// Maps pool_idx → block_hash assigned during Stored events,
/// so Removed events can use the same block_hash.
idx_to_block_hash: HashMap<usize, ExternalSequenceBlockHash>,
/// Tracks how many live pool slots currently advertise the same logical
/// block hash so router events reflect logical block visibility, not
/// transient slot ownership.
block_hash_refcounts: HashMap<ExternalSequenceBlockHash, usize>,
}
impl SglangKvManager {
pub fn new(
total_tokens: usize,
page_size: usize,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
) -> Self {
Self {
cache: RadixCache::new(total_tokens, page_size),
kv_event_sink,
kv_event_publishers,
dp_rank,
next_event_id: 0,
idx_to_block_hash: HashMap::new(),
block_hash_refcounts: HashMap::new(),
}
}
......@@ -94,6 +98,39 @@ impl SglangKvManager {
})
}
/// Continue an in-flight request from an already materialized prefix.
///
/// This is used by chunked-prefill continuation where the request still
/// owns token slots for a prefix that may extend past the radix-tree's
/// page-aligned cached prefix.
pub fn allocate_after_prefix(
&mut self,
token_ids: &[u64],
prefix_len: usize,
prefix_indices: &[usize],
last_node: NodeId,
) -> Option<AllocResult> {
let new_tokens = token_ids.len().saturating_sub(prefix_len);
let new_indices = self.cache.token_pool.allocate(new_tokens)?;
let mut kv_indices = prefix_indices[..prefix_len].to_vec();
kv_indices.extend_from_slice(&new_indices);
self.cache.inc_lock_ref(last_node);
let parent_hash = kv_indices
.get(prefix_len.wrapping_sub(1))
.and_then(|&idx| self.idx_to_block_hash.get(&idx).copied());
self.publish_stored_event(&token_ids[prefix_len..], &new_indices, parent_hash);
self.log_trace("allocation", new_tokens);
Some(AllocResult {
prefix_len,
kv_indices,
last_node,
})
}
/// Cache a completed request's full sequence into the radix tree.
///
/// Inserts the full token sequence so future requests can reuse it,
......@@ -152,6 +189,18 @@ impl SglangKvManager {
self.cache.dec_lock_ref(last_node);
}
/// Return request-owned token slots to the free pool and publish matching
/// removal events for any slots that were previously advertised to the router.
pub fn free_indices(&mut self, indices: &[usize]) {
if indices.is_empty() {
return;
}
self.cache.token_pool.free(indices);
self.publish_removed_event(indices);
self.log_trace("free", indices.len());
}
/// Collect token indices from the matched prefix path by walking root→last_node.
fn collect_path_indices(&self, last_node: NodeId) -> Vec<usize> {
if last_node == self.cache.root() {
......@@ -210,19 +259,12 @@ impl SglangKvManager {
if indices.is_empty() {
return;
}
let Some(ref sink) = self.kv_event_sink else {
return;
};
let mut blocks = Vec::with_capacity(indices.len());
let mut computed_blocks = Vec::with_capacity(indices.len());
let mut running_hash = parent_hash.map_or(0u64, |h| h.0);
for (i, &idx) in indices.iter().enumerate() {
// tokens_hash: per-token content hash for router prefix matching
let token_bytes: Vec<u8> = token_ids
.get(i)
.unwrap_or(&(idx as u64))
.to_le_bytes()
.to_vec();
let token = token_ids.get(i).copied().unwrap_or(idx as u64);
let token_bytes = token.to_le_bytes();
let tokens_hash = dynamo_kv_router::protocols::compute_block_hash(&token_bytes);
// block_hash: cumulative hash (parent_hash, token_id) so it's unique
......@@ -234,14 +276,36 @@ impl SglangKvManager {
let block_hash = ExternalSequenceBlockHash(running_hash);
self.idx_to_block_hash.insert(idx, block_hash);
blocks.push(KvCacheStoredBlockData {
*self.block_hash_refcounts.entry(block_hash).or_default() += 1;
computed_blocks.push(KvCacheStoredBlockData {
block_hash,
tokens_hash,
mm_extra_info: None,
});
}
if self.kv_event_publishers.is_empty() {
return;
}
let first_new = computed_blocks.iter().position(|block| {
self.block_hash_refcounts
.get(&block.block_hash)
.copied()
.unwrap_or_default()
== 1
});
let Some(first_new) = first_new else {
return;
};
let parent_hash = if first_new == 0 {
parent_hash
} else {
Some(computed_blocks[first_new - 1].block_hash)
};
let blocks = computed_blocks.into_iter().skip(first_new).collect();
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......@@ -252,20 +316,31 @@ impl SglangKvManager {
};
self.next_event_id += 1;
if let Err(e) = sink.publish(event, None) {
if let Err(e) = self.kv_event_publishers.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV event: {e}");
}
}
fn publish_removed_event(&mut self, evicted_indices: &[usize]) {
let Some(ref sink) = self.kv_event_sink else {
if self.kv_event_publishers.is_empty() {
return;
};
}
let block_hashes: Vec<ExternalSequenceBlockHash> = evicted_indices
.iter()
.filter_map(|&idx| self.idx_to_block_hash.remove(&idx))
.collect();
let mut block_hashes = Vec::new();
for &idx in evicted_indices {
let Some(block_hash) = self.idx_to_block_hash.remove(&idx) else {
continue;
};
let Some(refcount) = self.block_hash_refcounts.get_mut(&block_hash) else {
continue;
};
if *refcount > 1 {
*refcount -= 1;
continue;
}
self.block_hash_refcounts.remove(&block_hash);
block_hashes.push(block_hash);
}
if block_hashes.is_empty() {
return;
......@@ -278,7 +353,7 @@ impl SglangKvManager {
};
self.next_event_id += 1;
if let Err(e) = sink.publish(event, None) {
if let Err(e) = self.kv_event_publishers.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV remove event: {e}");
}
}
......@@ -287,8 +362,11 @@ impl SglangKvManager {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::Mutex;
use crate::common::protocols::KvCacheEventSink;
struct MockSink {
events: Mutex<Vec<KvCacheEvent>>,
}
......@@ -299,17 +377,18 @@ mod tests {
events: Mutex::new(Vec::new()),
}
}
fn event_count(&self) -> usize {
self.events.lock().unwrap().len()
}
fn clone_events(&self) -> Vec<KvCacheEvent> {
self.events.lock().unwrap().clone()
}
}
impl KvCacheEventSink for MockSink {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
......@@ -317,7 +396,7 @@ mod tests {
#[test]
fn test_allocate_cache_miss() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let mut mgr = SglangKvManager::new(100, 1, KvEventPublishers::default(), 0);
let result = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(result.prefix_len, 0);
......@@ -327,7 +406,7 @@ mod tests {
#[test]
fn test_allocate_cache_hit() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let mut mgr = SglangKvManager::new(100, 1, KvEventPublishers::default(), 0);
// First request: allocate and cache
let r1 = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
......@@ -343,7 +422,7 @@ mod tests {
#[test]
fn test_free_request_without_caching() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let mut mgr = SglangKvManager::new(100, 1, KvEventPublishers::default(), 0);
let result = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
mgr.free_request(result.last_node);
......@@ -355,7 +434,8 @@ mod tests {
#[test]
fn test_event_publishing() {
let sink = Arc::new(MockSink::new());
let mut mgr = SglangKvManager::new(100, 1, Some(sink.clone()), 0);
let mut mgr =
SglangKvManager::new(100, 1, KvEventPublishers::new(Some(sink.clone()), None), 0);
let r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(sink.event_count(), 1); // BlockStored for 3 new pages
......@@ -368,13 +448,93 @@ mod tests {
assert_eq!(sink.event_count(), 1); // no new event
}
#[test]
fn test_duplicate_logical_blocks_publish_once_and_remove_once() {
let sink = Arc::new(MockSink::new());
let mut mgr =
SglangKvManager::new(100, 1, KvEventPublishers::new(Some(sink.clone()), None), 0);
let req1 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
let req2 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
let events = sink.clone_events();
assert_eq!(events.len(), 1);
let KvCacheEventData::Stored(store) = &events[0].data else {
panic!("expected stored event");
};
assert_eq!(store.blocks.len(), 3);
mgr.free_indices(&req1.kv_indices);
assert_eq!(sink.event_count(), 1);
mgr.free_indices(&req2.kv_indices);
let events = sink.clone_events();
assert_eq!(events.len(), 2);
let KvCacheEventData::Removed(remove) = &events[1].data else {
panic!("expected removed event");
};
assert_eq!(remove.block_hashes.len(), 3);
}
#[test]
fn test_allocate_oom() {
let mut mgr = SglangKvManager::new(3, 1, None, 0);
let mut mgr = SglangKvManager::new(3, 1, KvEventPublishers::default(), 0);
let _r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
// Pool is full
let result = mgr.allocate_for_request(&[4, 5, 6]);
assert!(result.is_none());
}
#[test]
fn test_chunked_prefill_parent_hash() {
let sink = Arc::new(MockSink::new());
let mut mgr =
SglangKvManager::new(32, 1, KvEventPublishers::new(Some(sink.clone()), None), 0);
let tokens = [11, 22, 33, 44, 55, 66];
let chunk1_len = 3;
let chunk2_len = 6;
let alloc1 = mgr.allocate_for_request(&tokens[..chunk1_len]).unwrap();
let new_last =
mgr.cache_unfinished_req(&tokens[..chunk1_len], &alloc1.kv_indices, alloc1.last_node);
let alloc2 = mgr.allocate_for_request(&tokens[..chunk2_len]).unwrap();
mgr.free_request(new_last);
let events = sink.events.lock().unwrap();
assert_eq!(events.len(), 2, "expected two stored events");
let KvCacheEventData::Stored(store1) = &events[0].data else {
panic!("expected first event to be Stored");
};
let KvCacheEventData::Stored(store2) = &events[1].data else {
panic!("expected second event to be Stored");
};
assert!(
store1.parent_hash.is_none(),
"first chunk should start from the root"
);
let last_block_hash = store1
.blocks
.last()
.expect("first chunk should store at least one block")
.block_hash;
assert_eq!(
store2.parent_hash,
Some(last_block_hash),
"second chunk should chain from the last block of chunk 1"
);
assert_eq!(
store2.blocks.len(),
chunk2_len - chunk1_len,
"second chunk should only emit new blocks"
);
assert_eq!(
alloc2.prefix_len, chunk1_len,
"second chunk should reuse the cached partial prefix"
);
}
}
......@@ -36,7 +36,7 @@
//! implementation of the main block manager.
use crate::cache::HashCache;
use crate::common::kv_cache_trace;
use crate::common::protocols::{KvCacheEventSink, MoveBlock, PrefillCost};
use crate::common::protocols::{KvEventPublishers, MoveBlock, PrefillCost};
use crate::common::sequence::ActiveSequence;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
......@@ -45,29 +45,28 @@ use dynamo_kv_router::protocols::{
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap;
use std::sync::Arc;
pub struct KvManager {
cache: HashCache,
block_size: usize,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
next_event_id: u64,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_event_sink(max_capacity, block_size, None, 0)
Self::new_with_event_sink(max_capacity, block_size, KvEventPublishers::default(), 0)
}
pub fn new_with_event_sink(
max_capacity: usize,
block_size: usize,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
) -> Self {
debug_assert!(max_capacity > 0, "max_capacity must be > 0");
if kv_event_sink.is_some() {
if !kv_event_publishers.is_empty() {
tracing::info!(
"KvManager initialized with event sink for DP rank {dp_rank} with block_size {block_size}"
);
......@@ -76,7 +75,7 @@ impl KvManager {
KvManager {
cache: HashCache::new(max_capacity),
block_size,
kv_event_sink,
kv_event_publishers,
dp_rank,
next_event_id: 0,
}
......@@ -104,9 +103,9 @@ impl KvManager {
self.cache.max_capacity(),
);
let Some(ref sink) = self.kv_event_sink else {
if self.kv_event_publishers.is_empty() {
return;
};
}
let event_data = if is_store {
let num_blocks = full_blocks.len();
......@@ -145,7 +144,10 @@ impl KvManager {
dp_rank: self.dp_rank,
};
if let Err(e) = sink.publish(event, token_ids.as_deref()) {
if let Err(e) = self
.kv_event_publishers
.publish(event, token_ids.as_deref())
{
tracing::warn!("Failed to publish KV event: {e}");
}
}
......@@ -384,6 +386,9 @@ impl KvManager {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::common::protocols::KvCacheEventSink;
#[test]
fn test_failure_on_max_capacity() {
......@@ -548,11 +553,7 @@ mod tests {
}
impl KvCacheEventSink for CapturingSink {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
......@@ -563,8 +564,12 @@ mod tests {
let mut seq = ActiveSequence::new(tokens, 100, Some(block_size), true, false);
let sink = Arc::new(CapturingSink::default());
let mut manager =
KvManager::new_with_event_sink(256, block_size, Some(sink.clone() as _), 0);
let mut manager = KvManager::new_with_event_sink(
256,
block_size,
KvEventPublishers::new(Some(sink.clone() as _), None),
0,
);
// Chunk 1: allocate blocks 0-3
let signal = seq.prepare_allocation(256).unwrap();
......@@ -603,4 +608,42 @@ mod tests {
"second chunk's parent should be block 3's seq_hash"
);
}
#[test]
fn test_repreempt_after_partial_recompute_only_frees_reallocated_blocks() {
let mut seq = ActiveSequence::new((0..6).collect(), 16, Some(4), true, false);
let mut manager = KvManager::new(16, 4);
let signal = seq.take_creation_signal().unwrap();
assert_eq!(manager.process(&signal), 2);
for _ in 0..3 {
let signals = seq.generate();
for signal in &signals {
manager.process(signal);
}
if seq.generated_tokens() < seq.max_output_tokens() {
seq.commit_allocation(seq.len());
}
}
assert_eq!(manager.num_active_blocks(), 3);
let first_reset = seq.reset_with_signal();
for signal in &first_reset {
manager.process(signal);
}
assert_eq!(manager.num_active_blocks(), 0);
let prompt_only = seq.prepare_allocation(seq.num_input_tokens()).unwrap();
assert_eq!(manager.process(&prompt_only), 2);
seq.commit_allocation(seq.num_input_tokens());
assert_eq!(manager.num_active_blocks(), 2);
let second_reset = seq.reset_with_signal();
for signal in &second_reset {
manager.process(signal);
}
assert_eq!(manager.num_active_blocks(), 0);
}
}
......@@ -11,5 +11,5 @@ pub mod cache;
pub mod common;
pub mod engine;
pub mod kv_manager;
pub mod replay;
pub mod scheduler;
pub mod simulation;
......@@ -2,18 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::time::Instant;
use anyhow::{Context, Result, anyhow, bail};
use serde::Serialize;
use serde::ser::{SerializeMap, Serializer};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs, WorkerType};
#[derive(Debug, Clone)]
pub struct TraceSimulationReport {
pub request_counts: TraceRequestCounts,
......@@ -369,200 +362,6 @@ impl TraceCollector {
}
}
#[derive(Debug, Deserialize)]
struct RawTraceRecord {
#[serde(default)]
timestamp: Option<f64>,
#[serde(default)]
created_time: Option<f64>,
#[serde(default, alias = "input_tokens")]
input_length: Option<usize>,
#[serde(default, alias = "output_tokens")]
output_length: Option<usize>,
#[serde(default)]
hash_ids: Option<Vec<u64>>,
}
pub fn simulate_trace_file(
args: MockEngineArgs,
trace_path: &Path,
num_workers: usize,
) -> Result<TraceSimulationReport> {
validate_offline_replay_args(&args, num_workers)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
let started_at = Instant::now();
let report = crate::scheduler::vllm::simulate_trace(args, requests)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_file(
args: MockEngineArgs,
trace_path: &Path,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
let requests = load_trace_requests(trace_path, args.block_size, false)?;
let started_at = Instant::now();
let report = simulate_concurrency_requests(args, requests, max_in_flight, num_workers)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
validate_offline_concurrency_args(&args, num_workers, max_in_flight)?;
if requests.is_empty() {
bail!("concurrency replay requires at least one request");
}
crate::scheduler::vllm::simulate_concurrency(args, requests, max_in_flight)
}
fn validate_offline_replay_args(args: &MockEngineArgs, num_workers: usize) -> Result<()> {
if num_workers != 1 {
bail!(
"trace replay only supports num_workers=1, got {}",
num_workers
);
}
if args.engine_type != EngineType::Vllm {
bail!(
"trace replay only supports engine_type=vllm, got {:?}",
args.engine_type
);
}
if args.worker_type != WorkerType::Aggregated {
bail!(
"trace replay only supports aggregated workers, got {:?}",
args.worker_type
);
}
if args.dp_size != 1 {
bail!(
"trace replay only supports data_parallel_size=1, got {}",
args.dp_size
);
}
Ok(())
}
fn validate_offline_concurrency_args(
args: &MockEngineArgs,
num_workers: usize,
max_in_flight: usize,
) -> Result<()> {
if max_in_flight == 0 {
bail!("concurrency replay requires max_in_flight >= 1");
}
validate_offline_replay_args(args, num_workers)
}
fn load_trace_requests(
trace_path: &Path,
trace_block_size: usize,
timestamps_required: bool,
) -> Result<Vec<DirectRequest>> {
let file = File::open(trace_path)
.with_context(|| format!("failed to open trace file {}", trace_path.display()))?;
let reader = BufReader::new(file);
let mut requests = Vec::new();
for (line_idx, line) in reader.lines().enumerate() {
let line = line.with_context(|| {
format!(
"failed to read line {} from {}",
line_idx + 1,
trace_path.display()
)
})?;
if line.trim().is_empty() {
continue;
}
let raw: RawTraceRecord = serde_json::from_str(&line).with_context(|| {
format!(
"failed to parse line {} from {} as JSON",
line_idx + 1,
trace_path.display()
)
})?;
let input_length = raw
.input_length
.ok_or_else(|| anyhow!("trace line {} is missing input_length", line_idx + 1))?;
let output_length = raw
.output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
let hash_ids = raw
.hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
let arrival_timestamp_ms = if timestamps_required {
match raw.timestamp.or(raw.created_time) {
Some(timestamp_ms) => Some(timestamp_ms),
None => return Err(anyhow!("trace line {} is missing timestamp", line_idx + 1)),
}
} else {
None
};
let tokens = synthesize_tokens_from_hash_ids(&hash_ids, input_length, trace_block_size)
.with_context(|| {
format!(
"failed to synthesize tokens from hash_ids on line {}",
line_idx + 1
)
})?;
requests.push(DirectRequest {
tokens,
max_output_tokens: output_length,
uuid: Some(Uuid::new_v4()),
dp_rank: 0,
arrival_timestamp_ms,
});
}
if requests.is_empty() {
bail!(
"trace file {} did not contain any requests",
trace_path.display()
);
}
Ok(requests)
}
fn synthesize_tokens_from_hash_ids(
hash_ids: &[u64],
input_length: usize,
trace_block_size: usize,
) -> Result<Vec<u32>> {
let mut tokens = Vec::with_capacity(input_length);
for &hash_id in hash_ids {
let token_id = u32::try_from(hash_id)
.map_err(|_| anyhow!("hash_id {hash_id} exceeds u32::MAX for token synthesis"))?;
// TODO: Replace this repeated-token expansion with a hash-native prompt representation.
tokens.extend((0..trace_block_size).map(|_| token_id));
if tokens.len() >= input_length {
tokens.truncate(input_length);
return Ok(tokens);
}
}
bail!(
"input_length {} exceeds synthesized capacity {} from {} hash_ids and block_size {}",
input_length,
hash_ids.len() * trace_block_size,
hash_ids.len(),
trace_block_size
);
}
fn mean(values: &[f64]) -> f64 {
if values.is_empty() {
0.0
......@@ -620,41 +419,3 @@ fn std_dev(values: &[f64]) -> f64 {
/ values.len() as f64;
variance.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_itl_uses_per_token_gaps() {
let mut collector = TraceCollector::default();
let uuid = Uuid::from_u128(11);
collector.on_arrival(uuid, 0.0, 4, 4);
collector.on_admit(uuid, 0.0, 0);
collector.on_token(uuid, 10.0);
collector.on_token(uuid, 11.0);
collector.on_token(uuid, 12.0);
collector.on_token(uuid, 110.0);
let report = collector.finish();
assert!((report.latency.tpot.mean_ms - (100.0 / 3.0)).abs() < 1e-9);
assert!((report.latency.itl.distribution.mean_ms - (100.0 / 3.0)).abs() < 1e-9);
assert_eq!(report.latency.itl.distribution.median_ms, 1.0);
assert_eq!(report.latency.itl.distribution.p75_ms, 98.0);
assert_eq!(report.latency.itl.distribution.p90_ms, 98.0);
assert_eq!(report.latency.itl.distribution.p95_ms, 98.0);
assert_eq!(report.latency.itl.max_ms, 98.0);
assert_eq!(report.latency.ttst.min_ms, 1.0);
assert_eq!(report.latency.ttst.max_ms, 1.0);
assert_eq!(
report.latency.output_token_throughput_per_user.min_ms,
1000.0 / 98.0
);
assert_eq!(
report.latency.output_token_throughput_per_user.max_ms,
1000.0
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::path::Path;
use std::time::Instant;
use anyhow::{Result, bail};
use dynamo_kv_router::config::KvRouterConfig;
use super::loader::load_trace_requests;
use super::online;
use super::validate::{
validate_offline_concurrency_args, validate_offline_replay_args,
validate_online_concurrency_args, validate_online_replay_args,
};
use super::{ReplayRouterMode, TraceSimulationReport};
use crate::common::protocols::{DirectRequest, MockEngineArgs};
pub fn simulate_trace_file(
args: MockEngineArgs,
trace_path: &Path,
num_workers: usize,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
simulate_trace_file_with_router_mode(
args,
None,
trace_path,
num_workers,
arrival_speedup_ratio,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_file(
args: MockEngineArgs,
trace_path: &Path,
num_workers: usize,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
simulate_trace_live_file_with_router_mode(
args,
None,
trace_path,
num_workers,
arrival_speedup_ratio,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_live_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
online::simulate_trace_requests(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
pub fn simulate_trace_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
simulate_trace_requests_with_router_mode(
args,
None,
requests,
num_workers,
arrival_speedup_ratio,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
if requests.is_empty() {
bail!("trace replay requires at least one request");
}
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
simulate_trace_live_requests_with_router_mode(
args,
None,
requests,
num_workers,
arrival_speedup_ratio,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_live_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
if requests.is_empty() {
bail!("trace replay requires at least one request");
}
online::simulate_trace_requests(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
pub fn simulate_concurrency_file(
args: MockEngineArgs,
trace_path: &Path,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_file_with_router_mode(
args,
None,
trace_path,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let requests = load_trace_requests(trace_path, args.block_size, false)?;
let started_at = Instant::now();
let report = simulate_concurrency_requests_with_router_mode(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_live_file(
args: MockEngineArgs,
trace_path: &Path,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_live_file_with_router_mode(
args,
None,
trace_path,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_live_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
let requests = load_trace_requests(trace_path, args.block_size, false)?;
online::simulate_concurrency_requests(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
pub fn simulate_concurrency_live_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_live_requests_with_router_mode(
args,
None,
requests,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_live_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
if requests.is_empty() {
bail!("concurrency replay requires at least one request");
}
online::simulate_concurrency_requests(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
pub fn simulate_concurrency_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_requests_with_router_mode(
args,
None,
requests,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
if requests.is_empty() {
bail!("concurrency replay requires at least one request");
}
crate::replay::offline::simulate_concurrency(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use anyhow::{Context, Result, anyhow, bail};
use serde::Deserialize;
use uuid::Uuid;
use crate::common::protocols::DirectRequest;
#[derive(Debug, Deserialize)]
struct RawTraceRecord {
#[serde(default)]
timestamp: Option<f64>,
#[serde(default)]
created_time: Option<f64>,
#[serde(default, alias = "input_tokens")]
input_length: Option<usize>,
#[serde(default, alias = "output_tokens")]
output_length: Option<usize>,
#[serde(default)]
hash_ids: Option<Vec<u64>>,
}
pub(super) fn load_trace_requests(
trace_path: &Path,
trace_block_size: usize,
timestamps_required: bool,
) -> Result<Vec<DirectRequest>> {
let file = File::open(trace_path)
.with_context(|| format!("failed to open trace file {}", trace_path.display()))?;
let reader = BufReader::new(file);
let mut requests = Vec::new();
for (line_idx, line) in reader.lines().enumerate() {
let line = line.with_context(|| {
format!(
"failed to read line {} from {}",
line_idx + 1,
trace_path.display()
)
})?;
if line.trim().is_empty() {
continue;
}
let raw: RawTraceRecord = serde_json::from_str(&line).with_context(|| {
format!(
"failed to parse line {} from {} as JSON",
line_idx + 1,
trace_path.display()
)
})?;
let input_length = raw
.input_length
.ok_or_else(|| anyhow!("trace line {} is missing input_length", line_idx + 1))?;
let output_length = raw
.output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
let hash_ids = raw
.hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
let arrival_timestamp_ms = if timestamps_required {
match raw.timestamp.or(raw.created_time) {
Some(timestamp_ms) => Some(timestamp_ms),
None => return Err(anyhow!("trace line {} is missing timestamp", line_idx + 1)),
}
} else {
None
};
let tokens = synthesize_tokens_from_hash_ids(&hash_ids, input_length, trace_block_size)
.with_context(|| {
format!(
"failed to synthesize tokens from hash_ids on line {}",
line_idx + 1
)
})?;
requests.push(DirectRequest {
tokens,
max_output_tokens: output_length,
uuid: Some(Uuid::new_v4()),
dp_rank: 0,
arrival_timestamp_ms,
});
}
if requests.is_empty() {
bail!(
"trace file {} did not contain any requests",
trace_path.display()
);
}
Ok(requests)
}
fn synthesize_tokens_from_hash_ids(
hash_ids: &[u64],
input_length: usize,
trace_block_size: usize,
) -> Result<Vec<u32>> {
let mut tokens = Vec::with_capacity(input_length);
for &hash_id in hash_ids {
let token_id = u32::try_from(hash_id)
.map_err(|_| anyhow!("hash_id {hash_id} exceeds u32::MAX for token synthesis"))?;
// TODO: Replace this repeated-token expansion with a hash-native prompt representation.
tokens.extend((0..trace_block_size).map(|_| token_id));
if tokens.len() >= input_length {
tokens.truncate(input_length);
return Ok(tokens);
}
}
bail!(
"input_length {} exceeds synthesized capacity {} from {} hash_ids and block_size {}",
input_length,
hash_ids.len() * trace_block_size,
hash_ids.len(),
trace_block_size
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment