Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
use crate::common::protocols::OutputSignal;
use crate::kv_manager::SglangKvManager;
use super::config::{SglangConfig, floor_to_block};
use super::request::SglangRequest;
#[derive(Default)]
pub(super) struct DecodeResult {
pub(super) requests: Vec<SglangRequest>,
pub(super) output_signals: Vec<OutputSignal>,
pub(super) retracted_any: bool,
pub(super) end_ms: f64,
}
fn decode_capacity_state(
running: &[SglangRequest],
kv_manager: &SglangKvManager,
config: &SglangConfig,
) -> (usize, usize) {
let actual_available =
kv_manager.cache().available_tokens() + kv_manager.cache().evictable_size;
let reserved_tokens = running
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>();
let logical_available = actual_available.saturating_sub(reserved_tokens);
let page_growth_needed = running
.iter()
.map(|req| {
if req.current_sequence_len() + 1 > req.allocated_tokens {
config.block_size
} else {
0
}
})
.sum();
(
actual_available,
logical_available.saturating_sub(page_growth_needed),
)
}
pub(super) fn cache_materialized_prefix(
req: &mut SglangRequest,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
) {
let aligned_tokens = req.page_aligned_materialized_tokens(config.block_size);
if aligned_tokens == 0 || aligned_tokens <= req.cached_tokens {
return;
}
let Some(last_node) = req.last_node else {
return;
};
let sequence = req.sequence_prefix(aligned_tokens);
let new_last =
kv_manager.cache_unfinished_req(&sequence, &req.kv_indices[..aligned_tokens], last_node);
req.last_node = Some(new_last);
req.cached_tokens = aligned_tokens;
req.debug_assert_invariants(config.block_size);
}
pub(super) fn check_decode_mem(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
) -> Vec<SglangRequest> {
let mut retracted = Vec::new();
loop {
let (actual_available, logical_available_after_growth) =
decode_capacity_state(running, kv_manager, config);
if actual_available >= running.len() && logical_available_after_growth > 0 {
break;
}
if running.len() <= 1 {
break;
}
let Some((idx, _)) = running
.iter()
.enumerate()
.min_by_key(|(_, req)| req.output_len())
else {
break;
};
let mut req = running.swap_remove(idx);
kv_manager.free_indices(&req.kv_indices[req.cached_tokens..]);
if let Some(last_node) = req.last_node.take() {
kv_manager.free_request(last_node);
}
req.reset_for_retract();
req.debug_assert_invariants(config.block_size);
retracted.push(req);
}
let available = kv_manager.cache().token_pool.available();
let needed = running.len();
if available < needed {
kv_manager.evict(needed - available);
}
if !retracted.is_empty() {
tracing::warn!(
num_retracted = retracted.len(),
remaining = running.len(),
"SGLang decode retract requests because KV pool is full"
);
}
retracted
}
pub(super) fn simulate_decode_step(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
current_time_ms: f64,
apply_speedup: bool,
) -> DecodeResult {
if running.is_empty() {
return DecodeResult {
end_ms: current_time_ms,
..DecodeResult::default()
};
}
let total_context: usize = running
.iter()
.map(SglangRequest::current_sequence_len)
.sum();
let avg_context = total_context / running.len();
let decode_time =
config
.perf_model
.predict_decode_time(running.len(), total_context, avg_context);
let unscaled_time = Duration::from_secs_f64(decode_time / 1000.0);
let effective_ratio = config.speedup_ratio * config.decode_speedup_ratio;
let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
Duration::from_secs_f64(unscaled_time.as_secs_f64() / effective_ratio)
} else {
unscaled_time
};
let retracted = check_decode_mem(running, kv_manager, config);
let retracted_any = !retracted.is_empty();
let mut output_signals = Vec::with_capacity(running.len());
let mut completed_indices = Vec::new();
for (idx, req) in running.iter_mut().enumerate() {
if kv_manager.cache().token_pool.available() == 0 {
kv_manager.evict(1);
}
let crossing_page_boundary = req.current_sequence_len() + 1 > req.allocated_tokens;
let last_idx = req.kv_indices.last().copied();
let Some(new_idx) = kv_manager.allocate_decode_token(last_idx) else {
tracing::warn!(uuid = %req.uuid, "Failed to allocate decode token, skipping output");
continue;
};
req.kv_indices.push(new_idx);
if crossing_page_boundary {
req.allocated_tokens += config.block_size;
}
req.append_output_token(req.next_output_token());
req.debug_assert_invariants(config.block_size);
let is_complete = req.output_len() >= req.max_output_tokens;
output_signals.push(OutputSignal {
uuid: req.uuid,
completed: is_complete,
});
if is_complete {
let sequence = req.sequence_tokens();
let tokens_to_cache = floor_to_block(sequence.len(), config.block_size);
if req.kv_indices.len() > tokens_to_cache {
kv_manager.free_indices(&req.kv_indices[tokens_to_cache..]);
}
if let Some(last_node) = req.last_node.take() {
if tokens_to_cache > 0 {
kv_manager.cache_finished_req(
&sequence[..tokens_to_cache],
&req.kv_indices[..tokens_to_cache],
last_node,
);
} else {
kv_manager.free_request(last_node);
}
}
completed_indices.push(idx);
continue;
}
cache_materialized_prefix(req, kv_manager, config);
req.debug_assert_invariants(config.block_size);
}
for &idx in completed_indices.iter().rev() {
running.swap_remove(idx);
}
DecodeResult {
requests: retracted,
output_signals,
retracted_any,
end_ms: current_time_ms + total_time.as_secs_f64() * 1000.0,
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
AdmissionEvent, MockerMetrics, RouterEventVisibility, SchedulerHandle,
capture_deferred_kv_publish_sink, publish_deferred_kv_events,
};
use super::core::SglangCore;
use super::request::{SglangRequest, direct_to_sglang};
#[derive(Clone)]
pub struct SglangScheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
impl SglangScheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
None,
)
}
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
admission_tx,
)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let total_blocks = args.num_gpu_blocks as u64;
let initial_metrics = MockerMetrics::new(dp_rank, 0, total_blocks);
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
tokio::spawn(async move {
let (deferred_kv_events, buffering_publishers) =
capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
let mut core = SglangCore::new_with_sink(args, dp_rank, buffering_publishers);
loop {
if receive_requests(
&mut core.waiting,
&mut request_rx,
&cancel_token_clone,
&core.running,
)
.await
.is_none()
{
break;
}
let iteration_start = Instant::now();
let pass = core.execute_pass_internal(None, 0.0);
if let Some(admission_tx) = admission_tx.as_ref() {
for admission in &pass.admissions {
let _ = admission_tx.send(admission.clone());
}
}
if pass.router_event_visibility == RouterEventVisibility::PassStart {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
if total_time > std::time::Duration::ZERO {
sleep_until_precise(iteration_start + total_time).await;
}
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&output_tx, &pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
pass.active_decode_blocks,
total_blocks,
));
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl SchedulerHandle for SglangScheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
waiting: &mut std::collections::VecDeque<SglangRequest>,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
running: &[SglangRequest],
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if waiting.is_empty() && running.is_empty() {
tokio::select! {
biased;
_ = cancel_token.cancelled() => return None,
result = request_rx.recv() => {
let request = result?;
waiting.push_back(direct_to_sglang(request));
}
}
}
while let Ok(request) = request_rx.try_recv() {
waiting.push_back(direct_to_sglang(request));
}
Some(())
}
fn flush_output_signals(
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
let _ = tx.send(signal.clone());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SGLang scheduler simulation with adaptive admission control.
//!
//! Reference: sglang/python/sglang/srt/managers/scheduler.py
mod config;
mod core;
mod decode;
mod live;
mod policy;
mod prefill;
mod request;
pub(crate) use core::SglangCore;
pub use live::SglangScheduler;
#[cfg(test)]
mod tests;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashSet, VecDeque};
use crate::cache::radix_cache::RadixCache;
use crate::kv_manager::SglangKvManager;
use super::config::{
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD, IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD,
LPM_FALLBACK_THRESHOLD, SchedulePolicy, SglangConfig,
};
use super::request::SglangRequest;
pub(super) fn apply_schedule_policy(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &SglangKvManager,
config: &SglangConfig,
) {
match config.schedule_policy {
SchedulePolicy::Fifo => {}
SchedulePolicy::Lpm => {
if waiting.len() > LPM_FALLBACK_THRESHOLD {
return;
}
let page_size = config.block_size.max(1);
let total_tokens = waiting
.iter()
.map(SglangRequest::current_sequence_len)
.sum::<usize>()
.max(page_size);
let mut waiting_queue_cache = RadixCache::new(total_tokens, page_size);
let mut temporary_deprioritized = HashSet::new();
let mut scored = Vec::with_capacity(waiting.len());
for req in waiting.drain(..) {
let sequence = req.sequence_tokens();
let prefix_len = kv_manager.cache().prefix_match_len(&sequence);
if prefix_len <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD {
let in_batch_prefix = waiting_queue_cache.prefix_match_len(&sequence);
if in_batch_prefix >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD {
temporary_deprioritized.insert(req.uuid);
} else if !sequence.is_empty() {
let values: Vec<usize> = (0..sequence.len()).collect();
waiting_queue_cache.insert(&sequence, &values);
}
}
scored.push((prefix_len, temporary_deprioritized.contains(&req.uuid), req));
}
scored.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| b.0.cmp(&a.0)));
for (_, _, req) in scored {
waiting.push_back(req);
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use super::super::AdmissionEvent;
use super::config::{SglangConfig, ceil_to_block};
use super::request::SglangRequest;
use crate::kv_manager::SglangKvManager;
pub(super) struct AdmitResult {
pub(super) can_run: Vec<SglangRequest>,
pub(super) admissions: Vec<AdmissionEvent>,
pub(super) total_isl: usize,
pub(super) total_prefix: usize,
pub(super) oom: bool,
}
pub(super) fn get_new_batch_prefill(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
new_token_ratio: f64,
running: &[SglangRequest],
) -> AdmitResult {
let cache = kv_manager.cache();
let reserved_decode_output: f64 = running
.iter()
.map(|req| {
let remaining_output = req
.remaining_output_tokens()
.min(config.clip_max_new_tokens);
remaining_output as f64 * new_token_ratio
})
.sum();
let reserved_page_overhead = waiting
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>()
+ running
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>();
let mut rem_total_tokens = (cache.available_tokens() + cache.evictable_size)
.saturating_sub(reserved_page_overhead) as f64
- reserved_decode_output;
let mut rem_input_tokens = config.max_prefill_tokens as f64;
let mut rem_chunk_tokens = config.chunked_prefill_size as f64;
let mut can_run = Vec::new();
let mut admissions = Vec::new();
let mut rejected = VecDeque::new();
let mut oom = false;
let mut total_isl = 0usize;
let mut total_prefix = 0usize;
while let Some(mut req) = waiting.pop_front() {
let extend_input = req.extend_input_len();
if extend_input == 0 {
rejected.push_back(req);
break;
}
let total_needed = req.total_tokens_needed(config.clip_max_new_tokens) as f64;
if total_needed >= rem_total_tokens {
rejected.push_back(req);
break;
}
let chunk_tokens = if extend_input <= config.chunked_prefill_size {
extend_input
} else {
let chunk = (rem_chunk_tokens as usize / config.block_size) * config.block_size;
if chunk == 0 {
rejected.push_back(req);
break;
}
chunk.min(extend_input)
};
let charged_input_tokens = ceil_to_block(chunk_tokens, config.block_size) as f64;
if charged_input_tokens > rem_input_tokens || charged_input_tokens > rem_chunk_tokens {
rejected.push_back(req);
break;
}
let chunk_end = req.materialized_tokens + chunk_tokens;
let old_allocated_tokens = req.allocated_tokens;
let prev_node = req.last_node.take();
let alloc_tokens = req.sequence_prefix(chunk_end);
let actual_new_tokens = alloc_tokens.len().saturating_sub(req.materialized_tokens);
let available = kv_manager.cache().token_pool.available();
if available < actual_new_tokens {
kv_manager.evict(actual_new_tokens - available);
}
let alloc = if req.materialized_tokens > 0 {
let Some(last_node) = prev_node else {
rejected.push_back(req);
break;
};
kv_manager.allocate_after_prefix(
&alloc_tokens,
req.materialized_tokens,
&req.kv_indices[..req.materialized_tokens],
last_node,
)
} else {
kv_manager.allocate_for_request(&alloc_tokens)
};
let Some(alloc) = alloc else {
req.last_node = prev_node;
rejected.push_back(req);
oom = true;
break;
};
if let Some(node) = prev_node {
kv_manager.free_request(node);
}
req.last_node = Some(alloc.last_node);
req.kv_indices = alloc.kv_indices;
req.materialized_tokens = chunk_end;
req.allocated_tokens = ceil_to_block(chunk_end, config.block_size);
req.debug_assert_invariants(config.block_size);
let is_truncated = chunk_end < req.current_sequence_len();
let output_reserve = if is_truncated {
0
} else {
req.remaining_output_tokens()
.min(config.clip_max_new_tokens)
};
admissions.push(AdmissionEvent {
uuid: req.uuid,
reused_input_tokens: alloc.prefix_len,
});
total_isl += chunk_end;
total_prefix += alloc.prefix_len;
rem_total_tokens -= (req.allocated_tokens - old_allocated_tokens + output_reserve) as f64;
rem_input_tokens -= charged_input_tokens;
rem_chunk_tokens -= charged_input_tokens;
can_run.push(req);
if rem_chunk_tokens <= 0.0 {
break;
}
}
while let Some(req) = rejected.pop_back() {
waiting.push_front(req);
}
AdmitResult {
can_run,
admissions,
total_isl,
total_prefix,
oom,
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use uuid::Uuid;
use crate::cache::radix_cache::NodeId;
use crate::common::protocols::DirectRequest;
#[derive(Clone, Debug)]
pub(super) struct SglangRequest {
pub(super) uuid: Uuid,
pub(super) prompt_tokens: Vec<u64>,
pub(super) max_output_tokens: usize,
pub(super) output_ids: Vec<u64>,
pub(super) last_node: Option<NodeId>,
pub(super) kv_indices: Vec<usize>,
pub(super) materialized_tokens: usize,
pub(super) cached_tokens: usize,
pub(super) allocated_tokens: usize,
}
impl SglangRequest {
pub(super) fn prompt_len(&self) -> usize {
self.prompt_tokens.len()
}
pub(super) fn output_len(&self) -> usize {
self.output_ids.len()
}
pub(super) fn current_sequence_len(&self) -> usize {
self.prompt_len() + self.output_len()
}
pub(super) fn extend_input_len(&self) -> usize {
self.current_sequence_len()
.saturating_sub(self.materialized_tokens)
}
pub(super) fn total_tokens_needed(&self, clip_max_new_tokens: usize) -> usize {
let remaining_input = self.extend_input_len();
let remaining_output = self.remaining_output_tokens().min(clip_max_new_tokens);
remaining_input + remaining_output
}
pub(super) fn remaining_output_tokens(&self) -> usize {
self.max_output_tokens.saturating_sub(self.output_len())
}
pub(super) fn extra_reserved_tokens(&self) -> usize {
self.allocated_tokens.saturating_sub(self.kv_indices.len())
}
pub(super) fn page_aligned_materialized_tokens(&self, block_size: usize) -> usize {
self.materialized_tokens / block_size * block_size
}
pub(super) fn sequence_tokens(&self) -> Vec<u64> {
let mut sequence = self.prompt_tokens.clone();
sequence.extend_from_slice(&self.output_ids);
sequence
}
pub(super) fn sequence_prefix(&self, len: usize) -> Vec<u64> {
let prompt_len = self.prompt_len();
if len <= prompt_len {
return self.prompt_tokens[..len].to_vec();
}
let mut prefix = self.prompt_tokens.clone();
prefix.extend_from_slice(&self.output_ids[..len - prompt_len]);
prefix
}
pub(super) fn next_output_token(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.uuid.hash(&mut hasher);
self.output_len().hash(&mut hasher);
hasher.finish()
}
pub(super) fn append_output_token(&mut self, token: u64) {
self.output_ids.push(token);
self.materialized_tokens += 1;
}
pub(super) fn debug_assert_invariants(&self, block_size: usize) {
#[cfg(debug_assertions)]
{
let sequence_len = self.current_sequence_len();
debug_assert!(
self.cached_tokens <= self.materialized_tokens,
"request {} cached {} tokens but materialized {}",
self.uuid,
self.cached_tokens,
self.materialized_tokens
);
debug_assert!(
self.materialized_tokens <= sequence_len,
"request {} materialized {} tokens but sequence length is {sequence_len}",
self.uuid,
self.materialized_tokens
);
debug_assert_eq!(
self.kv_indices.len(),
self.materialized_tokens,
"request {} has {} kv indices but {} materialized tokens",
self.uuid,
self.kv_indices.len(),
self.materialized_tokens
);
debug_assert!(
self.allocated_tokens >= self.materialized_tokens,
"request {} allocated {} tokens but materialized {}",
self.uuid,
self.allocated_tokens,
self.materialized_tokens
);
debug_assert_eq!(
self.cached_tokens % block_size,
0,
"request {} cached tokens {} are not page-aligned to block size {block_size}",
self.uuid,
self.cached_tokens
);
debug_assert!(
self.allocated_tokens == 0 || self.allocated_tokens.is_multiple_of(block_size),
"request {} allocated tokens {} are not page-aligned to block size {block_size}",
self.uuid,
self.allocated_tokens
);
debug_assert!(
self.extra_reserved_tokens() < block_size,
"request {} reserves {} extra tokens with block size {block_size}",
self.uuid,
self.extra_reserved_tokens()
);
debug_assert_eq!(
self.last_node.is_some(),
self.materialized_tokens > 0,
"request {} has last_node={} but materialized_tokens={}",
self.uuid,
self.last_node.is_some(),
self.materialized_tokens
);
}
}
pub(super) fn reset_for_retract(&mut self) {
self.last_node = None;
self.kv_indices.clear();
self.materialized_tokens = 0;
self.cached_tokens = 0;
self.allocated_tokens = 0;
}
}
pub(super) fn direct_to_sglang(req: DirectRequest) -> SglangRequest {
SglangRequest {
uuid: req.uuid.unwrap_or_else(Uuid::new_v4),
prompt_tokens: req.tokens.iter().map(|&t| t as u64).collect(),
max_output_tokens: req.max_output_tokens,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use dynamo_kv_router::indexer::{METRIC_EVENT_REMOVED, METRIC_EVENT_STORED};
use dynamo_kv_router::protocols::WorkerId;
use rstest::rstest;
use tokio::sync::mpsc;
use uuid::Uuid;
use super::config::{SchedulePolicy, SglangConfig, ceil_to_block};
use super::core::SglangCore;
use super::decode;
use super::decode::simulate_decode_step;
use super::live::SglangScheduler;
use super::policy::apply_schedule_policy;
use super::prefill::get_new_batch_prefill;
use super::request::SglangRequest;
use crate::common::protocols::{
DirectRequest, EngineType, KvEventPublishers, MockEngineArgs, OutputSignal, SglangArgs,
};
use crate::kv_manager::SglangKvManager;
use crate::scheduler::test_utils::{
RouterIndexerHarness, nth_stored_hashes, removed_event_count, stored_hashes,
};
use crate::scheduler::{RouterEventVisibility, SchedulerHandle, capture_router_event_sink};
const ROUTER_TEST_WORKER_ID: WorkerId = 17;
fn test_args(
num_gpu_blocks: usize,
block_size: usize,
chunked_prefill_size: usize,
) -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(num_gpu_blocks)
.block_size(block_size)
.speedup_ratio(1.0)
.sglang(Some(SglangArgs {
page_size: Some(block_size),
chunked_prefill_size: Some(chunked_prefill_size),
..Default::default()
}))
.build()
.unwrap()
}
fn direct_request(tokens: Vec<u32>, max_output_tokens: usize) -> DirectRequest {
DirectRequest {
tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
}
}
fn make_decoded_request(
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
prompt_tokens: Vec<u64>,
max_output_tokens: usize,
) -> SglangRequest {
let prompt_len = prompt_tokens.len();
let alloc = kv_manager.allocate_for_request(&prompt_tokens).unwrap();
let mut running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens,
max_output_tokens,
output_ids: Vec::new(),
last_node: Some(alloc.last_node),
kv_indices: alloc.kv_indices,
materialized_tokens: prompt_len,
cached_tokens: 0,
allocated_tokens: ceil_to_block(prompt_len, config.block_size),
}];
let result = simulate_decode_step(&mut running, kv_manager, config, 0.0, false);
assert_eq!(result.output_signals.len(), 1);
running.pop().unwrap()
}
mod scheduling {
use super::*;
#[tokio::test]
async fn test_sglang_scheduler_fifo_ordering() {
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let num_requests = 5;
let max_output = 3;
for i in 0..num_requests {
scheduler.receive(crate::common::protocols::DirectRequest {
tokens: vec![i as u32; 10],
max_output_tokens: max_output,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_signals = num_requests * max_output;
let mut received = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
if received >= expected_signals {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(received, expected_signals);
}
#[test]
fn test_lpm_reorders_by_current_sequence_prefix_match() {
let mut kv_manager = SglangKvManager::new(1000, 1, KvEventPublishers::default(), 0);
kv_manager
.cache_mut()
.insert(&[1, 2, 3, 4, 5], &[0, 1, 2, 3, 4]);
let config = SglangConfig {
schedule_policy: SchedulePolicy::Lpm,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let no_match_uuid = Uuid::new_v4();
let match_uuid = Uuid::new_v4();
let mut waiting = VecDeque::from([
SglangRequest {
uuid: no_match_uuid,
prompt_tokens: vec![9, 8, 7],
max_output_tokens: 1,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
SglangRequest {
uuid: match_uuid,
prompt_tokens: vec![1, 2, 3, 4, 5],
max_output_tokens: 1,
output_ids: vec![6, 7],
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
]);
apply_schedule_policy(&mut waiting, &kv_manager, &config);
assert_eq!(waiting[0].uuid, match_uuid);
assert_eq!(waiting[1].uuid, no_match_uuid);
}
#[test]
fn test_lpm_deprioritizes_duplicate_short_prefixes() {
let config = SglangConfig {
schedule_policy: SchedulePolicy::Lpm,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.block_size(1)
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let kv_manager = SglangKvManager::new(1000, 1, KvEventPublishers::default(), 0);
let duplicate_prefix = (0..32).collect::<Vec<_>>();
let mut waiting = VecDeque::new();
for _ in 0..33 {
waiting.push_back(SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: duplicate_prefix.clone(),
max_output_tokens: 1,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
});
}
let unique_uuid = Uuid::new_v4();
waiting.push_back(SglangRequest {
uuid: unique_uuid,
prompt_tokens: (100..132).collect(),
max_output_tokens: 1,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
});
apply_schedule_policy(&mut waiting, &kv_manager, &config);
assert_eq!(
waiting.iter().position(|req| req.uuid == unique_uuid),
Some(1)
);
}
}
mod core_behavior {
use super::*;
#[test]
fn test_chunked_prefill_budget_is_page_aware() {
let config = SglangConfig {
chunked_prefill_size: 8,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let mut kv_manager = SglangKvManager::new(10000, 4, KvEventPublishers::default(), 0);
let mut waiting = VecDeque::from([SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1; 6],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
}]);
let admit = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
assert_eq!(admit.can_run.len(), 1);
assert_eq!(admit.can_run[0].materialized_tokens, 6);
assert_eq!(admit.can_run[0].allocated_tokens, 8);
}
#[test]
fn test_chunked_prefill_subpage_budget_defers_next_request() {
let config = SglangConfig {
chunked_prefill_size: 8,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let first_uuid = Uuid::new_v4();
let second_uuid = Uuid::new_v4();
let mut kv_manager = SglangKvManager::new(10000, 4, KvEventPublishers::default(), 0);
let mut waiting = VecDeque::from([
SglangRequest {
uuid: first_uuid,
prompt_tokens: vec![1; 7],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
SglangRequest {
uuid: second_uuid,
prompt_tokens: vec![2; 8],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
]);
let admit = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
assert_eq!(admit.can_run.len(), 1);
assert_eq!(admit.can_run[0].uuid, first_uuid);
assert_eq!(waiting.len(), 1);
assert_eq!(waiting[0].uuid, second_uuid);
}
#[test]
fn test_decode_allocation_is_page_aware() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let alloc = kv_manager
.allocate_for_request(&[1, 2, 3, 4, 5, 6])
.unwrap();
let mut running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4, 5, 6],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(alloc.last_node),
kv_indices: alloc.kv_indices,
materialized_tokens: 6,
cached_tokens: 4,
allocated_tokens: 8,
}];
let first = simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(running[0].allocated_tokens, 8);
assert_eq!(running[0].output_len(), 1);
assert_eq!(first.output_signals.len(), 1);
simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(running[0].allocated_tokens, 8);
simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(running[0].allocated_tokens, 12);
}
#[test]
fn test_decode_speedup_ratio_scales_sglang_decode_time() {
let base_args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(2.0)
.decode_speedup_ratio(1.0)
.build()
.unwrap();
let fast_args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(2.0)
.decode_speedup_ratio(4.0)
.build()
.unwrap();
let base_config = SglangConfig::from_args(&base_args);
let fast_config = SglangConfig::from_args(&fast_args);
let mut base_kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let base_alloc = base_kv_manager.allocate_for_request(&[1, 2, 3, 4]).unwrap();
let mut base_running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(base_alloc.last_node),
kv_indices: base_alloc.kv_indices,
materialized_tokens: 4,
cached_tokens: 0,
allocated_tokens: 4,
}];
let mut fast_kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let fast_alloc = fast_kv_manager.allocate_for_request(&[1, 2, 3, 4]).unwrap();
let mut fast_running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(fast_alloc.last_node),
kv_indices: fast_alloc.kv_indices,
materialized_tokens: 4,
cached_tokens: 0,
allocated_tokens: 4,
}];
let base = simulate_decode_step(
&mut base_running,
&mut base_kv_manager,
&base_config,
0.0,
true,
);
let fast = simulate_decode_step(
&mut fast_running,
&mut fast_kv_manager,
&fast_config,
0.0,
true,
);
let ratio = base.end_ms / fast.end_ms;
assert!(base.end_ms > fast.end_ms);
assert!(
(ratio - 4.0).abs() < 1e-3,
"expected 4x decode speedup ratio, got {ratio}"
);
}
#[test]
fn test_check_decode_mem_preserves_generated_output_on_retract() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut kv_manager = SglangKvManager::new(8, 4, KvEventPublishers::default(), 0);
let first = kv_manager.cache_mut().token_pool.allocate(4).unwrap();
let second = kv_manager.cache_mut().token_pool.allocate(4).unwrap();
let mut running = vec![
SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 10,
output_ids: vec![11, 12, 13],
last_node: None,
kv_indices: first,
materialized_tokens: 7,
cached_tokens: 4,
allocated_tokens: 8,
},
SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![9, 8, 7, 6],
max_output_tokens: 10,
output_ids: vec![21],
last_node: None,
kv_indices: second,
materialized_tokens: 5,
cached_tokens: 4,
allocated_tokens: 8,
},
];
let retracted = decode::check_decode_mem(&mut running, &mut kv_manager, &config);
assert_eq!(retracted.len(), 1);
assert_eq!(retracted[0].output_ids, vec![21]);
assert_eq!(retracted[0].materialized_tokens, 0);
assert!(retracted[0].kv_indices.is_empty());
}
#[test]
fn test_unfinished_decode_request_is_cached_after_output() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let alloc = kv_manager.allocate_for_request(&[1, 2, 3, 4]).unwrap();
let mut running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(alloc.last_node),
kv_indices: alloc.kv_indices,
materialized_tokens: 4,
cached_tokens: 0,
allocated_tokens: 4,
}];
simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
let prefix = running[0].sequence_prefix(4);
assert_eq!(kv_manager.cache().prefix_match_len(&prefix), 4);
}
#[test]
fn test_active_decode_blocks_tracks_page_reserved_occupancy_in_blocks() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(32)
.block_size(4)
.speedup_ratio(1.0)
.sglang(Some(SglangArgs {
chunked_prefill_size: Some(8),
page_size: Some(4),
..Default::default()
}))
.build()
.unwrap();
let mut core = SglangCore::new(args);
core.receive(crate::common::protocols::DirectRequest {
tokens: vec![1; 6],
max_output_tokens: 2,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
let pass = core.execute_pass_internal(None, 0.0);
assert_eq!(pass.completed_requests, 0);
assert_eq!(pass.active_decode_blocks, 2);
}
#[test]
fn test_sglang_pass_visibility_is_pass_end() {
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 4), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![1, 2, 3, 4], 1));
let pass = core.execute_pass_internal(None, 0.0);
assert_eq!(pass.router_event_visibility, RouterEventVisibility::PassEnd);
}
}
async fn assert_sglang_scheduler_completes_all(
scheduler: &SglangScheduler,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
prompt_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_prefix = vec![1u32; prompt_len / 2];
for i in 0..num_requests {
let mut input_tokens = if use_shared_tokens {
shared_prefix.clone()
} else {
Vec::new()
};
let unique_len = prompt_len - input_tokens.len();
input_tokens.extend((0..unique_len).map(|j| (i * unique_len + j) as u32 + 1000));
scheduler.receive(crate::common::protocols::DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(received_tokens, expected_tokens);
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert!(metrics.active_decode_blocks > 0);
assert!(metrics.total_blocks > 0);
assert!((0.0..=1.0).contains(&metrics.gpu_cache_usage_perc));
}
mod router_events {
use super::*;
#[rstest]
#[case::case_1(false, "fifo", 1)]
#[case::case_2(true, "fifo", 1)]
#[case::case_3(false, "lpm", 1)]
#[case::case_4(true, "lpm", 1)]
#[case::case_5(false, "fifo", 4)]
#[case::case_6(true, "fifo", 4)]
#[case::case_7(false, "lpm", 4)]
#[case::case_8(true, "lpm", 4)]
#[tokio::test]
async fn test_sglang_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] schedule_policy: &str,
#[case] page_size: usize,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.sglang(Some(SglangArgs {
schedule_policy: Some(schedule_policy.to_string()),
page_size: Some(page_size),
..Default::default()
}))
.build()
.unwrap();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
assert_sglang_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
async fn test_chunked_prefill_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 4), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![1, 2, 3, 4, 5, 6], 2));
let pass1 = core.execute_pass_internal(None, 0.0);
let mut prompt_hashes = stored_hashes(&pass1.kv_events);
assert_eq!(prompt_hashes.len(), 4);
harness.apply_events(pass1.kv_events).await;
let pass2 = core.execute_pass_internal(None, pass1.end_ms);
prompt_hashes.extend(nth_stored_hashes(&pass2.kv_events, 0));
harness.apply_events(pass2.kv_events).await;
assert_eq!(prompt_hashes.len(), 6);
assert!(harness.ok_count(METRIC_EVENT_STORED) >= 2);
harness.shutdown();
}
#[tokio::test]
async fn test_decode_growth_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 16), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![7, 8, 9, 10], 5));
let pass1 = core.execute_pass_internal(None, 0.0);
let mut full_hashes = stored_hashes(&pass1.kv_events);
harness.apply_events(pass1.kv_events).await;
let pass2 = core.execute_pass_internal(None, pass1.end_ms);
full_hashes.extend(stored_hashes(&pass2.kv_events));
harness.apply_events(pass2.kv_events).await;
assert_eq!(full_hashes.len(), 6);
assert!(harness.ok_count(METRIC_EVENT_STORED) >= 2);
harness.shutdown();
}
#[tokio::test]
async fn test_retract_frees_do_not_leave_stale_blocks() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let args = test_args(8, 4, 16);
let config = SglangConfig::from_args(&args);
let (buffer, sink) = capture_router_event_sink(ROUTER_TEST_WORKER_ID);
let mut kv_manager =
SglangKvManager::new(10, 4, KvEventPublishers::new(Some(sink), None), 0);
let req1 = make_decoded_request(&mut kv_manager, &config, vec![1, 2, 3, 4], 4);
let req1_events = buffer.drain();
let req1_hashes = stored_hashes(&req1_events);
harness.apply_events(req1_events).await;
let req2 = make_decoded_request(&mut kv_manager, &config, vec![9, 8, 7, 6], 4);
harness.apply_events(buffer.drain()).await;
let mut running = vec![req1, req2];
let retracted = decode::check_decode_mem(&mut running, &mut kv_manager, &config);
assert_eq!(retracted.len(), 1);
let retract_events = buffer.drain();
assert!(removed_event_count(&retract_events) > 0);
harness.apply_events(retract_events).await;
assert_eq!(harness.overlap_for_hashes(req1_hashes).await, 4);
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
#[tokio::test]
async fn test_completion_tail_free_emits_valid_removals() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 16), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![11, 12, 13, 14], 3));
let pass1 = core.execute_pass_internal(None, 0.0);
let prompt_hashes = nth_stored_hashes(&pass1.kv_events, 0);
let mut full_hashes = stored_hashes(&pass1.kv_events);
harness.apply_events(pass1.kv_events).await;
let pass2 = core.execute_pass_internal(None, pass1.end_ms);
full_hashes.extend(stored_hashes(&pass2.kv_events));
harness.apply_events(pass2.kv_events).await;
let pass3 = core.execute_pass_internal(None, pass2.end_ms);
assert!(removed_event_count(&pass3.kv_events) > 0);
full_hashes.extend(stored_hashes(&pass3.kv_events));
harness.apply_events(pass3.kv_events).await;
assert_eq!(prompt_hashes.len(), 4);
assert!(full_hashes.len() >= prompt_hashes.len());
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
#[tokio::test]
async fn test_mixed_chunk_decode_retract_reprefill_complete_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let args = test_args(8, 4, 4);
let config = SglangConfig::from_args(&args);
let (buffer, sink) = capture_router_event_sink(ROUTER_TEST_WORKER_ID);
let mut kv_manager =
SglangKvManager::new(12, 4, KvEventPublishers::new(Some(sink), None), 0);
let mut waiting = VecDeque::from([SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4, 5, 6],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
}]);
let chunk1 = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
let mut req1 = chunk1.can_run.into_iter().next().unwrap();
decode::cache_materialized_prefix(&mut req1, &mut kv_manager, &config);
waiting.push_front(req1);
harness.apply_events(buffer.drain()).await;
let chunk2 = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
let mut running = chunk2.can_run;
let decode1 = simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(decode1.output_signals.len(), 1);
harness.apply_events(buffer.drain()).await;
let req1 = running.pop().unwrap();
let req2 = make_decoded_request(&mut kv_manager, &config, vec![9, 10, 11, 12], 3);
harness.apply_events(buffer.drain()).await;
let mut running = vec![req1, req2];
let mut retracted = decode::check_decode_mem(&mut running, &mut kv_manager, &config);
assert_eq!(retracted.len(), 1);
harness.apply_events(buffer.drain()).await;
let mut waiting = VecDeque::from([retracted.pop().unwrap()]);
let mut now_ms = 0.0;
let mut saw_remove = harness.ok_count(METRIC_EVENT_REMOVED) > 0;
loop {
let admit =
get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &running);
for mut req in admit.can_run {
if req.materialized_tokens < req.current_sequence_len() {
decode::cache_materialized_prefix(&mut req, &mut kv_manager, &config);
waiting.push_front(req);
} else {
running.push(req);
}
}
let events = buffer.drain();
saw_remove |= removed_event_count(&events) > 0;
harness.apply_events(events).await;
if running.is_empty() {
if waiting.is_empty() {
break;
}
continue;
}
let decode =
simulate_decode_step(&mut running, &mut kv_manager, &config, now_ms, false);
now_ms = decode.end_ms;
for req in decode.requests.into_iter().rev() {
waiting.push_front(req);
}
let events = buffer.drain();
saw_remove |= removed_event_count(&events) > 0;
harness.apply_events(events).await;
if running.is_empty() && waiting.is_empty() {
break;
}
}
assert!(saw_remove);
harness.assert_no_event_errors();
harness.shutdown();
}
#[tokio::test]
async fn test_live_pathological_load_no_router_event_errors() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = SglangScheduler::new(
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(4)
.block_size(4)
.speedup_ratio(1000.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(4),
..Default::default()
}))
.build()
.unwrap(),
0,
Some(output_tx),
KvEventPublishers::new(Some(sink.clone()), None),
None,
);
for _ in 0..8 {
scheduler.receive(direct_request(vec![42], 4));
}
let expected = 8 * 4;
let mut seen = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
if seen == expected {
break;
}
}
_ = &mut timeout => {
break;
}
}
}
assert_eq!(seen, expected);
drop(scheduler);
drop(sink);
forward_task.await.unwrap();
harness.flush().await;
harness.assert_no_event_errors();
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::anyhow;
use dynamo_kv_router::indexer::{
KvIndexerInterface, KvIndexerMetrics, LocalKvIndexer, METRIC_EVENT_REMOVED,
METRIC_EVENT_STORED, METRIC_STATUS_BLOCK_NOT_FOUND, METRIC_STATUS_INVALID_BLOCK,
METRIC_STATUS_OK, METRIC_STATUS_PARENT_NOT_FOUND,
};
use dynamo_kv_router::protocols::{
KvCacheEvent, KvCacheEventData, LocalBlockHash, RouterEvent, WorkerId, WorkerWithDpRank,
};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use super::{DirectRequest, OutputSignal, SchedulerHandle};
use crate::common::protocols::KvCacheEventSink;
pub(crate) struct RouterIndexerHarness {
indexer: Arc<LocalKvIndexer>,
metrics: Arc<KvIndexerMetrics>,
worker: WorkerWithDpRank,
}
impl RouterIndexerHarness {
pub(crate) fn new(block_size: u32, worker_id: WorkerId) -> Self {
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let indexer = Arc::new(LocalKvIndexer::new(
CancellationToken::new(),
block_size,
metrics.clone(),
4096,
));
Self {
indexer,
metrics,
worker: WorkerWithDpRank::new(worker_id, 0),
}
}
pub(crate) async fn apply_events<I>(&self, events: I)
where
I: IntoIterator<Item = RouterEvent>,
{
for event in events {
self.indexer.apply_event_with_buffer(event).await.unwrap();
}
let _ = self.indexer.flush().await;
self.assert_no_event_errors();
}
pub(crate) async fn overlap_for_hashes(&self, local_hashes: Vec<LocalBlockHash>) -> u32 {
self.indexer
.find_matches(local_hashes)
.await
.unwrap()
.scores
.get(&self.worker)
.copied()
.unwrap_or(0)
}
pub(crate) fn ok_count(&self, event_type: &'static str) -> u64 {
metric_value(&self.metrics, event_type, METRIC_STATUS_OK)
}
pub(crate) fn status_count(&self, event_type: &'static str, status: &'static str) -> u64 {
metric_value(&self.metrics, event_type, status)
}
pub(crate) fn invalid_counts(&self) -> Vec<(&'static str, &'static str, u64)> {
[METRIC_EVENT_STORED, METRIC_EVENT_REMOVED]
.into_iter()
.flat_map(|event_type| {
[
METRIC_STATUS_PARENT_NOT_FOUND,
METRIC_STATUS_BLOCK_NOT_FOUND,
METRIC_STATUS_INVALID_BLOCK,
]
.into_iter()
.map(move |status| (event_type, status, self.status_count(event_type, status)))
})
.collect()
}
pub(crate) fn invalid_event_count(&self) -> u64 {
self.invalid_counts()
.into_iter()
.map(|(_, _, count)| count)
.sum()
}
pub(crate) fn spawn_forwarder(&self) -> (Arc<TestKvEventSink>, JoinHandle<()>) {
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<RouterEvent>();
let sink = Arc::new(TestKvEventSink {
worker_id: self.worker.worker_id,
event_tx,
});
let indexer = self.indexer.clone();
let forwarder = tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
indexer.apply_event_with_buffer(event).await.unwrap();
}
let _ = indexer.flush().await;
});
(sink, forwarder)
}
pub(crate) async fn flush(&self) {
let _ = self.indexer.flush().await;
}
pub(crate) fn assert_no_event_errors(&self) {
let breakdown = self
.invalid_counts()
.into_iter()
.filter(|(_, _, count)| *count > 0)
.map(|(event_type, status, count)| format!("{event_type}/{status}={count}"))
.collect::<Vec<_>>()
.join(", ");
assert_eq!(
self.invalid_event_count(),
0,
"router indexer reported invalid KV events{}",
if breakdown.is_empty() {
String::new()
} else {
format!(": {breakdown}")
}
);
}
pub(crate) fn shutdown(&self) {
self.indexer.shutdown();
}
}
#[derive(Clone)]
pub(crate) struct TestKvEventSink {
worker_id: WorkerId,
event_tx: mpsc::UnboundedSender<RouterEvent>,
}
impl KvCacheEventSink for TestKvEventSink {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.event_tx
.send(RouterEvent::new(self.worker_id, event))
.map_err(|_| anyhow!("router test event channel closed"))
}
}
pub(crate) fn metric_value(
metrics: &KvIndexerMetrics,
event_type: &'static str,
status: &'static str,
) -> u64 {
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[event_type, status])
.unwrap()
.get()
}
pub(crate) fn stored_hashes(events: &[RouterEvent]) -> Vec<LocalBlockHash> {
events
.iter()
.filter_map(|event| match &event.event.data {
KvCacheEventData::Stored(store) => Some(
store
.blocks
.iter()
.map(|block| block.tokens_hash)
.collect::<Vec<_>>(),
),
_ => None,
})
.flatten()
.collect()
}
pub(crate) fn nth_stored_hashes(events: &[RouterEvent], nth: usize) -> Vec<LocalBlockHash> {
events
.iter()
.filter_map(|event| match &event.event.data {
KvCacheEventData::Stored(store) => Some(
store
.blocks
.iter()
.map(|block| block.tokens_hash)
.collect::<Vec<_>>(),
),
_ => None,
})
.nth(nth)
.unwrap_or_default()
}
pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize {
events
.iter()
.filter(|event| matches!(event.event.data, KvCacheEventData::Removed(_)))
.count()
}
/// Send `num_requests` to a scheduler, collect all output signals, and assert
/// that the scheduler produces exactly `num_requests * max_output_tokens` signals
/// and returns to idle (0 active decode blocks).
///
/// When `use_shared_tokens` is true, the first half of each request shares a
/// common prefix to exercise prefix caching / radix tree reuse.
pub(crate) async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
scheduler.receive(DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received_tokens, expected_tokens,
"Expected {expected_tokens} output signals, got {received_tokens}"
);
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_eq!(
metrics.active_decode_blocks, 0,
"Scheduler should be idle after all requests complete, got {} active blocks",
metrics.active_decode_blocks
);
assert_eq!(
metrics.gpu_cache_usage_perc, 0.0,
"Scheduler should report zero cache usage after draining, got {}",
metrics.gpu_cache_usage_perc
);
assert!(
metrics.total_blocks > 0,
"Scheduler should populate total_blocks, got {}",
metrics.total_blocks
);
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Asynchronous Scheduler for LLM Request Management
//!
//! This module implements an asynchronous scheduler that handles three main functions:
//! 1. Receiving new requests and placing them in the waiting queue
//! 2. Scheduling waiting requests against available KV cache resources
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
//! The scheduler checks direct block capacity to determine if there's sufficient
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
//!
//! ## Request Simulation
//! The simulation models two key phases:
//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
//!
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
//! preemption strategy (LIFO by default, matching vLLM v1) where a running request
//! is evicted and placed at the front of the waiting queue to be rescheduled later.
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
WorkerType,
};
use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence;
use crate::common::utils::sleep_until_precise;
use crate::kv_manager::KvManager;
use crate::simulation::{TraceCollector, TraceSimulationReport};
use dynamo_kv_router::protocols::DpRank;
use dynamo_tokens::blocks::UniqueBlock;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use validator::Validate;
/// Simple metrics struct for mocker's internal use
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
pub dp_rank: DpRank,
pub active_decode_blocks: u64,
}
/// Enum representing either a direct request or an active sequence
pub enum Request {
Direct(DirectRequest),
Active(ActiveSequence),
}
#[derive(Default)]
struct SchedulerState {
waiting: VecDeque<Uuid>,
prefill: VecDeque<Uuid>,
decode: VecDeque<Uuid>,
requests: HashMap<Uuid, Request>,
}
impl SchedulerState {
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn receive(&mut self, request: DirectRequest) -> Uuid {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
self.requests.insert(uuid, Request::Direct(request));
self.waiting.push_back(uuid);
uuid
}
/// Try to admit one request from waiting → prefill.
/// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
/// later in simulate_prefill when the request reaches the front of the queue.
fn admit_one(&mut self, args: &MockEngineArgs) -> Option<Uuid> {
let &uuid = self.waiting.front()?;
let num_active = self.prefill.len() + self.decode.len();
if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
return None;
}
self.waiting.pop_front();
// Convert DirectRequest → ActiveSequence if needed
if let Some(Request::Direct(_)) = self.requests.get(&uuid) {
let Some(Request::Direct(direct)) = self.requests.remove(&uuid) else {
unreachable!()
};
self.requests.insert(
uuid,
Request::Active(ActiveSequence::new(
direct.tokens,
direct.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
args.zmq_kv_events_port.is_some(),
)),
);
}
self.prefill.push_back(uuid);
Some(uuid)
}
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
if !self.decode.contains(&uuid) {
return None;
}
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
Some(sequence)
}
/// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) {
tracing::trace!("Request {uuid} will complete");
self.decode.retain(|u| u != uuid);
self.requests.remove(uuid);
}
/// Preempt a running request by evicting it from decode, resetting the sequence,
/// and adding it back to the front of the waiting queue.
/// In LIFO mode, evicts the newest request (matches vLLM v1).
/// In FIFO mode, evicts the oldest request.
fn preempt(&mut self, mode: PreemptionMode) -> Vec<MoveBlock> {
let uuid = match mode {
PreemptionMode::Lifo => self.decode.pop_back(),
PreemptionMode::Fifo => self.decode.pop_front(),
}
.expect("Nothing to evict for preemption.");
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
tracing::warn!("Request {uuid} will be preempted");
// Reset the sequence and re-queue for prefill
let Request::Active(mut active_sequence) = request else {
panic!("Expected ActiveSequence in running queue")
};
let signals = active_sequence.reset_with_signal();
self.requests.insert(uuid, Request::Active(active_sequence));
self.waiting.push_front(uuid);
signals
}
}
/// Cancels its token when dropped. Shared via Arc so the background task is
/// only cancelled when the last Scheduler clone is dropped.
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
impl Scheduler {
/// Create a new Scheduler with the given parameters
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
args.validate().expect("invalid MockEngineArgs");
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let initial_metrics = MockerMetrics {
dp_rank,
active_decode_blocks: 0,
};
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
// Spawn main background task with cancellation token
tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
kv_event_sink,
dp_rank,
);
let mut hit_rates = RunningMean::new(1000);
loop {
// 1. Receive requests
if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
.await
.is_none()
{
break;
}
// 2. Simulate prefill + decode
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
// 3. Send metrics once per forward pass (after all prefill and decode processing)
let _ = metrics_tx.send(MockerMetrics {
dp_rank,
active_decode_blocks: kv_manager.num_active_blocks() as u64,
});
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl super::SchedulerHandle for Scheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
state: &mut SchedulerState,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if state.is_empty() {
// Fully idle - block until new request arrives or shutdown
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return None;
}
result = request_rx.recv() => {
let Some(request) = result else {
return None; // channel closed
};
state.receive(request);
return Some(());
}
}
}
// Has active/waiting work - collect any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
state.receive(request);
}
Some(())
}
/// Simulate prefill phase for all pending prefill requests.
///
/// Handles token budget, block allocation, and preemption inline.
/// Token budget: `max_num_batched_tokens - decode.len()` (1 token per decode request).
/// When blocks are unavailable, decode requests are preempted (LIFO by default)
/// to free capacity, matching vLLM v1 behavior.
async fn simulate_prefill(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
) -> Duration {
let start_time = Instant::now();
let total_time = simulate_prefill_step(state, kv_manager, hit_rates, args, None, 0.0, false);
if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
async fn simulate_decode(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
args: &MockEngineArgs,
) -> Duration {
let start_time = Instant::now();
let total_time = simulate_decode_step(state, kv_manager, output_tx, args, None, 0.0, false);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
if effective_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / effective_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
fn simulate_prefill_step(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
mut collector: Option<&mut TraceCollector>,
current_time_ms: f64,
apply_speedup: bool,
) -> Duration {
let mut token_budget = args
.max_num_batched_tokens
.map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));
// Accumulate batch-level prefill stats for a single predict call after the loop.
let mut batch_count: usize = 0;
let mut batch_total_isl: usize = 0;
let mut batch_total_prefix: usize = 0;
'prefill: while token_budget > 0 {
// Drain prefill first, then pull from waiting one at a time.
if state.prefill.is_empty() {
let Some(admitted_uuid) = state.admit_one(args) else {
break;
};
if let Some(collector) = collector.as_deref_mut() {
let Some(Request::Active(seq)) = state.requests.get(&admitted_uuid) else {
panic!("Request does not exist.");
};
let prefill_cost = kv_manager.get_prefill_cost(seq);
let reused_input_tokens = seq.len().saturating_sub(prefill_cost.new_tokens);
collector.on_admit(admitted_uuid, current_time_ms, reused_input_tokens);
}
}
let uuid = state.prefill[0];
let Some(Request::Active(seq)) = state.requests.get(&uuid) else {
panic!("Request does not exist.");
};
let prefill_cost = kv_manager.get_prefill_cost(seq);
let sequence_len = seq.len();
let allocated_tokens = seq.num_allocated_tokens();
let remaining = prefill_cost.new_tokens;
// Token budget check.
let tokens_left = sequence_len - allocated_tokens;
if !args.enable_chunked_prefill && tokens_left > token_budget {
break;
}
let chunk = tokens_left.min(token_budget);
let cumulative = allocated_tokens + chunk;
// Allocate blocks. process() returns the number of blocks committed.
// On partial success, preempt a decode request and retry; the next
// loop iteration re-prepares from the updated num_allocated_tokens.
let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
if let Some(signal) = seq.prepare_allocation(cumulative) {
let expected = match &signal {
MoveBlock::Use(blocks, ..) => blocks.len(),
_ => unreachable!(),
};
let allocated = kv_manager.process(&signal);
// Commit the blocks that were actually allocated.
let committed_tokens = if allocated == expected {
cumulative
} else {
// Partial success: compute token boundary from block count.
let prev_blocks = allocated_tokens
.div_ceil(seq.block_size())
.min(seq.unique_blocks().len());
(prev_blocks + allocated) * seq.block_size()
};
seq.commit_allocation(committed_tokens.min(cumulative));
if allocated < expected {
if state.decode.is_empty() {
break;
}
for signal in state.preempt(args.preemption_mode) {
kv_manager.process(&signal);
}
continue 'prefill; // Retry with freed capacity.
}
} else {
seq.commit_allocation(cumulative);
}
// Accumulate per-request (isl, prefix) for batch-level prediction.
let new_tokens_in_chunk = chunk.min(remaining);
if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
let isl = prefill_cost.cached_tokens + new_tokens_in_chunk;
batch_total_isl += isl;
batch_total_prefix += prefill_cost.cached_tokens;
batch_count += 1;
}
// Hit rate: fraction of tokens that were already cached.
let hit_rate = if sequence_len > 0 {
1.0 - (remaining as f32 / sequence_len as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
token_budget -= chunk;
if cumulative >= sequence_len {
// Fully prefilled: promote to decode queue.
state.prefill.pop_front();
state.decode.push_back(uuid);
} else {
// Partially prefilled: resume next iteration with updated allocation state.
break;
}
}
// One batch-level prefill prediction instead of summing per-request predictions.
let total_time = if batch_count > 0 {
let mean_isl = batch_total_isl / batch_count;
let mean_prefix = batch_total_prefix / batch_count;
let ms = args
.perf_model
.predict_prefill_time(batch_count, mean_isl, mean_prefix);
Duration::from_secs_f64(ms / 1000.0)
} else {
Duration::ZERO
};
if !apply_speedup || args.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
return total_time;
}
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio)
}
fn simulate_decode_step(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
args: &MockEngineArgs,
mut collector: Option<&mut TraceCollector>,
current_time_ms: f64,
apply_speedup: bool,
) -> Duration {
if state.decode.is_empty() {
return Duration::ZERO;
}
let decode_start_ms = current_time_ms;
let decode_lengths = state
.decode
.iter()
.filter_map(|uuid| match state.requests.get(uuid).unwrap() {
Request::Active(seq) => Some(seq.len()),
Request::Direct(_) => None,
})
.collect::<Vec<_>>();
if decode_lengths.is_empty() {
return Duration::ZERO;
}
let count = decode_lengths.len();
let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
let total_length: usize = decode_lengths.iter().sum();
let context_length = total_length / count;
let decoding_time =
args.perf_model
.predict_decode_time(count, active_kv_tokens, context_length);
let unscaled_time = Duration::from_secs_f64(decoding_time / 1000.0);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
Duration::from_secs_f64(unscaled_time.as_secs_f64() / effective_ratio)
} else {
unscaled_time
};
let decode_end_ms = decode_start_ms + total_time.as_secs_f64() * 1000.0;
// Process decoding.
let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
let mut emitted_any = false;
for uuid in uuids {
let mut allocated = false;
loop {
let Some(sequence) = state.run(uuid) else {
break;
};
let signals = sequence.generate();
if process_signals(kv_manager, &signals) {
allocated = true;
break;
}
sequence.pop(); // revert the failed generation
if state.decode.is_empty() {
break;
}
// Preempt one request and free its blocks
for signal in state.preempt(args.preemption_mode) {
kv_manager.process(&signal);
}
// If the current request was the one preempted, stop retrying
if !state.decode.contains(&uuid) {
break;
}
}
if !allocated {
continue;
}
let Some(sequence) = state.run(uuid) else {
continue;
};
emitted_any = true;
if let Some(collector) = collector.as_deref_mut() {
collector.on_token(uuid, decode_end_ms);
}
// Check completion and send notification.
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let send_failed = output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
if send_failed {
for signal in &sequence.free_signal() {
kv_manager.process(signal);
}
}
if send_failed || is_complete {
state.complete(&uuid);
}
}
if !emitted_any {
return Duration::ZERO;
}
total_time
}
pub fn simulate_trace(
args: MockEngineArgs,
mut requests: Vec<DirectRequest>,
) -> anyhow::Result<TraceSimulationReport> {
args.validate()?;
requests.sort_by(|left, right| {
let left_ts = left
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let right_ts = right
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests
.first()
.and_then(|request| request.arrival_timestamp_ms)
.ok_or_else(|| anyhow::anyhow!("trace replay requires at least one timestamped request"))?;
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
let arrival_timestamp_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp")
- first_arrival_ms;
request.arrival_timestamp_ms = Some(arrival_timestamp_ms);
request
})
.collect::<Vec<_>>(),
);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
if state.is_empty() {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
current_time_ms = next_arrival_ms;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
continue;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
Ok(collector.finish())
}
pub fn simulate_concurrency(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
args.validate()?;
let mut pending = VecDeque::from(requests);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_concurrency_arrivals(
&mut pending,
&mut state,
&mut collector,
current_time_ms,
max_in_flight,
);
if state.is_empty() {
break;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
Ok(collector.finish())
}
fn enqueue_trace_arrivals(
pending: &mut VecDeque<DirectRequest>,
state: &mut SchedulerState,
collector: &mut TraceCollector,
current_time_ms: f64,
) {
loop {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > current_time_ms {
break;
}
let request = pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = state.receive(request);
collector.on_arrival(uuid, arrival_ms, input_length, output_length);
}
}
fn enqueue_concurrency_arrivals(
pending: &mut VecDeque<DirectRequest>,
state: &mut SchedulerState,
collector: &mut TraceCollector,
current_time_ms: f64,
max_in_flight: usize,
) {
while state.requests.len() < max_in_flight {
let Some(mut request) = pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(current_time_ms);
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = state.receive(request);
collector.on_arrival(uuid, current_time_ms, input_length, output_length);
}
}
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
/// specifically a single signal attempting to create a single partial (generation) block.
/// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system.
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
for signal in signals {
if kv_manager.process(signal) > 0 {
continue;
}
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks, _hashes, ..) = signal else {
panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
);
};
// Verify the signal contains exactly one block
let num_blocks = blocks.len();
let num_active_blocks = kv_manager.num_active_blocks();
if num_blocks != 1 {
panic!(
"Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
);
}
// Verify the block is a PartialBlock (generation block)
if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
panic!("Failed signal is Invalid. Generation block has to be partial.");
}
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scheduler::SchedulerHandle;
use crate::simulation::{TraceCollector, TraceRequestStatsSnapshot};
use rstest::rstest;
use std::collections::HashMap;
use std::time::Duration;
use tokio::time::interval;
/// Helper function to verify that the scheduler is idle (no active KV blocks)
fn assert_scheduler_idle(metrics: &MockerMetrics) {
assert_eq!(
metrics.active_decode_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.active_decode_blocks
);
}
#[rstest]
#[case::case_1(false, false, false)]
#[case::case_2(false, true, false)]
#[case::case_3(true, false, false)]
#[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
async fn test_cache_hit_rate_with_identical_requests() {
let block_size: usize = 64;
let max_output_tokens: usize = 10;
let speedup_ratio = 10.0;
let num_requests = 10;
let token_length = 65;
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args
let args = MockEngineArgs::builder()
.num_gpu_blocks(100) // Large enough to not be a constraint
.block_size(block_size)
.speedup_ratio(speedup_ratio)
.build()
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create identical tokens for all requests
let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
// Send all requests with identical tokens
for _ in 0..num_requests {
let request = DirectRequest {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
};
scheduler.receive(request);
// Sleep for 0.1 second after each request
tokio::time::sleep(Duration::from_millis(100)).await;
}
// Collect all generated tokens
let mut received_tokens = 0;
// Set up a timeout that resets to 0.5 seconds on each received token
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
// Reset timeout whenever we receive a token
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => {
// Break when timeout occurs (no more tokens for 0.5 seconds)
break;
}
}
}
// Wait a bit for final metrics update
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify forward pass metrics - scheduler should be idle after completing all requests
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
println!("Test passed! Received {received_tokens} tokens");
}
/// White-box unit test that directly creates SchedulerState + KvManager,
/// manually invokes simulate_prefill / simulate_decode, and asserts on
/// queue states and block counts after each step.
#[tokio::test]
async fn test_scheduler_internal_state_transitions() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let r1_uuid = Uuid::from_u128(1);
let r2_uuid = Uuid::from_u128(2);
let r3_uuid = Uuid::from_u128(3);
// ── Step 1: Receive 3 requests ──
// R1: 8 input, 2 max_output → 2 blocks
// R2: 8 input, 2 max_output → 2 blocks
// R3: 12 input, 2 max_output → 3 blocks
state.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1_uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
state.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2_uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
state.receive(DirectRequest {
tokens: (200..212).collect(),
max_output_tokens: 2,
uuid: Some(r3_uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
assert_eq!(state.waiting.len(), 3);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 0);
assert_eq!(kv_manager.num_active_blocks(), 0);
// ── Step 2: First simulate_prefill ──
// Budget=12. R1 takes 8 tokens (2 blocks), fully prefilled → decode.
// R2 takes 4 tokens (1 block, chunked), partially prefilled → stays in prefill.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 1);
assert_eq!(state.prefill.len(), 1);
assert_eq!(state.decode.len(), 1);
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(state.prefill[0], r2_uuid);
assert_eq!(state.waiting[0], r3_uuid);
assert_eq!(kv_manager.num_active_blocks(), 3); // 2 for R1 + 1 for R2
let seq = match state.requests.get(&r1_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 8);
assert_eq!(seq.generated_tokens(), 0);
let seq = match state.requests.get(&r2_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 4);
assert_eq!(seq.generated_tokens(), 0);
// ── Step 3: First simulate_decode ──
// R1 generates 1 token, gains a partial block.
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
assert_eq!(state.decode.len(), 1);
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(kv_manager.num_active_blocks(), 4); // +1 partial for R1
let seq = match state.requests.get(&r1_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.generated_tokens(), 1);
// ── Step 4: Second simulate_prefill ──
// Budget=11. R2 finishes (4 more tokens, 1 block → active=5, decode).
// R3 admitted, needs 2 blocks for chunk of 7. Only 1 free slot → partial.
// Preempt R2 (LIFO) → R2 back to waiting. Retry R3 → evicts R2's
// inactive blocks, allocates 2 more → R3 allocated_tokens=11.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 1, "R2 preempted back to waiting");
assert_eq!(state.waiting[0], r2_uuid);
assert_eq!(state.prefill.len(), 1, "R3 partially prefilled");
assert_eq!(state.prefill[0], r3_uuid);
assert_eq!(state.decode.len(), 1, "R1 still decoding");
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(kv_manager.num_active_blocks(), 6); // at capacity
let seq = match state.requests.get(&r3_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 11);
// ── Step 5: Second simulate_decode ──
// R1 generates 2nd token → complete. Frees 3 blocks (1 destroyed, 2 deactivated).
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
assert!(!state.requests.contains_key(&r1_uuid), "R1 completed");
assert_eq!(state.decode.len(), 0);
assert_eq!(state.prefill.len(), 1);
assert_eq!(state.waiting.len(), 1);
assert_eq!(kv_manager.num_active_blocks(), 3); // only R3's 3 blocks
// ── Step 6: Third simulate_prefill ──
// R3 finishes prefill (1 token left, no new blocks) → decode.
// R2 re-admitted, fully prefilled (2 blocks via inactive eviction) → decode.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 0);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 2);
assert!(state.decode.contains(&r3_uuid));
assert!(state.decode.contains(&r2_uuid));
assert_eq!(kv_manager.num_active_blocks(), 5); // 3 for R3 + 2 for R2
// ── Steps 7+: Cycle until all requests complete ──
loop {
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
if state.is_empty() {
break;
}
}
assert_eq!(state.waiting.len(), 0);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 0);
assert_eq!(kv_manager.num_active_blocks(), 0);
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let block_size: usize = 64;
let input_tokens = 256;
let max_output_tokens = 200; // More than we'll receive
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args
let args = MockEngineArgs::builder()
.num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
.block_size(block_size)
.speedup_ratio(100.0) // Fast simulation
.build()
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create request with 256 tokens
let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
let request = DirectRequest {
tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
};
scheduler.receive(request);
// Receive exactly 129 tokens
let mut received_count = 0;
while received_count < 129 {
if let Some(_signal) = output_rx.recv().await {
received_count += 1;
} else {
panic!("Channel closed before receiving 129 tokens");
}
}
// Drop the receiver immediately
drop(output_rx);
// Wait for 1 second to allow cleanup
tokio::time::sleep(Duration::from_secs(1)).await;
// Check forward pass metrics
let metrics_rx = scheduler.metrics_receiver();
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
}
#[derive(Debug)]
struct ManualReplayResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
idle_jump_ms: f64,
first_decode_end_ms: f64,
}
#[derive(Debug)]
struct ManualConcurrencyResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
}
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn replay_fixture() -> Vec<DirectRequest> {
vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![9, 9, 9, 9, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
]
}
fn run_trace_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
) -> ManualReplayResult {
let mut requests = requests;
requests.sort_by(|left, right| {
let left_ts = left.arrival_timestamp_ms.unwrap();
let right_ts = right.arrival_timestamp_ms.unwrap();
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests.first().unwrap().arrival_timestamp_ms.unwrap();
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
request.arrival_timestamp_ms =
Some(request.arrival_timestamp_ms.unwrap() - first_arrival_ms);
request
})
.collect::<Vec<_>>(),
);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
let mut idle_jump_ms = 0.0;
let mut first_decode_end_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
if state.is_empty() {
let next_arrival_ms = pending.front().unwrap().arrival_timestamp_ms.unwrap();
current_time_ms = next_arrival_ms;
if idle_jump_ms == 0.0 && current_time_ms > 0.0 {
idle_jump_ms = current_time_ms;
}
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
continue;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
args,
Some(&mut collector),
current_time_ms,
true,
);
if first_decode_end_ms == 0.0 && decode_time > Duration::ZERO {
first_decode_end_ms = current_time_ms + decode_time.as_secs_f64() * 1000.0;
}
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualReplayResult {
report: collector.finish(),
snapshots,
idle_jump_ms,
first_decode_end_ms,
}
}
fn run_concurrency_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> ManualConcurrencyResult {
let mut pending = VecDeque::from(requests);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_concurrency_arrivals(
&mut pending,
&mut state,
&mut collector,
current_time_ms,
max_in_flight,
);
if state.is_empty() {
break;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualConcurrencyResult {
report: collector.finish(),
snapshots,
}
}
fn assert_report_close(left: &TraceSimulationReport, right: &TraceSimulationReport) {
let epsilon = 1e-9;
assert_eq!(
left.request_counts.num_requests,
right.request_counts.num_requests
);
assert_eq!(
left.request_counts.completed_requests,
right.request_counts.completed_requests
);
assert_eq!(
left.request_counts.total_input_tokens,
right.request_counts.total_input_tokens
);
assert_eq!(
left.request_counts.total_output_tokens,
right.request_counts.total_output_tokens
);
assert!((left.throughput.duration_ms - right.throughput.duration_ms).abs() <= epsilon);
assert!(
(left.throughput.request_throughput_rps - right.throughput.request_throughput_rps)
.abs()
<= epsilon
);
assert!(
(left.throughput.input_throughput_tok_s - right.throughput.input_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.output_throughput_tok_s - right.throughput.output_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.total_throughput_tok_s - right.throughput.total_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.prefix_cache_reused_ratio - right.prefix_cache_reused_ratio).abs() <= epsilon
);
assert!((left.latency.ttft.mean_ms - right.latency.ttft.mean_ms).abs() <= epsilon);
assert!((left.latency.ttft.min_ms - right.latency.ttft.min_ms).abs() <= epsilon);
assert!((left.latency.ttft.max_ms - right.latency.ttft.max_ms).abs() <= epsilon);
assert!((left.latency.ttft.median_ms - right.latency.ttft.median_ms).abs() <= epsilon);
assert!((left.latency.ttft.p75_ms - right.latency.ttft.p75_ms).abs() <= epsilon);
assert!((left.latency.ttft.p90_ms - right.latency.ttft.p90_ms).abs() <= epsilon);
assert!((left.latency.ttft.p95_ms - right.latency.ttft.p95_ms).abs() <= epsilon);
assert!((left.latency.ttft.p99_ms - right.latency.ttft.p99_ms).abs() <= epsilon);
assert!((left.latency.ttft.std_ms - right.latency.ttft.std_ms).abs() <= epsilon);
assert!((left.latency.ttst.mean_ms - right.latency.ttst.mean_ms).abs() <= epsilon);
assert!((left.latency.ttst.min_ms - right.latency.ttst.min_ms).abs() <= epsilon);
assert!((left.latency.ttst.max_ms - right.latency.ttst.max_ms).abs() <= epsilon);
assert!((left.latency.ttst.median_ms - right.latency.ttst.median_ms).abs() <= epsilon);
assert!((left.latency.ttst.p75_ms - right.latency.ttst.p75_ms).abs() <= epsilon);
assert!((left.latency.ttst.p90_ms - right.latency.ttst.p90_ms).abs() <= epsilon);
assert!((left.latency.ttst.p95_ms - right.latency.ttst.p95_ms).abs() <= epsilon);
assert!((left.latency.ttst.p99_ms - right.latency.ttst.p99_ms).abs() <= epsilon);
assert!((left.latency.ttst.std_ms - right.latency.ttst.std_ms).abs() <= epsilon);
assert!((left.latency.tpot.mean_ms - right.latency.tpot.mean_ms).abs() <= epsilon);
assert!((left.latency.tpot.min_ms - right.latency.tpot.min_ms).abs() <= epsilon);
assert!((left.latency.tpot.max_ms - right.latency.tpot.max_ms).abs() <= epsilon);
assert!((left.latency.tpot.median_ms - right.latency.tpot.median_ms).abs() <= epsilon);
assert!((left.latency.tpot.p75_ms - right.latency.tpot.p75_ms).abs() <= epsilon);
assert!((left.latency.tpot.p90_ms - right.latency.tpot.p90_ms).abs() <= epsilon);
assert!((left.latency.tpot.p95_ms - right.latency.tpot.p95_ms).abs() <= epsilon);
assert!((left.latency.tpot.p99_ms - right.latency.tpot.p99_ms).abs() <= epsilon);
assert!((left.latency.tpot.std_ms - right.latency.tpot.std_ms).abs() <= epsilon);
assert!(
(left.latency.itl.distribution.mean_ms - right.latency.itl.distribution.mean_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.min_ms - right.latency.itl.distribution.min_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.max_ms - right.latency.itl.distribution.max_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.median_ms - right.latency.itl.distribution.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p75_ms - right.latency.itl.distribution.p75_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p90_ms - right.latency.itl.distribution.p90_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p95_ms - right.latency.itl.distribution.p95_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p99_ms - right.latency.itl.distribution.p99_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.std_ms - right.latency.itl.distribution.std_ms).abs()
<= epsilon
);
assert!((left.latency.itl.max_ms - right.latency.itl.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.mean_ms - right.latency.e2e.mean_ms).abs() <= epsilon);
assert!((left.latency.e2e.min_ms - right.latency.e2e.min_ms).abs() <= epsilon);
assert!((left.latency.e2e.max_ms - right.latency.e2e.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.median_ms - right.latency.e2e.median_ms).abs() <= epsilon);
assert!((left.latency.e2e.p75_ms - right.latency.e2e.p75_ms).abs() <= epsilon);
assert!((left.latency.e2e.p90_ms - right.latency.e2e.p90_ms).abs() <= epsilon);
assert!((left.latency.e2e.p95_ms - right.latency.e2e.p95_ms).abs() <= epsilon);
assert!((left.latency.e2e.p99_ms - right.latency.e2e.p99_ms).abs() <= epsilon);
assert!((left.latency.e2e.std_ms - right.latency.e2e.std_ms).abs() <= epsilon);
assert!(
(left.latency.output_token_throughput_per_user.mean_ms
- right.latency.output_token_throughput_per_user.mean_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.min_ms
- right.latency.output_token_throughput_per_user.min_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.max_ms
- right.latency.output_token_throughput_per_user.max_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.median_ms
- right.latency.output_token_throughput_per_user.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p75_ms
- right.latency.output_token_throughput_per_user.p75_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p90_ms
- right.latency.output_token_throughput_per_user.p90_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p95_ms
- right.latency.output_token_throughput_per_user.p95_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p99_ms
- right.latency.output_token_throughput_per_user.p99_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.std_ms
- right.latency.output_token_throughput_per_user.std_ms)
.abs()
<= epsilon
);
}
#[rstest]
#[case(false, false)]
#[case(false, true)]
#[case(true, false)]
#[case(true, true)]
fn test_trace_replay_matches_manual_steps(
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let args = replay_args(enable_prefix_caching, enable_chunked_prefill);
let manual = run_trace_manually(&args, replay_fixture());
let replay_report = simulate_trace(args, replay_fixture()).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 1.0);
assert_eq!(request_3.arrival_time_ms, 400.0);
assert_eq!(manual.idle_jump_ms, 400.0);
assert_eq!(
request_1.first_token_ms.unwrap(),
manual.first_decode_end_ms,
);
assert!(request_2.first_admit_ms.unwrap() >= request_2.arrival_time_ms);
assert!(request_3.first_admit_ms.unwrap() >= request_3.arrival_time_ms);
assert!(manual.report.latency.e2e.mean_ms >= manual.report.latency.ttft.mean_ms);
if enable_prefix_caching {
assert!(request_2.reused_input_tokens > 0);
assert!(manual.report.prefix_cache_reused_ratio > 0.0);
} else {
assert_eq!(request_2.reused_input_tokens, 0);
assert_eq!(manual.report.prefix_cache_reused_ratio, 0.0);
}
assert_report_close(&replay_report, &manual.report);
}
#[test]
fn test_concurrency_replay_matches_manual_steps() {
let args = replay_args(false, false);
let requests = vec![
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 6, 7, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 9, 10, 11],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![12, 13, 14, 15, 16, 17, 18, 19],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
];
let manual = run_concurrency_manually(&args, requests.clone(), 2);
let replay_report = simulate_concurrency(args, requests, 2).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(manual.report.request_counts.completed_requests, 3);
assert_eq!(manual.report.request_counts.total_input_tokens, 24);
assert_eq!(manual.report.request_counts.total_output_tokens, 6);
assert_report_close(&replay_report, &manual.report);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet, VecDeque};
use std::time::Duration;
use dynamo_kv_router::protocols::WorkerId;
use dynamo_tokens::blocks::UniqueBlock;
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::common::protocols::{
DirectRequest, KvEventPublishers, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
WorkerType,
};
use crate::common::sequence::ActiveSequence;
use crate::kv_manager::KvManager;
use crate::replay::TraceCollector;
use crate::scheduler::{
AdmissionEvent, CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility,
capture_router_event_sink,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RequestStatus {
Waiting,
Running,
Preempted,
}
pub(crate) struct VllmRequestState {
pub(crate) sequence: ActiveSequence,
pub(crate) status: RequestStatus,
pub(crate) num_computed_tokens: usize,
pub(crate) num_preemptions: usize,
}
#[derive(Default)]
pub(crate) struct SchedulerState {
pub(crate) waiting: VecDeque<Uuid>,
waiting_members: HashSet<Uuid>,
pub(crate) running: VecDeque<Uuid>,
running_members: HashSet<Uuid>,
pub(crate) requests: HashMap<Uuid, VllmRequestState>,
}
struct PreemptedRequest {
uuid: Uuid,
signals: Vec<MoveBlock>,
}
#[derive(Clone, Copy, Debug, Default)]
struct ScheduledWork {
total_tokens: usize,
prompt_tokens: usize,
prefix_tokens: usize,
}
enum ScheduleOutcome {
Scheduled {
tokens_used: usize,
admission: Option<AdmissionEvent>,
},
Blocked,
CurrentPreempted,
}
impl SchedulerState {
pub(crate) fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn push_waiting(&mut self, uuid: Uuid) {
if !self.waiting_members.insert(uuid) {
return;
}
self.waiting.push_back(uuid);
}
fn prepend_waiting(&mut self, uuid: Uuid) {
if !self.waiting_members.insert(uuid) {
return;
}
self.waiting.push_front(uuid);
}
fn next_waiting_uuid(&mut self) -> Option<Uuid> {
loop {
let uuid = *self.waiting.front()?;
let Some(request) = self.requests.get(&uuid) else {
self.waiting.pop_front();
self.waiting_members.remove(&uuid);
continue;
};
if self.waiting_members.contains(&uuid) && request.status != RequestStatus::Running {
return Some(uuid);
}
self.waiting.pop_front();
self.waiting_members.remove(&uuid);
}
}
fn compact_running(&mut self) {
let mut compacted = VecDeque::with_capacity(self.running.len());
while let Some(uuid) = self.running.pop_front() {
let is_running = self.running_members.contains(&uuid)
&& self
.requests
.get(&uuid)
.is_some_and(|request| request.status == RequestStatus::Running);
if is_running {
compacted.push_back(uuid);
continue;
}
self.running_members.remove(&uuid);
}
self.running = compacted;
}
fn transition_to_running(&mut self, uuid: Uuid) {
if self.waiting.front().copied() == Some(uuid) {
self.waiting.pop_front();
}
self.waiting_members.remove(&uuid);
if self.running_members.insert(uuid) {
self.running.push_back(uuid);
}
if let Some(request) = self.requests.get_mut(&uuid) {
request.status = RequestStatus::Running;
}
}
pub(crate) fn complete(&mut self, uuid: &Uuid) {
self.waiting_members.remove(uuid);
self.running_members.remove(uuid);
self.requests.remove(uuid);
}
pub(crate) fn running_sequence_mut(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
if !self.running_members.contains(&uuid) {
return None;
}
self.requests
.get_mut(&uuid)
.map(|request| &mut request.sequence)
}
fn preempt(&mut self, mode: PreemptionMode) -> Option<PreemptedRequest> {
let uuid = loop {
let candidate = match mode {
PreemptionMode::Lifo => self.running.pop_back(),
PreemptionMode::Fifo => self.running.pop_front(),
}?;
let is_running = self.running_members.contains(&candidate)
&& self
.requests
.get(&candidate)
.is_some_and(|request| request.status == RequestStatus::Running);
if is_running {
break candidate;
}
self.running_members.remove(&candidate);
};
self.running_members.remove(&uuid);
let request = self.requests.get_mut(&uuid)?;
request.status = RequestStatus::Preempted;
request.num_computed_tokens = 0;
request.num_preemptions += 1;
let signals = request.sequence.reset_with_signal();
debug_assert_vllm_request_invariants(uuid, request);
#[cfg(debug_assertions)]
{
debug_assert_eq!(
request.sequence.num_allocated_tokens(),
0,
"preempted request {uuid} should release all allocated KV"
);
}
self.prepend_waiting(uuid);
Some(PreemptedRequest { uuid, signals })
}
#[cfg(test)]
pub(super) fn insert_running_for_test(&mut self, uuid: Uuid) {
self.running_members.insert(uuid);
self.running.push_back(uuid);
}
}
pub(crate) struct VllmCore {
args: MockEngineArgs,
pub(super) state: SchedulerState,
pub(super) kv_manager: KvManager,
kv_event_buffer: Option<CapturedRouterEventBuffer>,
}
impl VllmCore {
pub(crate) fn new(args: MockEngineArgs) -> Self {
Self::new_internal(args, 0, None, KvEventPublishers::default())
}
pub(crate) fn new_with_kv_capture(args: MockEngineArgs, worker_id: WorkerId) -> Self {
let (buffer, sink) = capture_router_event_sink(worker_id);
Self::new_internal(
args,
0,
Some(buffer),
KvEventPublishers::new(Some(sink), None),
)
}
pub(super) fn new_with_sink(
args: MockEngineArgs,
dp_rank: u32,
kv_event_publishers: KvEventPublishers,
) -> Self {
Self::new_internal(args, dp_rank, None, kv_event_publishers)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
kv_event_buffer: Option<CapturedRouterEventBuffer>,
kv_event_publishers: KvEventPublishers,
) -> Self {
let args = args.normalized().expect("invalid MockEngineArgs");
Self {
kv_manager: KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
kv_event_publishers,
dp_rank,
),
args,
state: SchedulerState::default(),
kv_event_buffer,
}
}
pub(crate) fn receive(&mut self, request: DirectRequest) -> Uuid {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
let sequence = ActiveSequence::new(
request.tokens,
request.max_output_tokens,
Some(self.args.block_size),
self.args.enable_prefix_caching,
self.args.zmq_kv_events_port.is_some(),
);
self.state.requests.insert(
uuid,
VllmRequestState {
sequence,
status: RequestStatus::Waiting,
num_computed_tokens: 0,
num_preemptions: 0,
},
);
self.state.push_waiting(uuid);
if let Some(request) = self.state.requests.get(&uuid) {
debug_assert_vllm_request_progress(uuid, request);
}
uuid
}
pub(crate) fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub(crate) fn num_requests(&self) -> usize {
self.state.requests.len()
}
pub(crate) fn execute_pass(
&mut self,
collector: &mut TraceCollector,
now_ms: f64,
) -> EnginePassResult {
self.execute_pass_internal(Some(collector), now_ms, None)
}
pub(super) fn execute_pass_internal(
&mut self,
mut collector: Option<&mut TraceCollector>,
now_ms: f64,
admission_tx: Option<&mpsc::UnboundedSender<AdmissionEvent>>,
) -> EnginePassResult {
let requests_before = self.state.requests.len();
self.state.compact_running();
let mut token_budget = self.args.max_num_batched_tokens.unwrap_or(usize::MAX);
let mut scheduled = HashMap::new();
let mut batch_count = 0usize;
let mut batch_total_isl = 0usize;
let mut batch_total_prefix = 0usize;
let mut admissions = Vec::new();
let mut preempted_any = false;
let mut req_index = 0usize;
while req_index < self.state.running.len() && token_budget > 0 {
let uuid = self.state.running[req_index];
match self.schedule_request(
uuid,
false,
&mut token_budget,
&mut scheduled,
&mut batch_count,
&mut batch_total_isl,
&mut batch_total_prefix,
&mut preempted_any,
) {
ScheduleOutcome::Scheduled { admission, .. } => {
if let Some(admission) = admission {
if let Some(collector) = collector.as_deref_mut() {
collector.on_admit(
admission.uuid,
now_ms,
admission.reused_input_tokens,
);
}
if let Some(admission_tx) = admission_tx {
let _ = admission_tx.send(admission.clone());
}
admissions.push(admission);
}
req_index += 1;
}
ScheduleOutcome::Blocked => break,
ScheduleOutcome::CurrentPreempted => {}
}
}
let max_num_running = self.args.max_num_seqs.unwrap_or(usize::MAX);
while !preempted_any && self.state.running.len() < max_num_running {
let Some(uuid) = self.state.next_waiting_uuid() else {
break;
};
match self.schedule_request(
uuid,
true,
&mut token_budget,
&mut scheduled,
&mut batch_count,
&mut batch_total_isl,
&mut batch_total_prefix,
&mut preempted_any,
) {
ScheduleOutcome::Scheduled {
admission,
tokens_used,
} => {
if let Some(admission) = admission {
if let Some(collector) = collector.as_deref_mut() {
collector.on_admit(
admission.uuid,
now_ms,
admission.reused_input_tokens,
);
}
if let Some(admission_tx) = admission_tx {
let _ = admission_tx.send(admission.clone());
}
admissions.push(admission);
}
if tokens_used == 0 && token_budget == 0 {
break;
}
}
ScheduleOutcome::Blocked | ScheduleOutcome::CurrentPreempted => break,
}
}
let prefill_time =
predict_prefill_duration(batch_count, batch_total_isl, batch_total_prefix, &self.args);
let decode_start_ms = now_ms + prefill_time.as_secs_f64() * 1000.0;
let (decode_time, output_signals) = self.emit_ready_tokens(collector, decode_start_ms);
let end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0;
debug_assert_vllm_scheduler_state(&self.state);
EnginePassResult {
end_ms,
completed_requests: requests_before.saturating_sub(self.state.requests.len()),
output_signals,
admissions,
active_decode_blocks: self.kv_manager.num_active_blocks() as u64,
router_event_visibility: RouterEventVisibility::PassStart,
kv_events: self
.kv_event_buffer
.as_ref()
.map(CapturedRouterEventBuffer::drain)
.unwrap_or_default(),
}
}
pub(super) fn drop_request(&mut self, uuid: Uuid) {
let Some(request) = self.state.requests.get(&uuid) else {
return;
};
for signal in request.sequence.free_signal() {
self.kv_manager.process(&signal);
}
self.state.complete(&uuid);
}
#[allow(clippy::too_many_arguments)]
fn schedule_request(
&mut self,
uuid: Uuid,
from_waiting: bool,
token_budget: &mut usize,
scheduled: &mut HashMap<Uuid, ScheduledWork>,
batch_count: &mut usize,
batch_total_isl: &mut usize,
batch_total_prefix: &mut usize,
preempted_any: &mut bool,
) -> ScheduleOutcome {
let Some(request) = self.state.requests.get(&uuid) else {
return ScheduleOutcome::Blocked;
};
debug_assert_vllm_request_invariants(uuid, request);
let prefill_cost = self.kv_manager.get_prefill_cost(&request.sequence);
let cached_prefix_tokens = if request.num_computed_tokens == 0 {
prefill_cost.cached_tokens
} else {
0
};
let effective_computed_before = request.num_computed_tokens + cached_prefix_tokens;
let prompt_len = request.sequence.num_input_tokens();
let prompt_before = effective_computed_before.min(prompt_len);
let remaining_known_tokens = request
.sequence
.len()
.saturating_sub(effective_computed_before);
let prompt_remaining = prompt_len.saturating_sub(prompt_before);
if prompt_remaining > 0
&& !self.args.enable_chunked_prefill
&& prompt_remaining > *token_budget
{
return ScheduleOutcome::Blocked;
}
let desired_tokens = remaining_known_tokens.min(*token_budget);
if desired_tokens == 0 && remaining_known_tokens > 0 {
return ScheduleOutcome::Blocked;
}
let desired_computed_after = effective_computed_before + desired_tokens;
let mut actual_computed_after = desired_computed_after;
loop {
let allocation = {
let Some(request) = self.state.requests.get_mut(&uuid) else {
return ScheduleOutcome::Blocked;
};
let allocation_target = desired_computed_after;
let prev_allocated_tokens = request.sequence.num_allocated_tokens();
if allocation_target <= prev_allocated_tokens {
request.num_computed_tokens = actual_computed_after;
None
} else {
let maybe_signal = request.sequence.prepare_allocation(allocation_target);
Some((allocation_target, prev_allocated_tokens, maybe_signal))
}
};
let Some((allocation_target, prev_allocated_tokens, maybe_signal)) = allocation else {
break;
};
let Some(signal) = maybe_signal else {
let Some(request) = self.state.requests.get_mut(&uuid) else {
return ScheduleOutcome::Blocked;
};
request.sequence.commit_allocation(allocation_target);
request.num_computed_tokens = actual_computed_after;
break;
};
let expected = match &signal {
MoveBlock::Use(blocks, ..) => blocks.len(),
_ => unreachable!(),
};
let allocated = self.kv_manager.process(&signal);
let (_committed_tokens, current_computed_tokens) = {
let Some(request) = self.state.requests.get_mut(&uuid) else {
return ScheduleOutcome::Blocked;
};
let committed_tokens = if allocated == expected {
allocation_target
} else {
let prev_blocks = prev_allocated_tokens
.div_ceil(request.sequence.block_size())
.min(request.sequence.unique_blocks().len());
(prev_blocks + allocated) * request.sequence.block_size()
};
request
.sequence
.commit_allocation(committed_tokens.min(allocation_target));
request.num_computed_tokens = actual_computed_after.min(committed_tokens);
(committed_tokens, request.num_computed_tokens)
};
if allocated == expected {
break;
}
let Some(preempted) = self.state.preempt(self.args.preemption_mode) else {
actual_computed_after = current_computed_tokens;
break;
};
for signal in preempted.signals {
self.kv_manager.process(&signal);
}
*preempted_any = true;
if let Some(undone) = scheduled.remove(&preempted.uuid) {
*token_budget += undone.total_tokens;
if undone.prompt_tokens > 0 && self.args.worker_type != WorkerType::Decode {
*batch_count = batch_count.saturating_sub(1);
*batch_total_isl =
batch_total_isl.saturating_sub(undone.prefix_tokens + undone.prompt_tokens);
*batch_total_prefix = batch_total_prefix.saturating_sub(undone.prefix_tokens);
}
}
if preempted.uuid == uuid {
return ScheduleOutcome::CurrentPreempted;
}
}
if let Some(request) = self.state.requests.get(&uuid) {
debug_assert_vllm_request_invariants(uuid, request);
}
let tokens_used = actual_computed_after.saturating_sub(effective_computed_before);
if tokens_used == 0
&& actual_computed_after < request_sequence_len(&self.state.requests, uuid)
{
return ScheduleOutcome::Blocked;
}
let prompt_after = actual_computed_after.min(prompt_len);
let prompt_tokens = prompt_after.saturating_sub(prompt_before);
scheduled.insert(
uuid,
ScheduledWork {
total_tokens: tokens_used,
prompt_tokens,
prefix_tokens: prompt_before,
},
);
if prompt_tokens > 0 && self.args.worker_type != WorkerType::Decode {
*batch_count += 1;
*batch_total_isl += prompt_before + prompt_tokens;
*batch_total_prefix += prompt_before;
}
if from_waiting {
self.state.transition_to_running(uuid);
}
*token_budget = token_budget.saturating_sub(tokens_used);
let admission = if from_waiting {
Some(AdmissionEvent {
uuid,
reused_input_tokens: cached_prefix_tokens,
})
} else {
None
};
ScheduleOutcome::Scheduled {
tokens_used,
admission,
}
}
fn emit_ready_tokens(
&mut self,
mut collector: Option<&mut TraceCollector>,
decode_start_ms: f64,
) -> (Duration, Vec<OutputSignal>) {
let ready = self
.state
.running
.iter()
.copied()
.filter(|uuid| {
let Some(request) = self.state.requests.get(uuid) else {
return false;
};
request.num_computed_tokens >= request.sequence.len()
&& request.sequence.generated_tokens() < request.sequence.max_output_tokens()
})
.collect::<Vec<_>>();
if ready.is_empty() {
return (Duration::ZERO, Vec::new());
}
let active_kv_tokens = self.kv_manager.num_active_blocks() * self.args.block_size;
let total_length = ready
.iter()
.filter_map(|uuid| self.state.requests.get(uuid))
.map(|request| request.sequence.len())
.sum::<usize>();
let context_length = total_length / ready.len();
let decode_ms =
self.args
.perf_model
.predict_decode_time(ready.len(), active_kv_tokens, context_length);
let decode_time = scale_decode_time(decode_ms, &self.args);
let decode_end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0;
let mut output_signals = Vec::with_capacity(ready.len());
for uuid in ready {
let mut emitted = false;
let mut completed = false;
loop {
debug_assert_vllm_ready_to_decode(&self.state.requests, uuid);
let Some(sequence) = self.state.running_sequence_mut(uuid) else {
break;
};
let signals = sequence.generate();
if process_signals(&mut self.kv_manager, &signals) {
if sequence.generated_tokens() < sequence.max_output_tokens() {
sequence.commit_allocation(sequence.len());
}
emitted = true;
completed = sequence.generated_tokens() >= sequence.max_output_tokens();
break;
}
sequence.pop();
let Some(preempted) = self.state.preempt(self.args.preemption_mode) else {
break;
};
for signal in preempted.signals {
self.kv_manager.process(&signal);
}
if preempted.uuid == uuid {
break;
}
}
if !emitted {
continue;
}
if let Some(collector) = collector.as_deref_mut() {
collector.on_token(uuid, decode_end_ms);
}
if let Some(request) = self.state.requests.get(&uuid) {
debug_assert_vllm_request_progress(uuid, request);
}
output_signals.push(OutputSignal { uuid, completed });
if completed {
self.state.complete(&uuid);
}
}
if output_signals.is_empty() {
return (Duration::ZERO, output_signals);
}
self.state.compact_running();
(decode_time, output_signals)
}
}
fn request_sequence_len(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) -> usize {
requests
.get(&uuid)
.map(|request| request.sequence.len())
.unwrap_or_default()
}
fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState) {
#[cfg(debug_assertions)]
{
let seq_len = request.sequence.len();
let allocated = request.sequence.num_allocated_tokens();
debug_assert!(
request.num_computed_tokens <= seq_len,
"request {uuid} computed {} tokens but sequence length is {seq_len}",
request.num_computed_tokens
);
debug_assert!(
allocated <= seq_len,
"request {uuid} allocated {allocated} tokens but sequence length is {seq_len}"
);
}
}
fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
#[cfg(debug_assertions)]
{
debug_assert_vllm_request_invariants(uuid, request);
let allocated = request.sequence.num_allocated_tokens();
debug_assert!(
allocated >= request.num_computed_tokens,
"request {uuid} allocated {allocated} tokens but computed {}",
request.num_computed_tokens
);
}
}
fn debug_assert_vllm_ready_to_decode(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) {
#[cfg(debug_assertions)]
{
let Some(request) = requests.get(&uuid) else {
return;
};
let seq_len = request.sequence.len();
if request.num_computed_tokens < seq_len {
return;
}
let allocated = request.sequence.num_allocated_tokens();
debug_assert_eq!(
allocated, seq_len,
"request {uuid} is decode-ready but allocated {allocated} tokens for sequence length {seq_len}"
);
}
}
fn debug_assert_vllm_scheduler_state(state: &SchedulerState) {
#[cfg(debug_assertions)]
{
let mut seen = std::collections::HashSet::new();
for uuid in &state.waiting_members {
debug_assert!(
seen.insert(*uuid),
"request {uuid} appears multiple times across waiting/running queues"
);
let request = state
.requests
.get(uuid)
.expect("waiting request missing from state map");
debug_assert!(
request.status != RequestStatus::Running,
"request {uuid} is queued in waiting but marked Running"
);
debug_assert_vllm_request_invariants(*uuid, request);
}
for uuid in &state.running_members {
debug_assert!(
seen.insert(*uuid),
"request {uuid} appears multiple times across waiting/running queues"
);
let request = state
.requests
.get(uuid)
.expect("running request missing from state map");
debug_assert_eq!(
request.status,
RequestStatus::Running,
"request {uuid} is queued in running but marked {:?}",
request.status
);
debug_assert_vllm_request_invariants(*uuid, request);
}
debug_assert!(
state.waiting.len() >= state.waiting_members.len(),
"waiting queue dropped live membership entries"
);
debug_assert!(
state.running.len() >= state.running_members.len(),
"running queue dropped live membership entries"
);
}
}
fn predict_prefill_duration(
batch_count: usize,
batch_total_isl: usize,
batch_total_prefix: usize,
args: &MockEngineArgs,
) -> Duration {
if batch_count == 0 || args.worker_type == WorkerType::Decode {
return Duration::ZERO;
}
let mean_isl = batch_total_isl / batch_count;
let mean_prefix = batch_total_prefix / batch_count;
let prefill_ms = args
.perf_model
.predict_prefill_time(batch_count, mean_isl, mean_prefix);
let total_time = Duration::from_secs_f64(prefill_ms / 1000.0);
if args.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
return total_time;
}
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio)
}
fn scale_decode_time(decode_ms: f64, args: &MockEngineArgs) -> Duration {
let unscaled = Duration::from_secs_f64(decode_ms / 1000.0);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
if effective_ratio <= 0.0 || unscaled <= Duration::ZERO {
return unscaled;
}
Duration::from_secs_f64(unscaled.as_secs_f64() / effective_ratio)
}
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
for signal in signals {
if kv_manager.process(signal) > 0 {
continue;
}
let MoveBlock::Use(blocks, ..) = signal else {
panic!("Failed signal is invalid. Expected decode allocation failure, got {signal:?}");
};
if blocks.len() != 1 {
panic!(
"Failed signal is invalid. Tried to allocate {} blocks during decode.",
blocks.len()
);
}
if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
panic!("Failed signal is invalid. Decode allocation must use a partial block.");
}
return false;
}
true
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
AdmissionEvent, RouterEventVisibility, SchedulerHandle, capture_deferred_kv_publish_sink,
publish_deferred_kv_events,
};
use super::core::VllmCore;
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
pub dp_rank: dynamo_kv_router::protocols::DpRank,
pub active_decode_blocks: u64,
pub total_blocks: u64,
pub gpu_cache_usage_perc: f64,
}
impl MockerMetrics {
pub fn new(
dp_rank: dynamo_kv_router::protocols::DpRank,
active_decode_blocks: u64,
total_blocks: u64,
) -> Self {
let gpu_cache_usage_perc = if total_blocks == 0 {
0.0
} else {
active_decode_blocks as f64 / total_blocks as f64
};
Self {
dp_rank,
active_decode_blocks,
total_blocks,
gpu_cache_usage_perc,
}
}
}
#[derive(Clone)]
pub struct Scheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
impl Scheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
None,
)
}
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
admission_tx,
)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let total_blocks = args.num_gpu_blocks as u64;
let initial_metrics = MockerMetrics::new(dp_rank, 0, total_blocks);
let (metrics_tx, metrics_rx) = tokio::sync::watch::channel(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
tokio::spawn(async move {
let (deferred_kv_events, buffering_publishers) =
capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
let mut core = VllmCore::new_with_sink(args, dp_rank, buffering_publishers);
loop {
if receive_requests(&mut core, &mut request_rx, &cancel_token_clone)
.await
.is_none()
{
break;
}
let iteration_start = Instant::now();
let pass = core.execute_pass_internal(None, 0.0, admission_tx.as_ref());
let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
if pass.router_event_visibility == RouterEventVisibility::PassStart {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
if total_time > std::time::Duration::ZERO {
sleep_until_precise(iteration_start + total_time).await;
}
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&mut core, &output_tx, &pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
core.kv_manager.num_active_blocks() as u64,
total_blocks,
));
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl SchedulerHandle for Scheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
core: &mut VllmCore,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if core.is_empty() {
tokio::select! {
biased;
_ = cancel_token.cancelled() => return None,
result = request_rx.recv() => {
let request = result?;
core.receive(request);
}
}
}
while let Ok(request) = request_rx.try_recv() {
core.receive(request);
}
Some(())
}
fn flush_output_signals(
core: &mut VllmCore,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
if tx.send(signal.clone()).is_ok() {
continue;
}
core.drop_request(signal.uuid);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! vLLM scheduler simulation around a unified waiting/running request model.
//!
//! Reference: vllm/vllm/v1/core/sched/scheduler.py
mod core;
mod live;
pub(crate) use core::VllmCore;
pub use live::{MockerMetrics, Scheduler};
#[cfg(test)]
mod tests;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::{Arc, Mutex};
use std::time::Duration;
use dynamo_kv_router::indexer::{METRIC_EVENT_REMOVED, METRIC_EVENT_STORED};
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, WorkerId};
use rstest::rstest;
use tokio::sync::mpsc;
use tokio::time::interval;
use uuid::Uuid;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal,
PreemptionMode, RawKvEvent, RawKvEventSink,
};
use crate::common::sequence::ActiveSequence;
use crate::scheduler::RouterEventVisibility;
use crate::scheduler::SchedulerHandle;
use crate::scheduler::test_utils::{RouterIndexerHarness, removed_event_count, stored_hashes};
use super::core::{RequestStatus, VllmCore, VllmRequestState};
use super::live::{MockerMetrics, Scheduler};
const ROUTER_TEST_WORKER_ID: WorkerId = 23;
fn assert_scheduler_idle(metrics: &MockerMetrics) {
assert_eq!(
metrics.active_decode_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.active_decode_blocks
);
assert_eq!(
metrics.gpu_cache_usage_perc, 0.0,
"Expected 0.0 cache usage, got {}",
metrics.gpu_cache_usage_perc
);
assert!(
metrics.total_blocks > 0,
"Expected total_blocks to be populated, got {}",
metrics.total_blocks
);
}
fn make_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn router_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(12)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.speedup_ratio(0.0)
.build()
.unwrap()
}
mod core_behavior {
use super::*;
#[test]
fn test_unified_pass_keeps_partial_prefill_in_running() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
assert_eq!(
pass.output_signals.len(),
1,
"first request should emit immediately"
);
assert_eq!(core.state.waiting.len(), 0);
assert_eq!(
core.state.running.iter().copied().collect::<Vec<_>>(),
vec![r1, r2]
);
assert_eq!(core.state.requests.get(&r1).unwrap().num_computed_tokens, 8);
assert_eq!(core.state.requests.get(&r2).unwrap().num_computed_tokens, 4);
assert_eq!(
core.state
.requests
.get(&r1)
.unwrap()
.sequence
.generated_tokens(),
1
);
assert_eq!(
core.state.requests.get(&r2).unwrap().status,
RequestStatus::Running
);
assert_eq!(core.kv_manager.num_active_blocks(), 4);
}
#[test]
fn test_running_requests_consume_budget_before_waiting() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(4))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
core.execute_pass(&mut collector, 0.0);
let pass = core.execute_pass(&mut collector, 1.0);
assert!(pass.output_signals.iter().any(|signal| signal.uuid == r1));
assert_eq!(
core.state.requests.get(&r2).unwrap().num_computed_tokens,
0,
"waiting request should not steal budget before the running request catches up"
);
}
#[test]
fn test_first_token_can_arrive_on_prompt_completion_pass() {
let mut core = VllmCore::new(make_args());
let uuid = Uuid::from_u128(11);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
assert_eq!(pass.output_signals.len(), 1);
assert_eq!(pass.output_signals[0].uuid, uuid);
assert!(!pass.output_signals[0].completed);
assert_eq!(
core.state
.requests
.get(&uuid)
.unwrap()
.sequence
.generated_tokens(),
1
);
}
#[test]
fn test_preemption_requeues_newest_running_request() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.preemption_mode(PreemptionMode::Lifo)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
let r3 = Uuid::from_u128(3);
for (uuid, range) in [(r1, 0u32..8u32), (r2, 100u32..108u32), (r3, 200u32..212u32)] {
core.receive(DirectRequest {
tokens: range.collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
core.execute_pass(&mut collector, 0.0);
core.execute_pass(&mut collector, 1.0);
let request = core.state.requests.get(&r2).unwrap();
assert_eq!(request.status, RequestStatus::Preempted);
assert_eq!(request.num_computed_tokens, 0);
assert_eq!(request.num_preemptions, 1);
assert_eq!(core.state.waiting.front().copied(), Some(r2));
}
#[test]
fn test_running_request_catches_up_decode_tail_before_promote() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(8)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(1))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let uuid = Uuid::from_u128(99);
let mut sequence = ActiveSequence::new((0..6).collect(), 16, Some(4), true, false);
let signal = sequence.take_creation_signal().unwrap();
assert_eq!(core.kv_manager.process(&signal), 2);
for _ in 0..6 {
let signals = sequence.generate();
for signal in &signals {
core.kv_manager.process(signal);
}
if sequence.generated_tokens() < sequence.max_output_tokens() {
sequence.commit_allocation(sequence.len());
}
}
let free = sequence.reset_with_signal();
for signal in &free {
core.kv_manager.process(signal);
}
let prompt_only = sequence
.prepare_allocation(sequence.num_input_tokens())
.unwrap();
assert_eq!(core.kv_manager.process(&prompt_only), 2);
sequence.commit_allocation(sequence.num_input_tokens());
core.state.insert_running_for_test(uuid);
core.state.requests.insert(
uuid,
VllmRequestState {
sequence,
status: RequestStatus::Running,
num_computed_tokens: 9,
num_preemptions: 1,
},
);
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let request = core.state.requests.get(&uuid).unwrap();
assert_eq!(pass.output_signals.len(), 1);
assert_eq!(request.num_computed_tokens, 12);
assert_eq!(request.sequence.num_allocated_tokens(), 13);
assert_eq!(core.kv_manager.num_active_blocks(), 4);
}
#[test]
fn test_completion_returns_scheduler_to_idle() {
let mut core = VllmCore::new(make_args());
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
while !core.is_empty() {
core.execute_pass(&mut collector, 0.0);
}
assert!(core.state.waiting.is_empty());
assert!(core.state.running.is_empty());
assert_eq!(core.kv_manager.num_active_blocks(), 0);
}
}
mod router_events {
use super::*;
#[test]
fn test_vllm_pass_visibility_is_pass_start() {
let mut core = VllmCore::new_with_kv_capture(router_args(), ROUTER_TEST_WORKER_ID);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(71)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
assert_eq!(
pass.router_event_visibility,
RouterEventVisibility::PassStart
);
}
#[tokio::test]
async fn test_completion_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = VllmCore::new_with_kv_capture(router_args(), ROUTER_TEST_WORKER_ID);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(41)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let mut now_ms = 0.0;
let mut saw_store = false;
while !core.is_empty() {
let pass = core.execute_pass(&mut collector, now_ms);
saw_store |= !stored_hashes(&pass.kv_events).is_empty();
now_ms = pass.end_ms;
harness.apply_events(pass.kv_events).await;
}
assert!(saw_store);
assert!(harness.ok_count(METRIC_EVENT_STORED) > 0);
assert_eq!(core.kv_manager.num_active_blocks(), 0);
harness.shutdown();
}
#[tokio::test]
async fn test_preemption_recompute_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.preemption_mode(PreemptionMode::Lifo)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new_with_kv_capture(args, ROUTER_TEST_WORKER_ID);
let r1 = Uuid::from_u128(51);
let r2 = Uuid::from_u128(52);
let r3 = Uuid::from_u128(53);
for (uuid, range) in [(r1, 0u32..8u32), (r2, 100u32..108u32), (r3, 200u32..212u32)] {
core.receive(DirectRequest {
tokens: range.collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
let mut now_ms = 0.0;
let mut saw_remove = false;
for _ in 0..2 {
let pass = core.execute_pass(&mut collector, now_ms);
saw_remove |= removed_event_count(&pass.kv_events) > 0;
now_ms = pass.end_ms;
harness.apply_events(pass.kv_events).await;
}
let request = core.state.requests.get(&r2).unwrap();
assert_eq!(request.status, RequestStatus::Preempted);
assert_eq!(request.num_computed_tokens, 0);
assert_eq!(request.num_preemptions, 1);
assert_eq!(core.state.waiting.front().copied(), Some(r2));
assert!(saw_remove);
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
}
mod live_scheduler {
use super::*;
type CapturedKvEvent = (KvCacheEvent, Option<Vec<Vec<u32>>>);
#[derive(Default)]
struct CapturingKvSink {
events: Mutex<Vec<CapturedKvEvent>>,
}
impl CapturingKvSink {
fn take(&self) -> Vec<CapturedKvEvent> {
std::mem::take(&mut *self.events.lock().unwrap())
}
}
impl KvCacheEventSink for CapturingKvSink {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.events.lock().unwrap().push((event, None));
Ok(())
}
}
impl RawKvEventSink for CapturingKvSink {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()> {
self.events
.lock()
.unwrap()
.push((event.event, event.block_token_ids));
Ok(())
}
}
#[rstest]
#[case::case_1(false, false, false)]
#[case::case_2(false, true, false)]
#[case::case_3(true, false, false)]
#[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
async fn test_cache_hit_rate_with_identical_requests() {
let block_size: usize = 64;
let max_output_tokens: usize = 10;
let speedup_ratio = 10.0;
let num_requests = 10;
let token_length = 65;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
.block_size(block_size)
.speedup_ratio(speedup_ratio)
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let identical_tokens: Vec<u32> = (0..token_length).collect();
for _ in 0..num_requests {
scheduler.receive(DirectRequest {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
tokio::time::sleep(Duration::from_millis(100)).await;
}
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
let metrics_rx = scheduler.metrics_receiver();
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => break,
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
assert_eq!(received_tokens, num_requests * max_output_tokens);
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(10)
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
scheduler.receive(DirectRequest {
tokens: (0..256).collect(),
max_output_tokens: 200,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut received_count = 0;
while received_count < 129 {
if output_rx.recv().await.is_some() {
received_count += 1;
continue;
}
panic!("Channel closed before receiving 129 tokens");
}
drop(output_rx);
let metrics_rx = scheduler.metrics_receiver();
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
loop {
if metrics_rx.borrow().active_decode_blocks == 0 {
break;
}
if tokio::time::Instant::now() >= deadline {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
}
#[tokio::test]
async fn test_live_scheduler_forwards_buffered_kv_token_ids() {
let sink = Arc::new(CapturingKvSink::default());
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(12)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(1))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.speedup_ratio(1000.0)
.zmq_kv_events_port(Some(12345))
.build()
.unwrap();
let scheduler = Scheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::new(None, Some(sink.clone())),
None,
);
scheduler.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(72)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let signal = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
.await
.expect("scheduler should emit output")
.expect("output channel should stay open");
assert!(signal.completed);
tokio::time::sleep(Duration::from_millis(50)).await;
let events = sink.take();
let stored = events
.into_iter()
.find_map(|(event, block_token_ids)| match event.data {
KvCacheEventData::Stored(_) => block_token_ids,
_ => None,
})
.expect("live scheduler should forward stored KV event token ids");
assert!(!stored.is_empty());
assert!(stored.iter().all(|block| !block.is_empty()));
}
#[tokio::test]
async fn test_live_pathological_load_no_router_event_errors() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = Scheduler::new(
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(3))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(1000.0)
.build()
.unwrap(),
0,
Some(output_tx),
KvEventPublishers::new(Some(sink.clone()), None),
None,
);
for _ in 0..8 {
scheduler.receive(DirectRequest {
tokens: vec![42; 8],
max_output_tokens: 4,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected = 8 * 4;
let mut seen = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
if seen == expected {
break;
}
}
_ = &mut timeout => {
break;
}
}
}
assert_eq!(seen, expected);
drop(scheduler);
drop(sink);
forward_task.await.unwrap();
harness.flush().await;
harness.assert_no_event_errors();
assert!(harness.ok_count(METRIC_EVENT_STORED) > 0);
harness.shutdown();
}
}
......@@ -333,6 +333,9 @@ pub mod model {
/// KV Router configuration environment variables
pub mod router {
/// Minimum number of workers required before KV router startup continues.
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";
/// Queue threshold fraction for prefill token capacity.
/// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
pub const DYN_ROUTER_QUEUE_THRESHOLD: &str = "DYN_ROUTER_QUEUE_THRESHOLD";
......@@ -492,6 +495,7 @@ mod tests {
model::huggingface::HF_HOME,
model::huggingface::HF_HUB_OFFLINE,
// Router
router::DYN_ROUTER_MIN_INITIAL_WORKERS,
router::DYN_ROUTER_QUEUE_THRESHOLD,
router::DYN_ROUTER_QUEUE_POLICY,
// Event Plane
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from pathlib import Path
from types import SimpleNamespace
import pytest
from dynamo.llm import EngineType, EntrypointArgs
MODULE_PATH = (
Path(__file__).resolve().parents[2] / "components/src/dynamo/mocker/config.py"
)
SPEC = importlib.util.spec_from_file_location("dynamo_mocker_config", MODULE_PATH)
assert SPEC is not None
assert SPEC.loader is not None
CONFIG = importlib.util.module_from_spec(SPEC)
SPEC.loader.exec_module(CONFIG)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.gpu_0,
pytest.mark.parallel,
pytest.mark.unit,
]
def make_args(**overrides):
defaults = {
"extra_engine_args": None,
"engine_type": "vllm",
"num_gpu_blocks": 16384,
"block_size": None,
"max_num_seqs": 256,
"max_num_batched_tokens": 8192,
"enable_prefix_caching": True,
"enable_chunked_prefill": True,
"preemption_mode": "lifo",
"speedup_ratio": 1.0,
"decode_speedup_ratio": 1.0,
"dp_size": 1,
"startup_time": None,
"durable_kv_events": False,
"kv_transfer_bandwidth": 64.0,
"reasoning": None,
"sglang_schedule_policy": None,
"sglang_page_size": None,
"sglang_max_prefill_tokens": None,
"sglang_chunked_prefill_size": None,
"sglang_clip_max_new_tokens": None,
"sglang_schedule_conservativeness": None,
"aic_perf_model": False,
"aic_system": None,
"aic_backend_version": None,
"aic_tp_size": None,
"model_path": None,
"is_prefill_worker": False,
"is_decode_worker": False,
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_build_runtime_config_uses_normalized_sglang_page_size_alias():
engine_args = CONFIG.build_mocker_engine_args(
make_args(engine_type="sglang", block_size=None, sglang_page_size=16)
)
block_size, runtime_config = CONFIG.build_runtime_config(engine_args)
assert block_size == 16
assert runtime_config.total_kv_blocks == 16384
assert runtime_config.max_num_seqs == 256
assert runtime_config.max_num_batched_tokens == 8192
def test_build_mocker_engine_args_rejects_mismatched_sglang_sizes():
with pytest.raises(Exception, match="block_size and sglang.page_size to match"):
CONFIG.build_mocker_engine_args(
make_args(engine_type="sglang", block_size=8, sglang_page_size=4)
)
def test_load_mocker_engine_args_from_json_file_normalizes_page_size(tmp_path):
config_path = tmp_path / "engine_args.json"
config_path.write_text(
'{"engine_type":"sglang","sglang":{"page_size":32},"num_gpu_blocks":1024}'
)
engine_args = CONFIG.load_mocker_engine_args(
make_args(extra_engine_args=config_path)
)
assert engine_args.block_size == 32
assert engine_args.num_gpu_blocks == 1024
def test_worker_overrides_drive_runtime_config_for_prefill_worker():
engine_args = CONFIG.build_mocker_engine_args(make_args(is_prefill_worker=True))
worker_args = CONFIG.apply_worker_engine_args_overrides(
engine_args,
bootstrap_port=9001,
kv_bytes_per_token=128,
)
block_size, runtime_config = CONFIG.build_runtime_config(worker_args)
assert block_size == 64
assert worker_args.bootstrap_port == 9001
assert runtime_config.bootstrap_port == 9001
assert runtime_config.bootstrap_host is not None
def test_runtime_config_disables_local_indexer_for_decode_worker():
engine_args = CONFIG.build_mocker_engine_args(
make_args(is_decode_worker=True, durable_kv_events=False)
)
_, runtime_config = CONFIG.build_runtime_config(engine_args)
assert engine_args.enable_local_indexer is True
assert runtime_config.enable_local_indexer is False
def test_entrypoint_args_accept_typed_mocker_engine_args():
engine_args = CONFIG.build_mocker_engine_args(make_args())
entrypoint_args = EntrypointArgs(
engine_type=EngineType.Mocker,
mocker_engine_args=engine_args,
kv_cache_block_size=engine_args.block_size,
)
assert entrypoint_args is not None
......@@ -5,7 +5,6 @@ import asyncio
import json
import logging
import random
import time
from typing import TYPE_CHECKING, Any, Optional
import aiohttp
......@@ -164,19 +163,16 @@ def _test_router_two_routers(
for i, port in enumerate(router_ports):
logger.info(f"Starting KV router frontend on port {port}")
kv_router = KVRouterProcess(
request, block_size, port, engine_workers.namespace, store_backend
request,
block_size,
port,
engine_workers.namespace,
store_backend,
min_initial_workers=engine_workers.num_workers,
)
kv_router.__enter__()
kv_routers.append(kv_router)
# Add delay between routers for file backend to ensure first router's
# registration is visible before second router starts its cleanup
if i == 0 and store_backend == "file":
logger.info(
"Waiting 0.5s for first router to fully register (file backend)"
)
time.sleep(0.5)
# Wait for workers to be ready on both routers
logger.info("Waiting for workers to register with both routers...")
for i, port in enumerate(router_ports):
......@@ -331,7 +327,7 @@ def _test_python_router_bindings(
AssertionError: If requests fail or router doesn't work correctly
"""
# Create KvRouterConfig with default settings
kv_router_config = KvRouterConfig()
kv_router_config = KvRouterConfig(min_initial_workers=num_workers)
# Create KvRouter Python object
kv_router = KvRouter(
......@@ -834,6 +830,7 @@ def _test_router_indexers_sync(
router_snapshot_threshold=20,
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
min_initial_workers=num_workers,
)
# If standalone indexer mode, launch mockers one-by-one and register.
......@@ -1478,11 +1475,15 @@ def _test_router_decisions(
if standalone_indexer_url:
await engine_workers.launch_mockers_with_indexer(endpoint)
# Workers register one instance per process (not per dp_rank)
expected_num_instances = engine_workers.num_workers
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads,
min_initial_workers=expected_num_instances,
)
kv_router = KvRouter(
endpoint=endpoint,
......@@ -1490,9 +1491,6 @@ def _test_router_decisions(
kv_router_config=kv_router_config,
)
# Workers register one instance per process (not per dp_rank)
expected_num_instances = engine_workers.num_workers
# Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready(
endpoint,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import time
from typing import Any
from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_indexers_sync,
)
from tests.router.helper import get_runtime
from tests.utils.constants import DefaultPort
from tests.utils.port_utils import allocate_ports, deallocate_ports
from tests.utils.test_output import resolve_test_output_path
logger = logging.getLogger(__name__)
TEST_PROMPT = (
"In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on "
"clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint "
"rustle of leaves, but it remained calm, confident in the safety of its burrow "
"just a few hops away. The late afternoon sun warmed its fur, and tiny dust "
"motes danced in the golden light as bees hummed lazily nearby. Though the "
"rabbit lived a simple life, every day was an adventure of scents, shadows, and "
"snacks-an endless search for the tastiest patch of greens and the softest spot "
"to nap."
)
def allocate_frontend_ports(request, count: int) -> list[int]:
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
def build_test_payload(model_name: str) -> dict[str, Any]:
return {
"model": model_name,
"messages": [{"role": "user", "content": TEST_PROMPT}],
"stream": True,
"max_tokens": 10,
}
class ManagedEngineProcessMixin:
process_name = "worker"
cleanup_name = "worker resources"
init_delay_seconds = 5
init_delay_reason = "initialize before starting next worker"
cleanup_delay_seconds = 2
def __enter__(self):
logger.info(
"[%s] Starting %d worker processes sequentially...",
self.__class__.__name__,
len(self.worker_processes),
)
for i, process in enumerate(self.worker_processes):
logger.info(
"[%s] Starting %s %d...", self.__class__.__name__, self.process_name, i
)
try:
process._logger = logging.getLogger(process.__class__.__name__)
process._command_name = process.command[0]
process.log_dir = resolve_test_output_path(process.log_dir)
os.makedirs(process.log_dir, exist_ok=True)
log_name = f"{process._command_name}.log.txt"
process._log_path = os.path.join(process.log_dir, log_name)
if process.data_dir:
process._remove_directory(process.data_dir)
process._terminate_all_matching_process_names()
logger.info(
"[%s] Launching process %d (pid will be assigned)...",
self.__class__.__name__,
i,
)
process._start_process()
logger.info(
"[%s] Worker %d launched with PID: %s",
self.__class__.__name__,
i,
process.proc.pid if process.proc else "unknown",
)
time.sleep(process.delayed_start)
if i < len(self.worker_processes) - 1:
logger.info(
"[%s] Waiting %ss for worker %d to %s...",
self.__class__.__name__,
self.init_delay_seconds,
i,
self.init_delay_reason,
)
time.sleep(self.init_delay_seconds)
except Exception:
logger.exception(
"[%s] Failed to start worker %d", self.__class__.__name__, i
)
try:
process.__exit__(None, None, None)
except Exception as cleanup_err:
logger.warning(
"[%s] Error during cleanup: %s",
self.__class__.__name__,
cleanup_err,
)
raise
logger.info(
"[%s] All %d workers launched with sequential initialization.",
self.__class__.__name__,
len(self.worker_processes),
)
logger.info(
"[%s] Waiting for health checks to complete...", self.__class__.__name__
)
for i, process in enumerate(self.worker_processes):
logger.info(
"[%s] Checking health for worker %d...", self.__class__.__name__, i
)
try:
elapsed = process._check_ports(process.timeout)
process._check_urls(process.timeout - elapsed)
process._check_funcs(process.timeout - elapsed)
logger.info(
"[%s] Worker %d health checks passed", self.__class__.__name__, i
)
except Exception:
logger.error(
"[%s] Worker %d health check failed", self.__class__.__name__, i
)
self.__exit__(None, None, None)
raise
logger.info(
"[%s] All workers started successfully and passed health checks!",
self.__class__.__name__,
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for i, process in enumerate(self.worker_processes):
logger.info("Stopping %s %d", self.process_name, i)
process.__exit__(exc_type, exc_val, exc_tb)
logger.info("Waiting for %s to fully clean up...", self.cleanup_name)
time.sleep(self.cleanup_delay_seconds)
def get_engine_endpoint(engine_workers, request_plane: str, component_name: str):
runtime = get_runtime(request_plane=request_plane)
return runtime.endpoint(f"{engine_workers.namespace}.{component_name}.generate")
def run_basic_router_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
num_workers: int,
single_gpu: bool,
request,
request_plane: str,
block_size: int,
model_name: str,
frontend_timeout: int = 180,
):
with engine_process_cls(
request,
num_workers=num_workers,
single_gpu=single_gpu,
request_plane=request_plane,
**{engine_args_name: engine_args},
) as engine_workers:
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic(
engine_workers=engine_workers,
block_size=block_size,
request=request,
frontend_port=frontend_port,
test_payload=build_test_payload(model_name),
num_requests=10,
frontend_timeout=frontend_timeout,
store_backend="etcd",
request_plane=request_plane,
)
def run_router_decisions_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
request,
request_plane: str,
model_name: str,
block_size: int,
component_name: str,
num_workers: int,
single_gpu: bool,
test_dp_rank: bool,
extra_process_kwargs: dict[str, Any] | None = None,
):
process_kwargs = extra_process_kwargs or {}
with engine_process_cls(
request,
num_workers=num_workers,
single_gpu=single_gpu,
request_plane=request_plane,
**{engine_args_name: engine_args},
**process_kwargs,
) as engine_workers:
endpoint = get_engine_endpoint(engine_workers, request_plane, component_name)
_test_router_decisions(
engine_workers,
endpoint,
model_name,
request,
test_dp_rank=test_dp_rank,
block_size=block_size,
)
def run_indexers_sync_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
request,
runtime_services_dynamic_ports,
store_backend: str,
durable_kv_events: bool,
request_plane: str,
block_size: int,
model_name: str,
num_workers: int,
):
nats_process, _etcd_process = runtime_services_dynamic_ports
with engine_process_cls(
request,
num_workers=num_workers,
single_gpu=True,
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
**{engine_args_name: engine_args},
) as engine_workers:
_test_router_indexers_sync(
engine_workers=engine_workers,
block_size=block_size,
model_name=model_name,
num_workers=num_workers,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
)
......@@ -28,6 +28,7 @@ class FrontendRouterProcess(ManagedProcess):
request_plane: str = "nats",
durable_kv_events: bool = False,
router_mode: str = "kv",
min_initial_workers: int | None = None,
):
command = [
"python3",
......@@ -65,6 +66,8 @@ class FrontendRouterProcess(ManagedProcess):
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
if min_initial_workers is not None:
env["DYN_ROUTER_MIN_INITIAL_WORKERS"] = str(min_initial_workers)
super().__init__(
command=command,
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
from tests.utils.constants import ROUTER_MODEL_NAME
MODEL_NAME = ROUTER_MODEL_NAME
MOONCAKE_TRACE_BLOCK_SIZE = 512
MOONCAKE_TRACE_SAMPLE_LINES = [
'{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}',
'{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}',
'{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}',
'{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}',
'{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}',
'{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}',
'{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}',
'{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}',
'{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}',
'{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}',
'{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}',
'{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}',
'{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}',
'{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}',
'{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}',
'{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}',
'{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}',
'{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}',
'{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}',
'{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}',
]
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.gpu_0,
pytest.mark.integration,
pytest.mark.parallel,
pytest.mark.router,
pytest.mark.model(MODEL_NAME),
]
@pytest.mark.timeout(120)
def test_mocker_trace_file_replay(tmp_path):
repo_root = Path.cwd()
trace_file = tmp_path / "mooncake_trace.jsonl"
trace_file.write_text(
"\n".join(MOONCAKE_TRACE_SAMPLE_LINES) + "\n", encoding="utf-8"
)
replay_report = trace_file.with_name(f"{trace_file.stem}.replay.json")
pythonpath_entries = [
str(repo_root / "components/src"),
str(repo_root / "lib/bindings/python/src"),
]
existing_pythonpath = os.environ.get("PYTHONPATH")
if existing_pythonpath:
pythonpath_entries.append(existing_pythonpath)
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries)
result = subprocess.run(
[
sys.executable,
"-m",
"dynamo.mocker",
"--trace-file",
str(trace_file),
"--model-path",
MODEL_NAME,
"--num-workers",
"1",
"--block-size",
str(MOONCAKE_TRACE_BLOCK_SIZE),
"--speedup-ratio",
"0",
],
cwd=repo_root,
env=env,
capture_output=True,
text=True,
timeout=120,
check=False,
)
assert result.returncode == 0, (
f"dynamo.mocker trace replay failed with exit code {result.returncode}\n"
f"stdout:\n{result.stdout}\n"
f"stderr:\n{result.stderr}"
)
assert replay_report.exists(), (
"Expected default replay report next to the temp trace file, "
f"but {replay_report} was not created.\nstdout:\n{result.stdout}\n"
f"stderr:\n{result.stderr}"
)
assert "Replay Summary" in result.stdout
assert f"JSON report: {replay_report}" in result.stdout
report = json.loads(replay_report.read_text(encoding="utf-8"))
assert report["num_requests"] == len(MOONCAKE_TRACE_SAMPLE_LINES)
assert report["completed_requests"] == len(MOONCAKE_TRACE_SAMPLE_LINES)
assert report["total_input_tokens"] > 0
assert report["total_output_tokens"] > 0
assert report["duration_ms"] > 0
assert report["wall_time_ms"] >= 0
assert report["request_throughput_rps"] > 0
......@@ -7,21 +7,20 @@
# so we set explicit pytest timeouts to fail fast on hangs (see per-test markers below).
import logging
import os
import time
from typing import Any, Dict, Optional
import pytest
from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_indexers_sync,
from tests.router.e2e_harness import (
ManagedEngineProcessMixin,
run_basic_router_test,
run_indexers_sync_test,
run_router_decisions_test,
)
from tests.router.helper import generate_random_suffix, get_runtime
from tests.router.helper import generate_random_suffix
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
from tests.utils.test_output import resolve_test_output_path
logger = logging.getLogger(__name__)
......@@ -33,31 +32,8 @@ pytestmark = [
pytest.mark.sglang,
pytest.mark.model(MODEL_NAME),
]
SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 10
PAGE_SIZE = 16 # SGLang uses "page_size" instead of "block_size"
def allocate_frontend_ports(request, count: int) -> list[int]:
"""Allocate random free frontend ports for xdist-safe execution."""
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
# Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Shared SGLang configuration for all tests
# mem_fraction_static limits actual VRAM allocation (required for multi-worker on same GPU)
SGLANG_ARGS: Dict[str, Any] = {
......@@ -69,7 +45,7 @@ SGLANG_ARGS: Dict[str, Any] = {
}
class SGLangProcess:
class SGLangProcess(ManagedEngineProcessMixin):
"""Manages SGLang workers using dynamo.sglang (HTTP API + KV events).
This is a drop-in replacement for MockerProcess that uses real SGLang workers.
......@@ -242,97 +218,8 @@ class SGLangProcess:
f"with endpoint: {self.endpoint}"
)
def __enter__(self):
"""Start all SGLang worker processes with sequential initialization.
Workers are started sequentially with a delay between each to avoid
resource contention during initialization. This prevents
shared memory handle allocation failures when multiple workers
try to initialize simultaneously on the same GPU.
"""
logger.info(
f"[SGLangProcess] Starting {len(self.worker_processes)} worker processes sequentially..."
)
# Start each process sequentially, waiting for initialization before next
for i, process in enumerate(self.worker_processes):
logger.info(f"[SGLangProcess] Starting SGLang worker {i}...")
try:
# Manually initialize the process without blocking on health checks
process._logger = logging.getLogger(process.__class__.__name__)
process._command_name = process.command[0]
process.log_dir = resolve_test_output_path(process.log_dir)
os.makedirs(process.log_dir, exist_ok=True)
log_name = f"{process._command_name}.log.txt"
process._log_path = os.path.join(process.log_dir, log_name)
if process.data_dir:
process._remove_directory(process.data_dir)
process._terminate_all_matching_process_names()
logger.info(
f"[SGLangProcess] Launching process {i} (pid will be assigned)..."
)
process._start_process() # Start the process but don't wait
logger.info(
f"[SGLangProcess] Worker {i} launched with PID: {process.proc.pid if process.proc else 'unknown'}"
)
time.sleep(process.delayed_start)
# Wait for initialization before starting next worker
# This prevents shared memory contention
if i < len(self.worker_processes) - 1:
init_delay = 5 # seconds
logger.info(
f"[SGLangProcess] Waiting {init_delay}s for worker {i} to initialize before starting next worker..."
)
time.sleep(init_delay)
except Exception:
logger.exception(f"[SGLangProcess] Failed to start worker {i}")
# Clean up on failure
try:
process.__exit__(None, None, None)
except Exception as cleanup_err:
logger.warning(
f"[SGLangProcess] Error during cleanup: {cleanup_err}"
)
raise
logger.info(
f"[SGLangProcess] All {len(self.worker_processes)} workers launched with sequential initialization."
)
logger.info("[SGLangProcess] Waiting for health checks to complete...")
# Now wait for health checks for all processes
for i, process in enumerate(self.worker_processes):
logger.info(f"[SGLangProcess] Checking health for worker {i}...")
try:
elapsed = process._check_ports(process.timeout)
process._check_urls(process.timeout - elapsed)
process._check_funcs(process.timeout - elapsed)
logger.info(f"[SGLangProcess] Worker {i} health checks passed")
except Exception:
logger.error(f"[SGLangProcess] Worker {i} health check failed")
# Clean up all processes on failure
self.__exit__(None, None, None)
raise
logger.info(
"[SGLangProcess] All workers started successfully and passed health checks!"
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all SGLang worker processes gracefully."""
for i, process in enumerate(self.worker_processes):
logger.info(f"Stopping SGLang worker {i}")
process.__exit__(exc_type, exc_val, exc_tb)
# Add delay to ensure full cleanup of NATS/ETCD/ZMQ resources
# This prevents test isolation issues when running multiple tests
logger.info("Waiting for SGLang worker resources to fully clean up...")
time.sleep(2)
process_name = "SGLang worker"
cleanup_name = "SGLang worker resources"
@pytest.mark.pre_merge
......@@ -346,41 +233,17 @@ def test_sglang_kv_router_basic(
set_ucx_tls_no_mm,
request_plane,
):
"""
Quick e2e sanity test for KV router with SGLang engine instances.
Tests both NATS and TCP request planes.
"""
# runtime_services starts etcd and nats
N_SGLANG_WORKERS = 2
logger.info(
f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers using request_plane={request_plane}"
)
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
run_basic_router_test(
engine_process_cls=SGLangProcess,
engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
num_workers=2,
single_gpu=True,
request=request,
request_plane=request_plane,
) as sglang_workers:
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic(
engine_workers=sglang_workers,
block_size=PAGE_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
store_backend="etcd", # Explicit for clarity
request_plane=request_plane,
)
block_size=PAGE_SIZE,
model_name=MODEL_NAME,
)
@pytest.mark.pre_merge
......@@ -393,32 +256,19 @@ def test_router_decisions_sglang_multiple_workers(
set_ucx_tls_no_mm,
request_plane,
):
# runtime_services starts etcd and nats
logger.info("Starting SGLang router prefix reuse test with two workers")
N_WORKERS = 2
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
run_router_decisions_test(
engine_process_cls=SGLangProcess,
engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
request=request,
request_plane=request_plane,
) as sglang_workers:
# Start 2 worker processes on the same GPU
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
runtime = get_runtime(request_plane=request_plane)
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate")
_test_router_decisions(
sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=False,
block_size=PAGE_SIZE,
)
model_name=MODEL_NAME,
block_size=PAGE_SIZE,
component_name="backend",
num_workers=2,
single_gpu=True,
test_dp_rank=False,
)
@pytest.mark.gpu_2
......@@ -442,33 +292,20 @@ def test_router_decisions_sglang_dp(
* The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
"""
N_WORKERS = 1
DP_SIZE = 2
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
run_router_decisions_test(
engine_process_cls=SGLangProcess,
engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
request=request,
request_plane=request_plane,
) as sglang_workers:
logger.info("Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the SGLang workers
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate")
_test_router_decisions(
sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=PAGE_SIZE,
)
model_name=MODEL_NAME,
block_size=PAGE_SIZE,
component_name="backend",
num_workers=1,
single_gpu=False,
test_dp_rank=True,
extra_process_kwargs={"data_parallel_size": 2},
)
@pytest.mark.pre_merge
......@@ -492,50 +329,16 @@ def test_sglang_indexers_sync(
durable_kv_events,
request_plane,
):
"""
Test that two KV routers have synchronized indexer states after processing requests
with SGLang workers. This test verifies that both routers converge to the same internal state.
Tests with configuration:
- nats_core: etcd backend, local indexer with NATS Core, TCP request plane
(includes NATS interruption/recovery testing)
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process, _etcd_process = runtime_services_dynamic_ports
logger.info(
f"Starting SGLang indexers sync test: store_backend={store_backend}, "
f"durable_kv_events={durable_kv_events}, request_plane={request_plane}"
)
N_SGLANG_WORKERS = 2
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
run_indexers_sync_test(
engine_process_cls=SGLangProcess,
engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
request=request,
runtime_services_dynamic_ports=runtime_services_dynamic_ports,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
) as sglang_workers:
# Start SGLang workers
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers")
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# When using durable_kv_events=True, use JetStream mode for the router
_test_router_indexers_sync(
engine_workers=sglang_workers,
block_size=PAGE_SIZE,
model_name=MODEL_NAME,
num_workers=N_SGLANG_WORKERS,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
)
logger.info("SGLang indexers sync test completed successfully")
request_plane=request_plane,
block_size=PAGE_SIZE,
model_name=MODEL_NAME,
num_workers=2,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment