Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
use crate::common::protocols::OutputSignal;
use crate::kv_manager::SglangKvManager;
use super::config::{SglangConfig, floor_to_block};
use super::request::SglangRequest;
#[derive(Default)]
pub(super) struct DecodeResult {
pub(super) requests: Vec<SglangRequest>,
pub(super) output_signals: Vec<OutputSignal>,
pub(super) retracted_any: bool,
pub(super) end_ms: f64,
}
fn decode_capacity_state(
running: &[SglangRequest],
kv_manager: &SglangKvManager,
config: &SglangConfig,
) -> (usize, usize) {
let actual_available =
kv_manager.cache().available_tokens() + kv_manager.cache().evictable_size;
let reserved_tokens = running
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>();
let logical_available = actual_available.saturating_sub(reserved_tokens);
let page_growth_needed = running
.iter()
.map(|req| {
if req.current_sequence_len() + 1 > req.allocated_tokens {
config.block_size
} else {
0
}
})
.sum();
(
actual_available,
logical_available.saturating_sub(page_growth_needed),
)
}
pub(super) fn cache_materialized_prefix(
req: &mut SglangRequest,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
) {
let aligned_tokens = req.page_aligned_materialized_tokens(config.block_size);
if aligned_tokens == 0 || aligned_tokens <= req.cached_tokens {
return;
}
let Some(last_node) = req.last_node else {
return;
};
let sequence = req.sequence_prefix(aligned_tokens);
let new_last =
kv_manager.cache_unfinished_req(&sequence, &req.kv_indices[..aligned_tokens], last_node);
req.last_node = Some(new_last);
req.cached_tokens = aligned_tokens;
req.debug_assert_invariants(config.block_size);
}
pub(super) fn check_decode_mem(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
) -> Vec<SglangRequest> {
let mut retracted = Vec::new();
loop {
let (actual_available, logical_available_after_growth) =
decode_capacity_state(running, kv_manager, config);
if actual_available >= running.len() && logical_available_after_growth > 0 {
break;
}
if running.len() <= 1 {
break;
}
let Some((idx, _)) = running
.iter()
.enumerate()
.min_by_key(|(_, req)| req.output_len())
else {
break;
};
let mut req = running.swap_remove(idx);
kv_manager.free_indices(&req.kv_indices[req.cached_tokens..]);
if let Some(last_node) = req.last_node.take() {
kv_manager.free_request(last_node);
}
req.reset_for_retract();
req.debug_assert_invariants(config.block_size);
retracted.push(req);
}
let available = kv_manager.cache().token_pool.available();
let needed = running.len();
if available < needed {
kv_manager.evict(needed - available);
}
if !retracted.is_empty() {
tracing::warn!(
num_retracted = retracted.len(),
remaining = running.len(),
"SGLang decode retract requests because KV pool is full"
);
}
retracted
}
pub(super) fn simulate_decode_step(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
current_time_ms: f64,
apply_speedup: bool,
) -> DecodeResult {
if running.is_empty() {
return DecodeResult {
end_ms: current_time_ms,
..DecodeResult::default()
};
}
let total_context: usize = running
.iter()
.map(SglangRequest::current_sequence_len)
.sum();
let avg_context = total_context / running.len();
let decode_time =
config
.perf_model
.predict_decode_time(running.len(), total_context, avg_context);
let unscaled_time = Duration::from_secs_f64(decode_time / 1000.0);
let effective_ratio = config.speedup_ratio * config.decode_speedup_ratio;
let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
Duration::from_secs_f64(unscaled_time.as_secs_f64() / effective_ratio)
} else {
unscaled_time
};
let retracted = check_decode_mem(running, kv_manager, config);
let retracted_any = !retracted.is_empty();
let mut output_signals = Vec::with_capacity(running.len());
let mut completed_indices = Vec::new();
for (idx, req) in running.iter_mut().enumerate() {
if kv_manager.cache().token_pool.available() == 0 {
kv_manager.evict(1);
}
let crossing_page_boundary = req.current_sequence_len() + 1 > req.allocated_tokens;
let last_idx = req.kv_indices.last().copied();
let Some(new_idx) = kv_manager.allocate_decode_token(last_idx) else {
tracing::warn!(uuid = %req.uuid, "Failed to allocate decode token, skipping output");
continue;
};
req.kv_indices.push(new_idx);
if crossing_page_boundary {
req.allocated_tokens += config.block_size;
}
req.append_output_token(req.next_output_token());
req.debug_assert_invariants(config.block_size);
let is_complete = req.output_len() >= req.max_output_tokens;
output_signals.push(OutputSignal {
uuid: req.uuid,
completed: is_complete,
});
if is_complete {
let sequence = req.sequence_tokens();
let tokens_to_cache = floor_to_block(sequence.len(), config.block_size);
if req.kv_indices.len() > tokens_to_cache {
kv_manager.free_indices(&req.kv_indices[tokens_to_cache..]);
}
if let Some(last_node) = req.last_node.take() {
if tokens_to_cache > 0 {
kv_manager.cache_finished_req(
&sequence[..tokens_to_cache],
&req.kv_indices[..tokens_to_cache],
last_node,
);
} else {
kv_manager.free_request(last_node);
}
}
completed_indices.push(idx);
continue;
}
cache_materialized_prefix(req, kv_manager, config);
req.debug_assert_invariants(config.block_size);
}
for &idx in completed_indices.iter().rev() {
running.swap_remove(idx);
}
DecodeResult {
requests: retracted,
output_signals,
retracted_any,
end_ms: current_time_ms + total_time.as_secs_f64() * 1000.0,
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
AdmissionEvent, MockerMetrics, RouterEventVisibility, SchedulerHandle,
capture_deferred_kv_publish_sink, publish_deferred_kv_events,
};
use super::core::SglangCore;
use super::request::{SglangRequest, direct_to_sglang};
#[derive(Clone)]
pub struct SglangScheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
impl SglangScheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
None,
)
}
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
admission_tx,
)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let total_blocks = args.num_gpu_blocks as u64;
let initial_metrics = MockerMetrics::new(dp_rank, 0, total_blocks);
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
tokio::spawn(async move {
let (deferred_kv_events, buffering_publishers) =
capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
let mut core = SglangCore::new_with_sink(args, dp_rank, buffering_publishers);
loop {
if receive_requests(
&mut core.waiting,
&mut request_rx,
&cancel_token_clone,
&core.running,
)
.await
.is_none()
{
break;
}
let iteration_start = Instant::now();
let pass = core.execute_pass_internal(None, 0.0);
if let Some(admission_tx) = admission_tx.as_ref() {
for admission in &pass.admissions {
let _ = admission_tx.send(admission.clone());
}
}
if pass.router_event_visibility == RouterEventVisibility::PassStart {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
if total_time > std::time::Duration::ZERO {
sleep_until_precise(iteration_start + total_time).await;
}
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&output_tx, &pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
pass.active_decode_blocks,
total_blocks,
));
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl SchedulerHandle for SglangScheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
waiting: &mut std::collections::VecDeque<SglangRequest>,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
running: &[SglangRequest],
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if waiting.is_empty() && running.is_empty() {
tokio::select! {
biased;
_ = cancel_token.cancelled() => return None,
result = request_rx.recv() => {
let request = result?;
waiting.push_back(direct_to_sglang(request));
}
}
}
while let Ok(request) = request_rx.try_recv() {
waiting.push_back(direct_to_sglang(request));
}
Some(())
}
fn flush_output_signals(
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
let _ = tx.send(signal.clone());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SGLang scheduler simulation with adaptive admission control.
//!
//! Reference: sglang/python/sglang/srt/managers/scheduler.py
mod config;
mod core;
mod decode;
mod live;
mod policy;
mod prefill;
mod request;
pub(crate) use core::SglangCore;
pub use live::SglangScheduler;
#[cfg(test)]
mod tests;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashSet, VecDeque};
use crate::cache::radix_cache::RadixCache;
use crate::kv_manager::SglangKvManager;
use super::config::{
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD, IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD,
LPM_FALLBACK_THRESHOLD, SchedulePolicy, SglangConfig,
};
use super::request::SglangRequest;
pub(super) fn apply_schedule_policy(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &SglangKvManager,
config: &SglangConfig,
) {
match config.schedule_policy {
SchedulePolicy::Fifo => {}
SchedulePolicy::Lpm => {
if waiting.len() > LPM_FALLBACK_THRESHOLD {
return;
}
let page_size = config.block_size.max(1);
let total_tokens = waiting
.iter()
.map(SglangRequest::current_sequence_len)
.sum::<usize>()
.max(page_size);
let mut waiting_queue_cache = RadixCache::new(total_tokens, page_size);
let mut temporary_deprioritized = HashSet::new();
let mut scored = Vec::with_capacity(waiting.len());
for req in waiting.drain(..) {
let sequence = req.sequence_tokens();
let prefix_len = kv_manager.cache().prefix_match_len(&sequence);
if prefix_len <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD {
let in_batch_prefix = waiting_queue_cache.prefix_match_len(&sequence);
if in_batch_prefix >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD {
temporary_deprioritized.insert(req.uuid);
} else if !sequence.is_empty() {
let values: Vec<usize> = (0..sequence.len()).collect();
waiting_queue_cache.insert(&sequence, &values);
}
}
scored.push((prefix_len, temporary_deprioritized.contains(&req.uuid), req));
}
scored.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| b.0.cmp(&a.0)));
for (_, _, req) in scored {
waiting.push_back(req);
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use super::super::AdmissionEvent;
use super::config::{SglangConfig, ceil_to_block};
use super::request::SglangRequest;
use crate::kv_manager::SglangKvManager;
pub(super) struct AdmitResult {
pub(super) can_run: Vec<SglangRequest>,
pub(super) admissions: Vec<AdmissionEvent>,
pub(super) total_isl: usize,
pub(super) total_prefix: usize,
pub(super) oom: bool,
}
pub(super) fn get_new_batch_prefill(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
new_token_ratio: f64,
running: &[SglangRequest],
) -> AdmitResult {
let cache = kv_manager.cache();
let reserved_decode_output: f64 = running
.iter()
.map(|req| {
let remaining_output = req
.remaining_output_tokens()
.min(config.clip_max_new_tokens);
remaining_output as f64 * new_token_ratio
})
.sum();
let reserved_page_overhead = waiting
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>()
+ running
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>();
let mut rem_total_tokens = (cache.available_tokens() + cache.evictable_size)
.saturating_sub(reserved_page_overhead) as f64
- reserved_decode_output;
let mut rem_input_tokens = config.max_prefill_tokens as f64;
let mut rem_chunk_tokens = config.chunked_prefill_size as f64;
let mut can_run = Vec::new();
let mut admissions = Vec::new();
let mut rejected = VecDeque::new();
let mut oom = false;
let mut total_isl = 0usize;
let mut total_prefix = 0usize;
while let Some(mut req) = waiting.pop_front() {
let extend_input = req.extend_input_len();
if extend_input == 0 {
rejected.push_back(req);
break;
}
let total_needed = req.total_tokens_needed(config.clip_max_new_tokens) as f64;
if total_needed >= rem_total_tokens {
rejected.push_back(req);
break;
}
let chunk_tokens = if extend_input <= config.chunked_prefill_size {
extend_input
} else {
let chunk = (rem_chunk_tokens as usize / config.block_size) * config.block_size;
if chunk == 0 {
rejected.push_back(req);
break;
}
chunk.min(extend_input)
};
let charged_input_tokens = ceil_to_block(chunk_tokens, config.block_size) as f64;
if charged_input_tokens > rem_input_tokens || charged_input_tokens > rem_chunk_tokens {
rejected.push_back(req);
break;
}
let chunk_end = req.materialized_tokens + chunk_tokens;
let old_allocated_tokens = req.allocated_tokens;
let prev_node = req.last_node.take();
let alloc_tokens = req.sequence_prefix(chunk_end);
let actual_new_tokens = alloc_tokens.len().saturating_sub(req.materialized_tokens);
let available = kv_manager.cache().token_pool.available();
if available < actual_new_tokens {
kv_manager.evict(actual_new_tokens - available);
}
let alloc = if req.materialized_tokens > 0 {
let Some(last_node) = prev_node else {
rejected.push_back(req);
break;
};
kv_manager.allocate_after_prefix(
&alloc_tokens,
req.materialized_tokens,
&req.kv_indices[..req.materialized_tokens],
last_node,
)
} else {
kv_manager.allocate_for_request(&alloc_tokens)
};
let Some(alloc) = alloc else {
req.last_node = prev_node;
rejected.push_back(req);
oom = true;
break;
};
if let Some(node) = prev_node {
kv_manager.free_request(node);
}
req.last_node = Some(alloc.last_node);
req.kv_indices = alloc.kv_indices;
req.materialized_tokens = chunk_end;
req.allocated_tokens = ceil_to_block(chunk_end, config.block_size);
req.debug_assert_invariants(config.block_size);
let is_truncated = chunk_end < req.current_sequence_len();
let output_reserve = if is_truncated {
0
} else {
req.remaining_output_tokens()
.min(config.clip_max_new_tokens)
};
admissions.push(AdmissionEvent {
uuid: req.uuid,
reused_input_tokens: alloc.prefix_len,
});
total_isl += chunk_end;
total_prefix += alloc.prefix_len;
rem_total_tokens -= (req.allocated_tokens - old_allocated_tokens + output_reserve) as f64;
rem_input_tokens -= charged_input_tokens;
rem_chunk_tokens -= charged_input_tokens;
can_run.push(req);
if rem_chunk_tokens <= 0.0 {
break;
}
}
while let Some(req) = rejected.pop_back() {
waiting.push_front(req);
}
AdmitResult {
can_run,
admissions,
total_isl,
total_prefix,
oom,
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use uuid::Uuid;
use crate::cache::radix_cache::NodeId;
use crate::common::protocols::DirectRequest;
#[derive(Clone, Debug)]
pub(super) struct SglangRequest {
pub(super) uuid: Uuid,
pub(super) prompt_tokens: Vec<u64>,
pub(super) max_output_tokens: usize,
pub(super) output_ids: Vec<u64>,
pub(super) last_node: Option<NodeId>,
pub(super) kv_indices: Vec<usize>,
pub(super) materialized_tokens: usize,
pub(super) cached_tokens: usize,
pub(super) allocated_tokens: usize,
}
impl SglangRequest {
pub(super) fn prompt_len(&self) -> usize {
self.prompt_tokens.len()
}
pub(super) fn output_len(&self) -> usize {
self.output_ids.len()
}
pub(super) fn current_sequence_len(&self) -> usize {
self.prompt_len() + self.output_len()
}
pub(super) fn extend_input_len(&self) -> usize {
self.current_sequence_len()
.saturating_sub(self.materialized_tokens)
}
pub(super) fn total_tokens_needed(&self, clip_max_new_tokens: usize) -> usize {
let remaining_input = self.extend_input_len();
let remaining_output = self.remaining_output_tokens().min(clip_max_new_tokens);
remaining_input + remaining_output
}
pub(super) fn remaining_output_tokens(&self) -> usize {
self.max_output_tokens.saturating_sub(self.output_len())
}
pub(super) fn extra_reserved_tokens(&self) -> usize {
self.allocated_tokens.saturating_sub(self.kv_indices.len())
}
pub(super) fn page_aligned_materialized_tokens(&self, block_size: usize) -> usize {
self.materialized_tokens / block_size * block_size
}
pub(super) fn sequence_tokens(&self) -> Vec<u64> {
let mut sequence = self.prompt_tokens.clone();
sequence.extend_from_slice(&self.output_ids);
sequence
}
pub(super) fn sequence_prefix(&self, len: usize) -> Vec<u64> {
let prompt_len = self.prompt_len();
if len <= prompt_len {
return self.prompt_tokens[..len].to_vec();
}
let mut prefix = self.prompt_tokens.clone();
prefix.extend_from_slice(&self.output_ids[..len - prompt_len]);
prefix
}
pub(super) fn next_output_token(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.uuid.hash(&mut hasher);
self.output_len().hash(&mut hasher);
hasher.finish()
}
pub(super) fn append_output_token(&mut self, token: u64) {
self.output_ids.push(token);
self.materialized_tokens += 1;
}
pub(super) fn debug_assert_invariants(&self, block_size: usize) {
#[cfg(debug_assertions)]
{
let sequence_len = self.current_sequence_len();
debug_assert!(
self.cached_tokens <= self.materialized_tokens,
"request {} cached {} tokens but materialized {}",
self.uuid,
self.cached_tokens,
self.materialized_tokens
);
debug_assert!(
self.materialized_tokens <= sequence_len,
"request {} materialized {} tokens but sequence length is {sequence_len}",
self.uuid,
self.materialized_tokens
);
debug_assert_eq!(
self.kv_indices.len(),
self.materialized_tokens,
"request {} has {} kv indices but {} materialized tokens",
self.uuid,
self.kv_indices.len(),
self.materialized_tokens
);
debug_assert!(
self.allocated_tokens >= self.materialized_tokens,
"request {} allocated {} tokens but materialized {}",
self.uuid,
self.allocated_tokens,
self.materialized_tokens
);
debug_assert_eq!(
self.cached_tokens % block_size,
0,
"request {} cached tokens {} are not page-aligned to block size {block_size}",
self.uuid,
self.cached_tokens
);
debug_assert!(
self.allocated_tokens == 0 || self.allocated_tokens.is_multiple_of(block_size),
"request {} allocated tokens {} are not page-aligned to block size {block_size}",
self.uuid,
self.allocated_tokens
);
debug_assert!(
self.extra_reserved_tokens() < block_size,
"request {} reserves {} extra tokens with block size {block_size}",
self.uuid,
self.extra_reserved_tokens()
);
debug_assert_eq!(
self.last_node.is_some(),
self.materialized_tokens > 0,
"request {} has last_node={} but materialized_tokens={}",
self.uuid,
self.last_node.is_some(),
self.materialized_tokens
);
}
}
pub(super) fn reset_for_retract(&mut self) {
self.last_node = None;
self.kv_indices.clear();
self.materialized_tokens = 0;
self.cached_tokens = 0;
self.allocated_tokens = 0;
}
}
pub(super) fn direct_to_sglang(req: DirectRequest) -> SglangRequest {
SglangRequest {
uuid: req.uuid.unwrap_or_else(Uuid::new_v4),
prompt_tokens: req.tokens.iter().map(|&t| t as u64).collect(),
max_output_tokens: req.max_output_tokens,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use dynamo_kv_router::indexer::{METRIC_EVENT_REMOVED, METRIC_EVENT_STORED};
use dynamo_kv_router::protocols::WorkerId;
use rstest::rstest;
use tokio::sync::mpsc;
use uuid::Uuid;
use super::config::{SchedulePolicy, SglangConfig, ceil_to_block};
use super::core::SglangCore;
use super::decode;
use super::decode::simulate_decode_step;
use super::live::SglangScheduler;
use super::policy::apply_schedule_policy;
use super::prefill::get_new_batch_prefill;
use super::request::SglangRequest;
use crate::common::protocols::{
DirectRequest, EngineType, KvEventPublishers, MockEngineArgs, OutputSignal, SglangArgs,
};
use crate::kv_manager::SglangKvManager;
use crate::scheduler::test_utils::{
RouterIndexerHarness, nth_stored_hashes, removed_event_count, stored_hashes,
};
use crate::scheduler::{RouterEventVisibility, SchedulerHandle, capture_router_event_sink};
const ROUTER_TEST_WORKER_ID: WorkerId = 17;
fn test_args(
num_gpu_blocks: usize,
block_size: usize,
chunked_prefill_size: usize,
) -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(num_gpu_blocks)
.block_size(block_size)
.speedup_ratio(1.0)
.sglang(Some(SglangArgs {
page_size: Some(block_size),
chunked_prefill_size: Some(chunked_prefill_size),
..Default::default()
}))
.build()
.unwrap()
}
fn direct_request(tokens: Vec<u32>, max_output_tokens: usize) -> DirectRequest {
DirectRequest {
tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
}
}
fn make_decoded_request(
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
prompt_tokens: Vec<u64>,
max_output_tokens: usize,
) -> SglangRequest {
let prompt_len = prompt_tokens.len();
let alloc = kv_manager.allocate_for_request(&prompt_tokens).unwrap();
let mut running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens,
max_output_tokens,
output_ids: Vec::new(),
last_node: Some(alloc.last_node),
kv_indices: alloc.kv_indices,
materialized_tokens: prompt_len,
cached_tokens: 0,
allocated_tokens: ceil_to_block(prompt_len, config.block_size),
}];
let result = simulate_decode_step(&mut running, kv_manager, config, 0.0, false);
assert_eq!(result.output_signals.len(), 1);
running.pop().unwrap()
}
mod scheduling {
use super::*;
#[tokio::test]
async fn test_sglang_scheduler_fifo_ordering() {
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let num_requests = 5;
let max_output = 3;
for i in 0..num_requests {
scheduler.receive(crate::common::protocols::DirectRequest {
tokens: vec![i as u32; 10],
max_output_tokens: max_output,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_signals = num_requests * max_output;
let mut received = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
if received >= expected_signals {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(received, expected_signals);
}
#[test]
fn test_lpm_reorders_by_current_sequence_prefix_match() {
let mut kv_manager = SglangKvManager::new(1000, 1, KvEventPublishers::default(), 0);
kv_manager
.cache_mut()
.insert(&[1, 2, 3, 4, 5], &[0, 1, 2, 3, 4]);
let config = SglangConfig {
schedule_policy: SchedulePolicy::Lpm,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let no_match_uuid = Uuid::new_v4();
let match_uuid = Uuid::new_v4();
let mut waiting = VecDeque::from([
SglangRequest {
uuid: no_match_uuid,
prompt_tokens: vec![9, 8, 7],
max_output_tokens: 1,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
SglangRequest {
uuid: match_uuid,
prompt_tokens: vec![1, 2, 3, 4, 5],
max_output_tokens: 1,
output_ids: vec![6, 7],
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
]);
apply_schedule_policy(&mut waiting, &kv_manager, &config);
assert_eq!(waiting[0].uuid, match_uuid);
assert_eq!(waiting[1].uuid, no_match_uuid);
}
#[test]
fn test_lpm_deprioritizes_duplicate_short_prefixes() {
let config = SglangConfig {
schedule_policy: SchedulePolicy::Lpm,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.block_size(1)
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let kv_manager = SglangKvManager::new(1000, 1, KvEventPublishers::default(), 0);
let duplicate_prefix = (0..32).collect::<Vec<_>>();
let mut waiting = VecDeque::new();
for _ in 0..33 {
waiting.push_back(SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: duplicate_prefix.clone(),
max_output_tokens: 1,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
});
}
let unique_uuid = Uuid::new_v4();
waiting.push_back(SglangRequest {
uuid: unique_uuid,
prompt_tokens: (100..132).collect(),
max_output_tokens: 1,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
});
apply_schedule_policy(&mut waiting, &kv_manager, &config);
assert_eq!(
waiting.iter().position(|req| req.uuid == unique_uuid),
Some(1)
);
}
}
mod core_behavior {
use super::*;
#[test]
fn test_chunked_prefill_budget_is_page_aware() {
let config = SglangConfig {
chunked_prefill_size: 8,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let mut kv_manager = SglangKvManager::new(10000, 4, KvEventPublishers::default(), 0);
let mut waiting = VecDeque::from([SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1; 6],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
}]);
let admit = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
assert_eq!(admit.can_run.len(), 1);
assert_eq!(admit.can_run[0].materialized_tokens, 6);
assert_eq!(admit.can_run[0].allocated_tokens, 8);
}
#[test]
fn test_chunked_prefill_subpage_budget_defers_next_request() {
let config = SglangConfig {
chunked_prefill_size: 8,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let first_uuid = Uuid::new_v4();
let second_uuid = Uuid::new_v4();
let mut kv_manager = SglangKvManager::new(10000, 4, KvEventPublishers::default(), 0);
let mut waiting = VecDeque::from([
SglangRequest {
uuid: first_uuid,
prompt_tokens: vec![1; 7],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
SglangRequest {
uuid: second_uuid,
prompt_tokens: vec![2; 8],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
},
]);
let admit = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
assert_eq!(admit.can_run.len(), 1);
assert_eq!(admit.can_run[0].uuid, first_uuid);
assert_eq!(waiting.len(), 1);
assert_eq!(waiting[0].uuid, second_uuid);
}
#[test]
fn test_decode_allocation_is_page_aware() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let alloc = kv_manager
.allocate_for_request(&[1, 2, 3, 4, 5, 6])
.unwrap();
let mut running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4, 5, 6],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(alloc.last_node),
kv_indices: alloc.kv_indices,
materialized_tokens: 6,
cached_tokens: 4,
allocated_tokens: 8,
}];
let first = simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(running[0].allocated_tokens, 8);
assert_eq!(running[0].output_len(), 1);
assert_eq!(first.output_signals.len(), 1);
simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(running[0].allocated_tokens, 8);
simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(running[0].allocated_tokens, 12);
}
#[test]
fn test_decode_speedup_ratio_scales_sglang_decode_time() {
let base_args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(2.0)
.decode_speedup_ratio(1.0)
.build()
.unwrap();
let fast_args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(2.0)
.decode_speedup_ratio(4.0)
.build()
.unwrap();
let base_config = SglangConfig::from_args(&base_args);
let fast_config = SglangConfig::from_args(&fast_args);
let mut base_kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let base_alloc = base_kv_manager.allocate_for_request(&[1, 2, 3, 4]).unwrap();
let mut base_running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(base_alloc.last_node),
kv_indices: base_alloc.kv_indices,
materialized_tokens: 4,
cached_tokens: 0,
allocated_tokens: 4,
}];
let mut fast_kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let fast_alloc = fast_kv_manager.allocate_for_request(&[1, 2, 3, 4]).unwrap();
let mut fast_running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(fast_alloc.last_node),
kv_indices: fast_alloc.kv_indices,
materialized_tokens: 4,
cached_tokens: 0,
allocated_tokens: 4,
}];
let base = simulate_decode_step(
&mut base_running,
&mut base_kv_manager,
&base_config,
0.0,
true,
);
let fast = simulate_decode_step(
&mut fast_running,
&mut fast_kv_manager,
&fast_config,
0.0,
true,
);
let ratio = base.end_ms / fast.end_ms;
assert!(base.end_ms > fast.end_ms);
assert!(
(ratio - 4.0).abs() < 1e-3,
"expected 4x decode speedup ratio, got {ratio}"
);
}
#[test]
fn test_check_decode_mem_preserves_generated_output_on_retract() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut kv_manager = SglangKvManager::new(8, 4, KvEventPublishers::default(), 0);
let first = kv_manager.cache_mut().token_pool.allocate(4).unwrap();
let second = kv_manager.cache_mut().token_pool.allocate(4).unwrap();
let mut running = vec![
SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 10,
output_ids: vec![11, 12, 13],
last_node: None,
kv_indices: first,
materialized_tokens: 7,
cached_tokens: 4,
allocated_tokens: 8,
},
SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![9, 8, 7, 6],
max_output_tokens: 10,
output_ids: vec![21],
last_node: None,
kv_indices: second,
materialized_tokens: 5,
cached_tokens: 4,
allocated_tokens: 8,
},
];
let retracted = decode::check_decode_mem(&mut running, &mut kv_manager, &config);
assert_eq!(retracted.len(), 1);
assert_eq!(retracted[0].output_ids, vec![21]);
assert_eq!(retracted[0].materialized_tokens, 0);
assert!(retracted[0].kv_indices.is_empty());
}
#[test]
fn test_unfinished_decode_request_is_cached_after_output() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut kv_manager = SglangKvManager::new(64, 4, KvEventPublishers::default(), 0);
let alloc = kv_manager.allocate_for_request(&[1, 2, 3, 4]).unwrap();
let mut running = vec![SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4],
max_output_tokens: 4,
output_ids: Vec::new(),
last_node: Some(alloc.last_node),
kv_indices: alloc.kv_indices,
materialized_tokens: 4,
cached_tokens: 0,
allocated_tokens: 4,
}];
simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
let prefix = running[0].sequence_prefix(4);
assert_eq!(kv_manager.cache().prefix_match_len(&prefix), 4);
}
#[test]
fn test_active_decode_blocks_tracks_page_reserved_occupancy_in_blocks() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(32)
.block_size(4)
.speedup_ratio(1.0)
.sglang(Some(SglangArgs {
chunked_prefill_size: Some(8),
page_size: Some(4),
..Default::default()
}))
.build()
.unwrap();
let mut core = SglangCore::new(args);
core.receive(crate::common::protocols::DirectRequest {
tokens: vec![1; 6],
max_output_tokens: 2,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
let pass = core.execute_pass_internal(None, 0.0);
assert_eq!(pass.completed_requests, 0);
assert_eq!(pass.active_decode_blocks, 2);
}
#[test]
fn test_sglang_pass_visibility_is_pass_end() {
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 4), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![1, 2, 3, 4], 1));
let pass = core.execute_pass_internal(None, 0.0);
assert_eq!(pass.router_event_visibility, RouterEventVisibility::PassEnd);
}
}
async fn assert_sglang_scheduler_completes_all(
scheduler: &SglangScheduler,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
prompt_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_prefix = vec![1u32; prompt_len / 2];
for i in 0..num_requests {
let mut input_tokens = if use_shared_tokens {
shared_prefix.clone()
} else {
Vec::new()
};
let unique_len = prompt_len - input_tokens.len();
input_tokens.extend((0..unique_len).map(|j| (i * unique_len + j) as u32 + 1000));
scheduler.receive(crate::common::protocols::DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(received_tokens, expected_tokens);
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert!(metrics.active_decode_blocks > 0);
assert!(metrics.total_blocks > 0);
assert!((0.0..=1.0).contains(&metrics.gpu_cache_usage_perc));
}
mod router_events {
use super::*;
#[rstest]
#[case::case_1(false, "fifo", 1)]
#[case::case_2(true, "fifo", 1)]
#[case::case_3(false, "lpm", 1)]
#[case::case_4(true, "lpm", 1)]
#[case::case_5(false, "fifo", 4)]
#[case::case_6(true, "fifo", 4)]
#[case::case_7(false, "lpm", 4)]
#[case::case_8(true, "lpm", 4)]
#[tokio::test]
async fn test_sglang_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] schedule_policy: &str,
#[case] page_size: usize,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.sglang(Some(SglangArgs {
schedule_policy: Some(schedule_policy.to_string()),
page_size: Some(page_size),
..Default::default()
}))
.build()
.unwrap();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
assert_sglang_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
async fn test_chunked_prefill_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 4), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![1, 2, 3, 4, 5, 6], 2));
let pass1 = core.execute_pass_internal(None, 0.0);
let mut prompt_hashes = stored_hashes(&pass1.kv_events);
assert_eq!(prompt_hashes.len(), 4);
harness.apply_events(pass1.kv_events).await;
let pass2 = core.execute_pass_internal(None, pass1.end_ms);
prompt_hashes.extend(nth_stored_hashes(&pass2.kv_events, 0));
harness.apply_events(pass2.kv_events).await;
assert_eq!(prompt_hashes.len(), 6);
assert!(harness.ok_count(METRIC_EVENT_STORED) >= 2);
harness.shutdown();
}
#[tokio::test]
async fn test_decode_growth_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 16), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![7, 8, 9, 10], 5));
let pass1 = core.execute_pass_internal(None, 0.0);
let mut full_hashes = stored_hashes(&pass1.kv_events);
harness.apply_events(pass1.kv_events).await;
let pass2 = core.execute_pass_internal(None, pass1.end_ms);
full_hashes.extend(stored_hashes(&pass2.kv_events));
harness.apply_events(pass2.kv_events).await;
assert_eq!(full_hashes.len(), 6);
assert!(harness.ok_count(METRIC_EVENT_STORED) >= 2);
harness.shutdown();
}
#[tokio::test]
async fn test_retract_frees_do_not_leave_stale_blocks() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let args = test_args(8, 4, 16);
let config = SglangConfig::from_args(&args);
let (buffer, sink) = capture_router_event_sink(ROUTER_TEST_WORKER_ID);
let mut kv_manager =
SglangKvManager::new(10, 4, KvEventPublishers::new(Some(sink), None), 0);
let req1 = make_decoded_request(&mut kv_manager, &config, vec![1, 2, 3, 4], 4);
let req1_events = buffer.drain();
let req1_hashes = stored_hashes(&req1_events);
harness.apply_events(req1_events).await;
let req2 = make_decoded_request(&mut kv_manager, &config, vec![9, 8, 7, 6], 4);
harness.apply_events(buffer.drain()).await;
let mut running = vec![req1, req2];
let retracted = decode::check_decode_mem(&mut running, &mut kv_manager, &config);
assert_eq!(retracted.len(), 1);
let retract_events = buffer.drain();
assert!(removed_event_count(&retract_events) > 0);
harness.apply_events(retract_events).await;
assert_eq!(harness.overlap_for_hashes(req1_hashes).await, 4);
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
#[tokio::test]
async fn test_completion_tail_free_emits_valid_removals() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = SglangCore::new_with_kv_capture(test_args(32, 4, 16), ROUTER_TEST_WORKER_ID);
core.receive(direct_request(vec![11, 12, 13, 14], 3));
let pass1 = core.execute_pass_internal(None, 0.0);
let prompt_hashes = nth_stored_hashes(&pass1.kv_events, 0);
let mut full_hashes = stored_hashes(&pass1.kv_events);
harness.apply_events(pass1.kv_events).await;
let pass2 = core.execute_pass_internal(None, pass1.end_ms);
full_hashes.extend(stored_hashes(&pass2.kv_events));
harness.apply_events(pass2.kv_events).await;
let pass3 = core.execute_pass_internal(None, pass2.end_ms);
assert!(removed_event_count(&pass3.kv_events) > 0);
full_hashes.extend(stored_hashes(&pass3.kv_events));
harness.apply_events(pass3.kv_events).await;
assert_eq!(prompt_hashes.len(), 4);
assert!(full_hashes.len() >= prompt_hashes.len());
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
#[tokio::test]
async fn test_mixed_chunk_decode_retract_reprefill_complete_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let args = test_args(8, 4, 4);
let config = SglangConfig::from_args(&args);
let (buffer, sink) = capture_router_event_sink(ROUTER_TEST_WORKER_ID);
let mut kv_manager =
SglangKvManager::new(12, 4, KvEventPublishers::new(Some(sink), None), 0);
let mut waiting = VecDeque::from([SglangRequest {
uuid: Uuid::new_v4(),
prompt_tokens: vec![1, 2, 3, 4, 5, 6],
max_output_tokens: 3,
output_ids: Vec::new(),
last_node: None,
kv_indices: Vec::new(),
materialized_tokens: 0,
cached_tokens: 0,
allocated_tokens: 0,
}]);
let chunk1 = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
let mut req1 = chunk1.can_run.into_iter().next().unwrap();
decode::cache_materialized_prefix(&mut req1, &mut kv_manager, &config);
waiting.push_front(req1);
harness.apply_events(buffer.drain()).await;
let chunk2 = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
let mut running = chunk2.can_run;
let decode1 = simulate_decode_step(&mut running, &mut kv_manager, &config, 0.0, false);
assert_eq!(decode1.output_signals.len(), 1);
harness.apply_events(buffer.drain()).await;
let req1 = running.pop().unwrap();
let req2 = make_decoded_request(&mut kv_manager, &config, vec![9, 10, 11, 12], 3);
harness.apply_events(buffer.drain()).await;
let mut running = vec![req1, req2];
let mut retracted = decode::check_decode_mem(&mut running, &mut kv_manager, &config);
assert_eq!(retracted.len(), 1);
harness.apply_events(buffer.drain()).await;
let mut waiting = VecDeque::from([retracted.pop().unwrap()]);
let mut now_ms = 0.0;
let mut saw_remove = harness.ok_count(METRIC_EVENT_REMOVED) > 0;
loop {
let admit =
get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &running);
for mut req in admit.can_run {
if req.materialized_tokens < req.current_sequence_len() {
decode::cache_materialized_prefix(&mut req, &mut kv_manager, &config);
waiting.push_front(req);
} else {
running.push(req);
}
}
let events = buffer.drain();
saw_remove |= removed_event_count(&events) > 0;
harness.apply_events(events).await;
if running.is_empty() {
if waiting.is_empty() {
break;
}
continue;
}
let decode =
simulate_decode_step(&mut running, &mut kv_manager, &config, now_ms, false);
now_ms = decode.end_ms;
for req in decode.requests.into_iter().rev() {
waiting.push_front(req);
}
let events = buffer.drain();
saw_remove |= removed_event_count(&events) > 0;
harness.apply_events(events).await;
if running.is_empty() && waiting.is_empty() {
break;
}
}
assert!(saw_remove);
harness.assert_no_event_errors();
harness.shutdown();
}
#[tokio::test]
async fn test_live_pathological_load_no_router_event_errors() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = SglangScheduler::new(
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(4)
.block_size(4)
.speedup_ratio(1000.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(4),
..Default::default()
}))
.build()
.unwrap(),
0,
Some(output_tx),
KvEventPublishers::new(Some(sink.clone()), None),
None,
);
for _ in 0..8 {
scheduler.receive(direct_request(vec![42], 4));
}
let expected = 8 * 4;
let mut seen = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
if seen == expected {
break;
}
}
_ = &mut timeout => {
break;
}
}
}
assert_eq!(seen, expected);
drop(scheduler);
drop(sink);
forward_task.await.unwrap();
harness.flush().await;
harness.assert_no_event_errors();
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::anyhow;
use dynamo_kv_router::indexer::{
KvIndexerInterface, KvIndexerMetrics, LocalKvIndexer, METRIC_EVENT_REMOVED,
METRIC_EVENT_STORED, METRIC_STATUS_BLOCK_NOT_FOUND, METRIC_STATUS_INVALID_BLOCK,
METRIC_STATUS_OK, METRIC_STATUS_PARENT_NOT_FOUND,
};
use dynamo_kv_router::protocols::{
KvCacheEvent, KvCacheEventData, LocalBlockHash, RouterEvent, WorkerId, WorkerWithDpRank,
};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use super::{DirectRequest, OutputSignal, SchedulerHandle};
use crate::common::protocols::KvCacheEventSink;
pub(crate) struct RouterIndexerHarness {
indexer: Arc<LocalKvIndexer>,
metrics: Arc<KvIndexerMetrics>,
worker: WorkerWithDpRank,
}
impl RouterIndexerHarness {
pub(crate) fn new(block_size: u32, worker_id: WorkerId) -> Self {
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let indexer = Arc::new(LocalKvIndexer::new(
CancellationToken::new(),
block_size,
metrics.clone(),
4096,
));
Self {
indexer,
metrics,
worker: WorkerWithDpRank::new(worker_id, 0),
}
}
pub(crate) async fn apply_events<I>(&self, events: I)
where
I: IntoIterator<Item = RouterEvent>,
{
for event in events {
self.indexer.apply_event_with_buffer(event).await.unwrap();
}
let _ = self.indexer.flush().await;
self.assert_no_event_errors();
}
pub(crate) async fn overlap_for_hashes(&self, local_hashes: Vec<LocalBlockHash>) -> u32 {
self.indexer
.find_matches(local_hashes)
.await
.unwrap()
.scores
.get(&self.worker)
.copied()
.unwrap_or(0)
}
pub(crate) fn ok_count(&self, event_type: &'static str) -> u64 {
metric_value(&self.metrics, event_type, METRIC_STATUS_OK)
}
pub(crate) fn status_count(&self, event_type: &'static str, status: &'static str) -> u64 {
metric_value(&self.metrics, event_type, status)
}
pub(crate) fn invalid_counts(&self) -> Vec<(&'static str, &'static str, u64)> {
[METRIC_EVENT_STORED, METRIC_EVENT_REMOVED]
.into_iter()
.flat_map(|event_type| {
[
METRIC_STATUS_PARENT_NOT_FOUND,
METRIC_STATUS_BLOCK_NOT_FOUND,
METRIC_STATUS_INVALID_BLOCK,
]
.into_iter()
.map(move |status| (event_type, status, self.status_count(event_type, status)))
})
.collect()
}
pub(crate) fn invalid_event_count(&self) -> u64 {
self.invalid_counts()
.into_iter()
.map(|(_, _, count)| count)
.sum()
}
pub(crate) fn spawn_forwarder(&self) -> (Arc<TestKvEventSink>, JoinHandle<()>) {
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<RouterEvent>();
let sink = Arc::new(TestKvEventSink {
worker_id: self.worker.worker_id,
event_tx,
});
let indexer = self.indexer.clone();
let forwarder = tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
indexer.apply_event_with_buffer(event).await.unwrap();
}
let _ = indexer.flush().await;
});
(sink, forwarder)
}
pub(crate) async fn flush(&self) {
let _ = self.indexer.flush().await;
}
pub(crate) fn assert_no_event_errors(&self) {
let breakdown = self
.invalid_counts()
.into_iter()
.filter(|(_, _, count)| *count > 0)
.map(|(event_type, status, count)| format!("{event_type}/{status}={count}"))
.collect::<Vec<_>>()
.join(", ");
assert_eq!(
self.invalid_event_count(),
0,
"router indexer reported invalid KV events{}",
if breakdown.is_empty() {
String::new()
} else {
format!(": {breakdown}")
}
);
}
pub(crate) fn shutdown(&self) {
self.indexer.shutdown();
}
}
#[derive(Clone)]
pub(crate) struct TestKvEventSink {
worker_id: WorkerId,
event_tx: mpsc::UnboundedSender<RouterEvent>,
}
impl KvCacheEventSink for TestKvEventSink {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.event_tx
.send(RouterEvent::new(self.worker_id, event))
.map_err(|_| anyhow!("router test event channel closed"))
}
}
pub(crate) fn metric_value(
metrics: &KvIndexerMetrics,
event_type: &'static str,
status: &'static str,
) -> u64 {
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[event_type, status])
.unwrap()
.get()
}
pub(crate) fn stored_hashes(events: &[RouterEvent]) -> Vec<LocalBlockHash> {
events
.iter()
.filter_map(|event| match &event.event.data {
KvCacheEventData::Stored(store) => Some(
store
.blocks
.iter()
.map(|block| block.tokens_hash)
.collect::<Vec<_>>(),
),
_ => None,
})
.flatten()
.collect()
}
pub(crate) fn nth_stored_hashes(events: &[RouterEvent], nth: usize) -> Vec<LocalBlockHash> {
events
.iter()
.filter_map(|event| match &event.event.data {
KvCacheEventData::Stored(store) => Some(
store
.blocks
.iter()
.map(|block| block.tokens_hash)
.collect::<Vec<_>>(),
),
_ => None,
})
.nth(nth)
.unwrap_or_default()
}
pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize {
events
.iter()
.filter(|event| matches!(event.event.data, KvCacheEventData::Removed(_)))
.count()
}
/// Send `num_requests` to a scheduler, collect all output signals, and assert
/// that the scheduler produces exactly `num_requests * max_output_tokens` signals
/// and returns to idle (0 active decode blocks).
///
/// When `use_shared_tokens` is true, the first half of each request shares a
/// common prefix to exercise prefix caching / radix tree reuse.
pub(crate) async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
scheduler.receive(DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received_tokens, expected_tokens,
"Expected {expected_tokens} output signals, got {received_tokens}"
);
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_eq!(
metrics.active_decode_blocks, 0,
"Scheduler should be idle after all requests complete, got {} active blocks",
metrics.active_decode_blocks
);
assert_eq!(
metrics.gpu_cache_usage_perc, 0.0,
"Scheduler should report zero cache usage after draining, got {}",
metrics.gpu_cache_usage_perc
);
assert!(
metrics.total_blocks > 0,
"Scheduler should populate total_blocks, got {}",
metrics.total_blocks
);
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Asynchronous Scheduler for LLM Request Management
//!
//! This module implements an asynchronous scheduler that handles three main functions:
//! 1. Receiving new requests and placing them in the waiting queue
//! 2. Scheduling waiting requests against available KV cache resources
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
//! The scheduler checks direct block capacity to determine if there's sufficient
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
//!
//! ## Request Simulation
//! The simulation models two key phases:
//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
//!
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
//! preemption strategy (LIFO by default, matching vLLM v1) where a running request
//! is evicted and placed at the front of the waiting queue to be rescheduled later.
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
WorkerType,
};
use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence;
use crate::common::utils::sleep_until_precise;
use crate::kv_manager::KvManager;
use crate::simulation::{TraceCollector, TraceSimulationReport};
use dynamo_kv_router::protocols::DpRank;
use dynamo_tokens::blocks::UniqueBlock;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use validator::Validate;
/// Simple metrics struct for mocker's internal use
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
pub dp_rank: DpRank,
pub active_decode_blocks: u64,
}
/// Enum representing either a direct request or an active sequence
pub enum Request {
Direct(DirectRequest),
Active(ActiveSequence),
}
#[derive(Default)]
struct SchedulerState {
waiting: VecDeque<Uuid>,
prefill: VecDeque<Uuid>,
decode: VecDeque<Uuid>,
requests: HashMap<Uuid, Request>,
}
impl SchedulerState {
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn receive(&mut self, request: DirectRequest) -> Uuid {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
self.requests.insert(uuid, Request::Direct(request));
self.waiting.push_back(uuid);
uuid
}
/// Try to admit one request from waiting → prefill.
/// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
/// later in simulate_prefill when the request reaches the front of the queue.
fn admit_one(&mut self, args: &MockEngineArgs) -> Option<Uuid> {
let &uuid = self.waiting.front()?;
let num_active = self.prefill.len() + self.decode.len();
if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
return None;
}
self.waiting.pop_front();
// Convert DirectRequest → ActiveSequence if needed
if let Some(Request::Direct(_)) = self.requests.get(&uuid) {
let Some(Request::Direct(direct)) = self.requests.remove(&uuid) else {
unreachable!()
};
self.requests.insert(
uuid,
Request::Active(ActiveSequence::new(
direct.tokens,
direct.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
args.zmq_kv_events_port.is_some(),
)),
);
}
self.prefill.push_back(uuid);
Some(uuid)
}
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
if !self.decode.contains(&uuid) {
return None;
}
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
Some(sequence)
}
/// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) {
tracing::trace!("Request {uuid} will complete");
self.decode.retain(|u| u != uuid);
self.requests.remove(uuid);
}
/// Preempt a running request by evicting it from decode, resetting the sequence,
/// and adding it back to the front of the waiting queue.
/// In LIFO mode, evicts the newest request (matches vLLM v1).
/// In FIFO mode, evicts the oldest request.
fn preempt(&mut self, mode: PreemptionMode) -> Vec<MoveBlock> {
let uuid = match mode {
PreemptionMode::Lifo => self.decode.pop_back(),
PreemptionMode::Fifo => self.decode.pop_front(),
}
.expect("Nothing to evict for preemption.");
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
tracing::warn!("Request {uuid} will be preempted");
// Reset the sequence and re-queue for prefill
let Request::Active(mut active_sequence) = request else {
panic!("Expected ActiveSequence in running queue")
};
let signals = active_sequence.reset_with_signal();
self.requests.insert(uuid, Request::Active(active_sequence));
self.waiting.push_front(uuid);
signals
}
}
/// Cancels its token when dropped. Shared via Arc so the background task is
/// only cancelled when the last Scheduler clone is dropped.
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
impl Scheduler {
/// Create a new Scheduler with the given parameters
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
args.validate().expect("invalid MockEngineArgs");
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let initial_metrics = MockerMetrics {
dp_rank,
active_decode_blocks: 0,
};
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
// Spawn main background task with cancellation token
tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
kv_event_sink,
dp_rank,
);
let mut hit_rates = RunningMean::new(1000);
loop {
// 1. Receive requests
if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
.await
.is_none()
{
break;
}
// 2. Simulate prefill + decode
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
// 3. Send metrics once per forward pass (after all prefill and decode processing)
let _ = metrics_tx.send(MockerMetrics {
dp_rank,
active_decode_blocks: kv_manager.num_active_blocks() as u64,
});
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl super::SchedulerHandle for Scheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
state: &mut SchedulerState,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if state.is_empty() {
// Fully idle - block until new request arrives or shutdown
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return None;
}
result = request_rx.recv() => {
let Some(request) = result else {
return None; // channel closed
};
state.receive(request);
return Some(());
}
}
}
// Has active/waiting work - collect any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
state.receive(request);
}
Some(())
}
/// Simulate prefill phase for all pending prefill requests.
///
/// Handles token budget, block allocation, and preemption inline.
/// Token budget: `max_num_batched_tokens - decode.len()` (1 token per decode request).
/// When blocks are unavailable, decode requests are preempted (LIFO by default)
/// to free capacity, matching vLLM v1 behavior.
async fn simulate_prefill(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
) -> Duration {
let start_time = Instant::now();
let total_time = simulate_prefill_step(state, kv_manager, hit_rates, args, None, 0.0, false);
if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
async fn simulate_decode(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
args: &MockEngineArgs,
) -> Duration {
let start_time = Instant::now();
let total_time = simulate_decode_step(state, kv_manager, output_tx, args, None, 0.0, false);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
if effective_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / effective_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
fn simulate_prefill_step(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
mut collector: Option<&mut TraceCollector>,
current_time_ms: f64,
apply_speedup: bool,
) -> Duration {
let mut token_budget = args
.max_num_batched_tokens
.map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));
// Accumulate batch-level prefill stats for a single predict call after the loop.
let mut batch_count: usize = 0;
let mut batch_total_isl: usize = 0;
let mut batch_total_prefix: usize = 0;
'prefill: while token_budget > 0 {
// Drain prefill first, then pull from waiting one at a time.
if state.prefill.is_empty() {
let Some(admitted_uuid) = state.admit_one(args) else {
break;
};
if let Some(collector) = collector.as_deref_mut() {
let Some(Request::Active(seq)) = state.requests.get(&admitted_uuid) else {
panic!("Request does not exist.");
};
let prefill_cost = kv_manager.get_prefill_cost(seq);
let reused_input_tokens = seq.len().saturating_sub(prefill_cost.new_tokens);
collector.on_admit(admitted_uuid, current_time_ms, reused_input_tokens);
}
}
let uuid = state.prefill[0];
let Some(Request::Active(seq)) = state.requests.get(&uuid) else {
panic!("Request does not exist.");
};
let prefill_cost = kv_manager.get_prefill_cost(seq);
let sequence_len = seq.len();
let allocated_tokens = seq.num_allocated_tokens();
let remaining = prefill_cost.new_tokens;
// Token budget check.
let tokens_left = sequence_len - allocated_tokens;
if !args.enable_chunked_prefill && tokens_left > token_budget {
break;
}
let chunk = tokens_left.min(token_budget);
let cumulative = allocated_tokens + chunk;
// Allocate blocks. process() returns the number of blocks committed.
// On partial success, preempt a decode request and retry; the next
// loop iteration re-prepares from the updated num_allocated_tokens.
let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
if let Some(signal) = seq.prepare_allocation(cumulative) {
let expected = match &signal {
MoveBlock::Use(blocks, ..) => blocks.len(),
_ => unreachable!(),
};
let allocated = kv_manager.process(&signal);
// Commit the blocks that were actually allocated.
let committed_tokens = if allocated == expected {
cumulative
} else {
// Partial success: compute token boundary from block count.
let prev_blocks = allocated_tokens
.div_ceil(seq.block_size())
.min(seq.unique_blocks().len());
(prev_blocks + allocated) * seq.block_size()
};
seq.commit_allocation(committed_tokens.min(cumulative));
if allocated < expected {
if state.decode.is_empty() {
break;
}
for signal in state.preempt(args.preemption_mode) {
kv_manager.process(&signal);
}
continue 'prefill; // Retry with freed capacity.
}
} else {
seq.commit_allocation(cumulative);
}
// Accumulate per-request (isl, prefix) for batch-level prediction.
let new_tokens_in_chunk = chunk.min(remaining);
if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
let isl = prefill_cost.cached_tokens + new_tokens_in_chunk;
batch_total_isl += isl;
batch_total_prefix += prefill_cost.cached_tokens;
batch_count += 1;
}
// Hit rate: fraction of tokens that were already cached.
let hit_rate = if sequence_len > 0 {
1.0 - (remaining as f32 / sequence_len as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
token_budget -= chunk;
if cumulative >= sequence_len {
// Fully prefilled: promote to decode queue.
state.prefill.pop_front();
state.decode.push_back(uuid);
} else {
// Partially prefilled: resume next iteration with updated allocation state.
break;
}
}
// One batch-level prefill prediction instead of summing per-request predictions.
let total_time = if batch_count > 0 {
let mean_isl = batch_total_isl / batch_count;
let mean_prefix = batch_total_prefix / batch_count;
let ms = args
.perf_model
.predict_prefill_time(batch_count, mean_isl, mean_prefix);
Duration::from_secs_f64(ms / 1000.0)
} else {
Duration::ZERO
};
if !apply_speedup || args.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
return total_time;
}
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio)
}
fn simulate_decode_step(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
args: &MockEngineArgs,
mut collector: Option<&mut TraceCollector>,
current_time_ms: f64,
apply_speedup: bool,
) -> Duration {
if state.decode.is_empty() {
return Duration::ZERO;
}
let decode_start_ms = current_time_ms;
let decode_lengths = state
.decode
.iter()
.filter_map(|uuid| match state.requests.get(uuid).unwrap() {
Request::Active(seq) => Some(seq.len()),
Request::Direct(_) => None,
})
.collect::<Vec<_>>();
if decode_lengths.is_empty() {
return Duration::ZERO;
}
let count = decode_lengths.len();
let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
let total_length: usize = decode_lengths.iter().sum();
let context_length = total_length / count;
let decoding_time =
args.perf_model
.predict_decode_time(count, active_kv_tokens, context_length);
let unscaled_time = Duration::from_secs_f64(decoding_time / 1000.0);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
Duration::from_secs_f64(unscaled_time.as_secs_f64() / effective_ratio)
} else {
unscaled_time
};
let decode_end_ms = decode_start_ms + total_time.as_secs_f64() * 1000.0;
// Process decoding.
let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
let mut emitted_any = false;
for uuid in uuids {
let mut allocated = false;
loop {
let Some(sequence) = state.run(uuid) else {
break;
};
let signals = sequence.generate();
if process_signals(kv_manager, &signals) {
allocated = true;
break;
}
sequence.pop(); // revert the failed generation
if state.decode.is_empty() {
break;
}
// Preempt one request and free its blocks
for signal in state.preempt(args.preemption_mode) {
kv_manager.process(&signal);
}
// If the current request was the one preempted, stop retrying
if !state.decode.contains(&uuid) {
break;
}
}
if !allocated {
continue;
}
let Some(sequence) = state.run(uuid) else {
continue;
};
emitted_any = true;
if let Some(collector) = collector.as_deref_mut() {
collector.on_token(uuid, decode_end_ms);
}
// Check completion and send notification.
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let send_failed = output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
if send_failed {
for signal in &sequence.free_signal() {
kv_manager.process(signal);
}
}
if send_failed || is_complete {
state.complete(&uuid);
}
}
if !emitted_any {
return Duration::ZERO;
}
total_time
}
pub fn simulate_trace(
args: MockEngineArgs,
mut requests: Vec<DirectRequest>,
) -> anyhow::Result<TraceSimulationReport> {
args.validate()?;
requests.sort_by(|left, right| {
let left_ts = left
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let right_ts = right
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests
.first()
.and_then(|request| request.arrival_timestamp_ms)
.ok_or_else(|| anyhow::anyhow!("trace replay requires at least one timestamped request"))?;
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
let arrival_timestamp_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp")
- first_arrival_ms;
request.arrival_timestamp_ms = Some(arrival_timestamp_ms);
request
})
.collect::<Vec<_>>(),
);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
if state.is_empty() {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
current_time_ms = next_arrival_ms;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
continue;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
Ok(collector.finish())
}
pub fn simulate_concurrency(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
args.validate()?;
let mut pending = VecDeque::from(requests);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_concurrency_arrivals(
&mut pending,
&mut state,
&mut collector,
current_time_ms,
max_in_flight,
);
if state.is_empty() {
break;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
Ok(collector.finish())
}
fn enqueue_trace_arrivals(
pending: &mut VecDeque<DirectRequest>,
state: &mut SchedulerState,
collector: &mut TraceCollector,
current_time_ms: f64,
) {
loop {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > current_time_ms {
break;
}
let request = pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = state.receive(request);
collector.on_arrival(uuid, arrival_ms, input_length, output_length);
}
}
fn enqueue_concurrency_arrivals(
pending: &mut VecDeque<DirectRequest>,
state: &mut SchedulerState,
collector: &mut TraceCollector,
current_time_ms: f64,
max_in_flight: usize,
) {
while state.requests.len() < max_in_flight {
let Some(mut request) = pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(current_time_ms);
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = state.receive(request);
collector.on_arrival(uuid, current_time_ms, input_length, output_length);
}
}
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
/// specifically a single signal attempting to create a single partial (generation) block.
/// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system.
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
for signal in signals {
if kv_manager.process(signal) > 0 {
continue;
}
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks, _hashes, ..) = signal else {
panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
);
};
// Verify the signal contains exactly one block
let num_blocks = blocks.len();
let num_active_blocks = kv_manager.num_active_blocks();
if num_blocks != 1 {
panic!(
"Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
);
}
// Verify the block is a PartialBlock (generation block)
if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
panic!("Failed signal is Invalid. Generation block has to be partial.");
}
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scheduler::SchedulerHandle;
use crate::simulation::{TraceCollector, TraceRequestStatsSnapshot};
use rstest::rstest;
use std::collections::HashMap;
use std::time::Duration;
use tokio::time::interval;
/// Helper function to verify that the scheduler is idle (no active KV blocks)
fn assert_scheduler_idle(metrics: &MockerMetrics) {
assert_eq!(
metrics.active_decode_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.active_decode_blocks
);
}
#[rstest]
#[case::case_1(false, false, false)]
#[case::case_2(false, true, false)]
#[case::case_3(true, false, false)]
#[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
async fn test_cache_hit_rate_with_identical_requests() {
let block_size: usize = 64;
let max_output_tokens: usize = 10;
let speedup_ratio = 10.0;
let num_requests = 10;
let token_length = 65;
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args
let args = MockEngineArgs::builder()
.num_gpu_blocks(100) // Large enough to not be a constraint
.block_size(block_size)
.speedup_ratio(speedup_ratio)
.build()
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create identical tokens for all requests
let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
// Send all requests with identical tokens
for _ in 0..num_requests {
let request = DirectRequest {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
};
scheduler.receive(request);
// Sleep for 0.1 second after each request
tokio::time::sleep(Duration::from_millis(100)).await;
}
// Collect all generated tokens
let mut received_tokens = 0;
// Set up a timeout that resets to 0.5 seconds on each received token
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
// Reset timeout whenever we receive a token
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => {
// Break when timeout occurs (no more tokens for 0.5 seconds)
break;
}
}
}
// Wait a bit for final metrics update
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify forward pass metrics - scheduler should be idle after completing all requests
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
println!("Test passed! Received {received_tokens} tokens");
}
/// White-box unit test that directly creates SchedulerState + KvManager,
/// manually invokes simulate_prefill / simulate_decode, and asserts on
/// queue states and block counts after each step.
#[tokio::test]
async fn test_scheduler_internal_state_transitions() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let r1_uuid = Uuid::from_u128(1);
let r2_uuid = Uuid::from_u128(2);
let r3_uuid = Uuid::from_u128(3);
// ── Step 1: Receive 3 requests ──
// R1: 8 input, 2 max_output → 2 blocks
// R2: 8 input, 2 max_output → 2 blocks
// R3: 12 input, 2 max_output → 3 blocks
state.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1_uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
state.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2_uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
state.receive(DirectRequest {
tokens: (200..212).collect(),
max_output_tokens: 2,
uuid: Some(r3_uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
assert_eq!(state.waiting.len(), 3);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 0);
assert_eq!(kv_manager.num_active_blocks(), 0);
// ── Step 2: First simulate_prefill ──
// Budget=12. R1 takes 8 tokens (2 blocks), fully prefilled → decode.
// R2 takes 4 tokens (1 block, chunked), partially prefilled → stays in prefill.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 1);
assert_eq!(state.prefill.len(), 1);
assert_eq!(state.decode.len(), 1);
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(state.prefill[0], r2_uuid);
assert_eq!(state.waiting[0], r3_uuid);
assert_eq!(kv_manager.num_active_blocks(), 3); // 2 for R1 + 1 for R2
let seq = match state.requests.get(&r1_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 8);
assert_eq!(seq.generated_tokens(), 0);
let seq = match state.requests.get(&r2_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 4);
assert_eq!(seq.generated_tokens(), 0);
// ── Step 3: First simulate_decode ──
// R1 generates 1 token, gains a partial block.
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
assert_eq!(state.decode.len(), 1);
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(kv_manager.num_active_blocks(), 4); // +1 partial for R1
let seq = match state.requests.get(&r1_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.generated_tokens(), 1);
// ── Step 4: Second simulate_prefill ──
// Budget=11. R2 finishes (4 more tokens, 1 block → active=5, decode).
// R3 admitted, needs 2 blocks for chunk of 7. Only 1 free slot → partial.
// Preempt R2 (LIFO) → R2 back to waiting. Retry R3 → evicts R2's
// inactive blocks, allocates 2 more → R3 allocated_tokens=11.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 1, "R2 preempted back to waiting");
assert_eq!(state.waiting[0], r2_uuid);
assert_eq!(state.prefill.len(), 1, "R3 partially prefilled");
assert_eq!(state.prefill[0], r3_uuid);
assert_eq!(state.decode.len(), 1, "R1 still decoding");
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(kv_manager.num_active_blocks(), 6); // at capacity
let seq = match state.requests.get(&r3_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 11);
// ── Step 5: Second simulate_decode ──
// R1 generates 2nd token → complete. Frees 3 blocks (1 destroyed, 2 deactivated).
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
assert!(!state.requests.contains_key(&r1_uuid), "R1 completed");
assert_eq!(state.decode.len(), 0);
assert_eq!(state.prefill.len(), 1);
assert_eq!(state.waiting.len(), 1);
assert_eq!(kv_manager.num_active_blocks(), 3); // only R3's 3 blocks
// ── Step 6: Third simulate_prefill ──
// R3 finishes prefill (1 token left, no new blocks) → decode.
// R2 re-admitted, fully prefilled (2 blocks via inactive eviction) → decode.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 0);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 2);
assert!(state.decode.contains(&r3_uuid));
assert!(state.decode.contains(&r2_uuid));
assert_eq!(kv_manager.num_active_blocks(), 5); // 3 for R3 + 2 for R2
// ── Steps 7+: Cycle until all requests complete ──
loop {
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
if state.is_empty() {
break;
}
}
assert_eq!(state.waiting.len(), 0);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 0);
assert_eq!(kv_manager.num_active_blocks(), 0);
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let block_size: usize = 64;
let input_tokens = 256;
let max_output_tokens = 200; // More than we'll receive
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args
let args = MockEngineArgs::builder()
.num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
.block_size(block_size)
.speedup_ratio(100.0) // Fast simulation
.build()
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create request with 256 tokens
let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
let request = DirectRequest {
tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
};
scheduler.receive(request);
// Receive exactly 129 tokens
let mut received_count = 0;
while received_count < 129 {
if let Some(_signal) = output_rx.recv().await {
received_count += 1;
} else {
panic!("Channel closed before receiving 129 tokens");
}
}
// Drop the receiver immediately
drop(output_rx);
// Wait for 1 second to allow cleanup
tokio::time::sleep(Duration::from_secs(1)).await;
// Check forward pass metrics
let metrics_rx = scheduler.metrics_receiver();
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
}
#[derive(Debug)]
struct ManualReplayResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
idle_jump_ms: f64,
first_decode_end_ms: f64,
}
#[derive(Debug)]
struct ManualConcurrencyResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
}
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn replay_fixture() -> Vec<DirectRequest> {
vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![9, 9, 9, 9, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
]
}
fn run_trace_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
) -> ManualReplayResult {
let mut requests = requests;
requests.sort_by(|left, right| {
let left_ts = left.arrival_timestamp_ms.unwrap();
let right_ts = right.arrival_timestamp_ms.unwrap();
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests.first().unwrap().arrival_timestamp_ms.unwrap();
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
request.arrival_timestamp_ms =
Some(request.arrival_timestamp_ms.unwrap() - first_arrival_ms);
request
})
.collect::<Vec<_>>(),
);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
let mut idle_jump_ms = 0.0;
let mut first_decode_end_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
if state.is_empty() {
let next_arrival_ms = pending.front().unwrap().arrival_timestamp_ms.unwrap();
current_time_ms = next_arrival_ms;
if idle_jump_ms == 0.0 && current_time_ms > 0.0 {
idle_jump_ms = current_time_ms;
}
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
continue;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
args,
Some(&mut collector),
current_time_ms,
true,
);
if first_decode_end_ms == 0.0 && decode_time > Duration::ZERO {
first_decode_end_ms = current_time_ms + decode_time.as_secs_f64() * 1000.0;
}
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualReplayResult {
report: collector.finish(),
snapshots,
idle_jump_ms,
first_decode_end_ms,
}
}
fn run_concurrency_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> ManualConcurrencyResult {
let mut pending = VecDeque::from(requests);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_concurrency_arrivals(
&mut pending,
&mut state,
&mut collector,
current_time_ms,
max_in_flight,
);
if state.is_empty() {
break;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualConcurrencyResult {
report: collector.finish(),
snapshots,
}
}
fn assert_report_close(left: &TraceSimulationReport, right: &TraceSimulationReport) {
let epsilon = 1e-9;
assert_eq!(
left.request_counts.num_requests,
right.request_counts.num_requests
);
assert_eq!(
left.request_counts.completed_requests,
right.request_counts.completed_requests
);
assert_eq!(
left.request_counts.total_input_tokens,
right.request_counts.total_input_tokens
);
assert_eq!(
left.request_counts.total_output_tokens,
right.request_counts.total_output_tokens
);
assert!((left.throughput.duration_ms - right.throughput.duration_ms).abs() <= epsilon);
assert!(
(left.throughput.request_throughput_rps - right.throughput.request_throughput_rps)
.abs()
<= epsilon
);
assert!(
(left.throughput.input_throughput_tok_s - right.throughput.input_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.output_throughput_tok_s - right.throughput.output_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.total_throughput_tok_s - right.throughput.total_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.prefix_cache_reused_ratio - right.prefix_cache_reused_ratio).abs() <= epsilon
);
assert!((left.latency.ttft.mean_ms - right.latency.ttft.mean_ms).abs() <= epsilon);
assert!((left.latency.ttft.min_ms - right.latency.ttft.min_ms).abs() <= epsilon);
assert!((left.latency.ttft.max_ms - right.latency.ttft.max_ms).abs() <= epsilon);
assert!((left.latency.ttft.median_ms - right.latency.ttft.median_ms).abs() <= epsilon);
assert!((left.latency.ttft.p75_ms - right.latency.ttft.p75_ms).abs() <= epsilon);
assert!((left.latency.ttft.p90_ms - right.latency.ttft.p90_ms).abs() <= epsilon);
assert!((left.latency.ttft.p95_ms - right.latency.ttft.p95_ms).abs() <= epsilon);
assert!((left.latency.ttft.p99_ms - right.latency.ttft.p99_ms).abs() <= epsilon);
assert!((left.latency.ttft.std_ms - right.latency.ttft.std_ms).abs() <= epsilon);
assert!((left.latency.ttst.mean_ms - right.latency.ttst.mean_ms).abs() <= epsilon);
assert!((left.latency.ttst.min_ms - right.latency.ttst.min_ms).abs() <= epsilon);
assert!((left.latency.ttst.max_ms - right.latency.ttst.max_ms).abs() <= epsilon);
assert!((left.latency.ttst.median_ms - right.latency.ttst.median_ms).abs() <= epsilon);
assert!((left.latency.ttst.p75_ms - right.latency.ttst.p75_ms).abs() <= epsilon);
assert!((left.latency.ttst.p90_ms - right.latency.ttst.p90_ms).abs() <= epsilon);
assert!((left.latency.ttst.p95_ms - right.latency.ttst.p95_ms).abs() <= epsilon);
assert!((left.latency.ttst.p99_ms - right.latency.ttst.p99_ms).abs() <= epsilon);
assert!((left.latency.ttst.std_ms - right.latency.ttst.std_ms).abs() <= epsilon);
assert!((left.latency.tpot.mean_ms - right.latency.tpot.mean_ms).abs() <= epsilon);
assert!((left.latency.tpot.min_ms - right.latency.tpot.min_ms).abs() <= epsilon);
assert!((left.latency.tpot.max_ms - right.latency.tpot.max_ms).abs() <= epsilon);
assert!((left.latency.tpot.median_ms - right.latency.tpot.median_ms).abs() <= epsilon);
assert!((left.latency.tpot.p75_ms - right.latency.tpot.p75_ms).abs() <= epsilon);
assert!((left.latency.tpot.p90_ms - right.latency.tpot.p90_ms).abs() <= epsilon);
assert!((left.latency.tpot.p95_ms - right.latency.tpot.p95_ms).abs() <= epsilon);
assert!((left.latency.tpot.p99_ms - right.latency.tpot.p99_ms).abs() <= epsilon);
assert!((left.latency.tpot.std_ms - right.latency.tpot.std_ms).abs() <= epsilon);
assert!(
(left.latency.itl.distribution.mean_ms - right.latency.itl.distribution.mean_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.min_ms - right.latency.itl.distribution.min_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.max_ms - right.latency.itl.distribution.max_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.median_ms - right.latency.itl.distribution.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p75_ms - right.latency.itl.distribution.p75_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p90_ms - right.latency.itl.distribution.p90_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p95_ms - right.latency.itl.distribution.p95_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p99_ms - right.latency.itl.distribution.p99_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.std_ms - right.latency.itl.distribution.std_ms).abs()
<= epsilon
);
assert!((left.latency.itl.max_ms - right.latency.itl.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.mean_ms - right.latency.e2e.mean_ms).abs() <= epsilon);
assert!((left.latency.e2e.min_ms - right.latency.e2e.min_ms).abs() <= epsilon);
assert!((left.latency.e2e.max_ms - right.latency.e2e.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.median_ms - right.latency.e2e.median_ms).abs() <= epsilon);
assert!((left.latency.e2e.p75_ms - right.latency.e2e.p75_ms).abs() <= epsilon);
assert!((left.latency.e2e.p90_ms - right.latency.e2e.p90_ms).abs() <= epsilon);
assert!((left.latency.e2e.p95_ms - right.latency.e2e.p95_ms).abs() <= epsilon);
assert!((left.latency.e2e.p99_ms - right.latency.e2e.p99_ms).abs() <= epsilon);
assert!((left.latency.e2e.std_ms - right.latency.e2e.std_ms).abs() <= epsilon);
assert!(
(left.latency.output_token_throughput_per_user.mean_ms
- right.latency.output_token_throughput_per_user.mean_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.min_ms
- right.latency.output_token_throughput_per_user.min_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.max_ms
- right.latency.output_token_throughput_per_user.max_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.median_ms
- right.latency.output_token_throughput_per_user.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p75_ms
- right.latency.output_token_throughput_per_user.p75_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p90_ms
- right.latency.output_token_throughput_per_user.p90_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p95_ms
- right.latency.output_token_throughput_per_user.p95_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p99_ms
- right.latency.output_token_throughput_per_user.p99_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.std_ms
- right.latency.output_token_throughput_per_user.std_ms)
.abs()
<= epsilon
);
}
#[rstest]
#[case(false, false)]
#[case(false, true)]
#[case(true, false)]
#[case(true, true)]
fn test_trace_replay_matches_manual_steps(
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let args = replay_args(enable_prefix_caching, enable_chunked_prefill);
let manual = run_trace_manually(&args, replay_fixture());
let replay_report = simulate_trace(args, replay_fixture()).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 1.0);
assert_eq!(request_3.arrival_time_ms, 400.0);
assert_eq!(manual.idle_jump_ms, 400.0);
assert_eq!(
request_1.first_token_ms.unwrap(),
manual.first_decode_end_ms,
);
assert!(request_2.first_admit_ms.unwrap() >= request_2.arrival_time_ms);
assert!(request_3.first_admit_ms.unwrap() >= request_3.arrival_time_ms);
assert!(manual.report.latency.e2e.mean_ms >= manual.report.latency.ttft.mean_ms);
if enable_prefix_caching {
assert!(request_2.reused_input_tokens > 0);
assert!(manual.report.prefix_cache_reused_ratio > 0.0);
} else {
assert_eq!(request_2.reused_input_tokens, 0);
assert_eq!(manual.report.prefix_cache_reused_ratio, 0.0);
}
assert_report_close(&replay_report, &manual.report);
}
#[test]
fn test_concurrency_replay_matches_manual_steps() {
let args = replay_args(false, false);
let requests = vec![
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 6, 7, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 9, 10, 11],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![12, 13, 14, 15, 16, 17, 18, 19],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
];
let manual = run_concurrency_manually(&args, requests.clone(), 2);
let replay_report = simulate_concurrency(args, requests, 2).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(manual.report.request_counts.completed_requests, 3);
assert_eq!(manual.report.request_counts.total_input_tokens, 24);
assert_eq!(manual.report.request_counts.total_output_tokens, 6);
assert_report_close(&replay_report, &manual.report);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet, VecDeque};
use std::time::Duration;
use dynamo_kv_router::protocols::WorkerId;
use dynamo_tokens::blocks::UniqueBlock;
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::common::protocols::{
DirectRequest, KvEventPublishers, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
WorkerType,
};
use crate::common::sequence::ActiveSequence;
use crate::kv_manager::KvManager;
use crate::replay::TraceCollector;
use crate::scheduler::{
AdmissionEvent, CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility,
capture_router_event_sink,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RequestStatus {
Waiting,
Running,
Preempted,
}
pub(crate) struct VllmRequestState {
pub(crate) sequence: ActiveSequence,
pub(crate) status: RequestStatus,
pub(crate) num_computed_tokens: usize,
pub(crate) num_preemptions: usize,
}
#[derive(Default)]
pub(crate) struct SchedulerState {
pub(crate) waiting: VecDeque<Uuid>,
waiting_members: HashSet<Uuid>,
pub(crate) running: VecDeque<Uuid>,
running_members: HashSet<Uuid>,
pub(crate) requests: HashMap<Uuid, VllmRequestState>,
}
struct PreemptedRequest {
uuid: Uuid,
signals: Vec<MoveBlock>,
}
#[derive(Clone, Copy, Debug, Default)]
struct ScheduledWork {
total_tokens: usize,
prompt_tokens: usize,
prefix_tokens: usize,
}
enum ScheduleOutcome {
Scheduled {
tokens_used: usize,
admission: Option<AdmissionEvent>,
},
Blocked,
CurrentPreempted,
}
impl SchedulerState {
pub(crate) fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn push_waiting(&mut self, uuid: Uuid) {
if !self.waiting_members.insert(uuid) {
return;
}
self.waiting.push_back(uuid);
}
fn prepend_waiting(&mut self, uuid: Uuid) {
if !self.waiting_members.insert(uuid) {
return;
}
self.waiting.push_front(uuid);
}
fn next_waiting_uuid(&mut self) -> Option<Uuid> {
loop {
let uuid = *self.waiting.front()?;
let Some(request) = self.requests.get(&uuid) else {
self.waiting.pop_front();
self.waiting_members.remove(&uuid);
continue;
};
if self.waiting_members.contains(&uuid) && request.status != RequestStatus::Running {
return Some(uuid);
}
self.waiting.pop_front();
self.waiting_members.remove(&uuid);
}
}
fn compact_running(&mut self) {
let mut compacted = VecDeque::with_capacity(self.running.len());
while let Some(uuid) = self.running.pop_front() {
let is_running = self.running_members.contains(&uuid)
&& self
.requests
.get(&uuid)
.is_some_and(|request| request.status == RequestStatus::Running);
if is_running {
compacted.push_back(uuid);
continue;
}
self.running_members.remove(&uuid);
}
self.running = compacted;
}
fn transition_to_running(&mut self, uuid: Uuid) {
if self.waiting.front().copied() == Some(uuid) {
self.waiting.pop_front();
}
self.waiting_members.remove(&uuid);
if self.running_members.insert(uuid) {
self.running.push_back(uuid);
}
if let Some(request) = self.requests.get_mut(&uuid) {
request.status = RequestStatus::Running;
}
}
pub(crate) fn complete(&mut self, uuid: &Uuid) {
self.waiting_members.remove(uuid);
self.running_members.remove(uuid);
self.requests.remove(uuid);
}
pub(crate) fn running_sequence_mut(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
if !self.running_members.contains(&uuid) {
return None;
}
self.requests
.get_mut(&uuid)
.map(|request| &mut request.sequence)
}
fn preempt(&mut self, mode: PreemptionMode) -> Option<PreemptedRequest> {
let uuid = loop {
let candidate = match mode {
PreemptionMode::Lifo => self.running.pop_back(),
PreemptionMode::Fifo => self.running.pop_front(),
}?;
let is_running = self.running_members.contains(&candidate)
&& self
.requests
.get(&candidate)
.is_some_and(|request| request.status == RequestStatus::Running);
if is_running {
break candidate;
}
self.running_members.remove(&candidate);
};
self.running_members.remove(&uuid);
let request = self.requests.get_mut(&uuid)?;
request.status = RequestStatus::Preempted;
request.num_computed_tokens = 0;
request.num_preemptions += 1;
let signals = request.sequence.reset_with_signal();
debug_assert_vllm_request_invariants(uuid, request);
#[cfg(debug_assertions)]
{
debug_assert_eq!(
request.sequence.num_allocated_tokens(),
0,
"preempted request {uuid} should release all allocated KV"
);
}
self.prepend_waiting(uuid);
Some(PreemptedRequest { uuid, signals })
}
#[cfg(test)]
pub(super) fn insert_running_for_test(&mut self, uuid: Uuid) {
self.running_members.insert(uuid);
self.running.push_back(uuid);
}
}
pub(crate) struct VllmCore {
args: MockEngineArgs,
pub(super) state: SchedulerState,
pub(super) kv_manager: KvManager,
kv_event_buffer: Option<CapturedRouterEventBuffer>,
}
impl VllmCore {
pub(crate) fn new(args: MockEngineArgs) -> Self {
Self::new_internal(args, 0, None, KvEventPublishers::default())
}
pub(crate) fn new_with_kv_capture(args: MockEngineArgs, worker_id: WorkerId) -> Self {
let (buffer, sink) = capture_router_event_sink(worker_id);
Self::new_internal(
args,
0,
Some(buffer),
KvEventPublishers::new(Some(sink), None),
)
}
pub(super) fn new_with_sink(
args: MockEngineArgs,
dp_rank: u32,
kv_event_publishers: KvEventPublishers,
) -> Self {
Self::new_internal(args, dp_rank, None, kv_event_publishers)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
kv_event_buffer: Option<CapturedRouterEventBuffer>,
kv_event_publishers: KvEventPublishers,
) -> Self {
let args = args.normalized().expect("invalid MockEngineArgs");
Self {
kv_manager: KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
kv_event_publishers,
dp_rank,
),
args,
state: SchedulerState::default(),
kv_event_buffer,
}
}
pub(crate) fn receive(&mut self, request: DirectRequest) -> Uuid {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
let sequence = ActiveSequence::new(
request.tokens,
request.max_output_tokens,
Some(self.args.block_size),
self.args.enable_prefix_caching,
self.args.zmq_kv_events_port.is_some(),
);
self.state.requests.insert(
uuid,
VllmRequestState {
sequence,
status: RequestStatus::Waiting,
num_computed_tokens: 0,
num_preemptions: 0,
},
);
self.state.push_waiting(uuid);
if let Some(request) = self.state.requests.get(&uuid) {
debug_assert_vllm_request_progress(uuid, request);
}
uuid
}
pub(crate) fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub(crate) fn num_requests(&self) -> usize {
self.state.requests.len()
}
pub(crate) fn execute_pass(
&mut self,
collector: &mut TraceCollector,
now_ms: f64,
) -> EnginePassResult {
self.execute_pass_internal(Some(collector), now_ms, None)
}
pub(super) fn execute_pass_internal(
&mut self,
mut collector: Option<&mut TraceCollector>,
now_ms: f64,
admission_tx: Option<&mpsc::UnboundedSender<AdmissionEvent>>,
) -> EnginePassResult {
let requests_before = self.state.requests.len();
self.state.compact_running();
let mut token_budget = self.args.max_num_batched_tokens.unwrap_or(usize::MAX);
let mut scheduled = HashMap::new();
let mut batch_count = 0usize;
let mut batch_total_isl = 0usize;
let mut batch_total_prefix = 0usize;
let mut admissions = Vec::new();
let mut preempted_any = false;
let mut req_index = 0usize;
while req_index < self.state.running.len() && token_budget > 0 {
let uuid = self.state.running[req_index];
match self.schedule_request(
uuid,
false,
&mut token_budget,
&mut scheduled,
&mut batch_count,
&mut batch_total_isl,
&mut batch_total_prefix,
&mut preempted_any,
) {
ScheduleOutcome::Scheduled { admission, .. } => {
if let Some(admission) = admission {
if let Some(collector) = collector.as_deref_mut() {
collector.on_admit(
admission.uuid,
now_ms,
admission.reused_input_tokens,
);
}
if let Some(admission_tx) = admission_tx {
let _ = admission_tx.send(admission.clone());
}
admissions.push(admission);
}
req_index += 1;
}
ScheduleOutcome::Blocked => break,
ScheduleOutcome::CurrentPreempted => {}
}
}
let max_num_running = self.args.max_num_seqs.unwrap_or(usize::MAX);
while !preempted_any && self.state.running.len() < max_num_running {
let Some(uuid) = self.state.next_waiting_uuid() else {
break;
};
match self.schedule_request(
uuid,
true,
&mut token_budget,
&mut scheduled,
&mut batch_count,
&mut batch_total_isl,
&mut batch_total_prefix,
&mut preempted_any,
) {
ScheduleOutcome::Scheduled {
admission,
tokens_used,
} => {
if let Some(admission) = admission {
if let Some(collector) = collector.as_deref_mut() {
collector.on_admit(
admission.uuid,
now_ms,
admission.reused_input_tokens,
);
}
if let Some(admission_tx) = admission_tx {
let _ = admission_tx.send(admission.clone());
}
admissions.push(admission);
}
if tokens_used == 0 && token_budget == 0 {
break;
}
}
ScheduleOutcome::Blocked | ScheduleOutcome::CurrentPreempted => break,
}
}
let prefill_time =
predict_prefill_duration(batch_count, batch_total_isl, batch_total_prefix, &self.args);
let decode_start_ms = now_ms + prefill_time.as_secs_f64() * 1000.0;
let (decode_time, output_signals) = self.emit_ready_tokens(collector, decode_start_ms);
let end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0;
debug_assert_vllm_scheduler_state(&self.state);
EnginePassResult {
end_ms,
completed_requests: requests_before.saturating_sub(self.state.requests.len()),
output_signals,
admissions,
active_decode_blocks: self.kv_manager.num_active_blocks() as u64,
router_event_visibility: RouterEventVisibility::PassStart,
kv_events: self
.kv_event_buffer
.as_ref()
.map(CapturedRouterEventBuffer::drain)
.unwrap_or_default(),
}
}
pub(super) fn drop_request(&mut self, uuid: Uuid) {
let Some(request) = self.state.requests.get(&uuid) else {
return;
};
for signal in request.sequence.free_signal() {
self.kv_manager.process(&signal);
}
self.state.complete(&uuid);
}
#[allow(clippy::too_many_arguments)]
fn schedule_request(
&mut self,
uuid: Uuid,
from_waiting: bool,
token_budget: &mut usize,
scheduled: &mut HashMap<Uuid, ScheduledWork>,
batch_count: &mut usize,
batch_total_isl: &mut usize,
batch_total_prefix: &mut usize,
preempted_any: &mut bool,
) -> ScheduleOutcome {
let Some(request) = self.state.requests.get(&uuid) else {
return ScheduleOutcome::Blocked;
};
debug_assert_vllm_request_invariants(uuid, request);
let prefill_cost = self.kv_manager.get_prefill_cost(&request.sequence);
let cached_prefix_tokens = if request.num_computed_tokens == 0 {
prefill_cost.cached_tokens
} else {
0
};
let effective_computed_before = request.num_computed_tokens + cached_prefix_tokens;
let prompt_len = request.sequence.num_input_tokens();
let prompt_before = effective_computed_before.min(prompt_len);
let remaining_known_tokens = request
.sequence
.len()
.saturating_sub(effective_computed_before);
let prompt_remaining = prompt_len.saturating_sub(prompt_before);
if prompt_remaining > 0
&& !self.args.enable_chunked_prefill
&& prompt_remaining > *token_budget
{
return ScheduleOutcome::Blocked;
}
let desired_tokens = remaining_known_tokens.min(*token_budget);
if desired_tokens == 0 && remaining_known_tokens > 0 {
return ScheduleOutcome::Blocked;
}
let desired_computed_after = effective_computed_before + desired_tokens;
let mut actual_computed_after = desired_computed_after;
loop {
let allocation = {
let Some(request) = self.state.requests.get_mut(&uuid) else {
return ScheduleOutcome::Blocked;
};
let allocation_target = desired_computed_after;
let prev_allocated_tokens = request.sequence.num_allocated_tokens();
if allocation_target <= prev_allocated_tokens {
request.num_computed_tokens = actual_computed_after;
None
} else {
let maybe_signal = request.sequence.prepare_allocation(allocation_target);
Some((allocation_target, prev_allocated_tokens, maybe_signal))
}
};
let Some((allocation_target, prev_allocated_tokens, maybe_signal)) = allocation else {
break;
};
let Some(signal) = maybe_signal else {
let Some(request) = self.state.requests.get_mut(&uuid) else {
return ScheduleOutcome::Blocked;
};
request.sequence.commit_allocation(allocation_target);
request.num_computed_tokens = actual_computed_after;
break;
};
let expected = match &signal {
MoveBlock::Use(blocks, ..) => blocks.len(),
_ => unreachable!(),
};
let allocated = self.kv_manager.process(&signal);
let (_committed_tokens, current_computed_tokens) = {
let Some(request) = self.state.requests.get_mut(&uuid) else {
return ScheduleOutcome::Blocked;
};
let committed_tokens = if allocated == expected {
allocation_target
} else {
let prev_blocks = prev_allocated_tokens
.div_ceil(request.sequence.block_size())
.min(request.sequence.unique_blocks().len());
(prev_blocks + allocated) * request.sequence.block_size()
};
request
.sequence
.commit_allocation(committed_tokens.min(allocation_target));
request.num_computed_tokens = actual_computed_after.min(committed_tokens);
(committed_tokens, request.num_computed_tokens)
};
if allocated == expected {
break;
}
let Some(preempted) = self.state.preempt(self.args.preemption_mode) else {
actual_computed_after = current_computed_tokens;
break;
};
for signal in preempted.signals {
self.kv_manager.process(&signal);
}
*preempted_any = true;
if let Some(undone) = scheduled.remove(&preempted.uuid) {
*token_budget += undone.total_tokens;
if undone.prompt_tokens > 0 && self.args.worker_type != WorkerType::Decode {
*batch_count = batch_count.saturating_sub(1);
*batch_total_isl =
batch_total_isl.saturating_sub(undone.prefix_tokens + undone.prompt_tokens);
*batch_total_prefix = batch_total_prefix.saturating_sub(undone.prefix_tokens);
}
}
if preempted.uuid == uuid {
return ScheduleOutcome::CurrentPreempted;
}
}
if let Some(request) = self.state.requests.get(&uuid) {
debug_assert_vllm_request_invariants(uuid, request);
}
let tokens_used = actual_computed_after.saturating_sub(effective_computed_before);
if tokens_used == 0
&& actual_computed_after < request_sequence_len(&self.state.requests, uuid)
{
return ScheduleOutcome::Blocked;
}
let prompt_after = actual_computed_after.min(prompt_len);
let prompt_tokens = prompt_after.saturating_sub(prompt_before);
scheduled.insert(
uuid,
ScheduledWork {
total_tokens: tokens_used,
prompt_tokens,
prefix_tokens: prompt_before,
},
);
if prompt_tokens > 0 && self.args.worker_type != WorkerType::Decode {
*batch_count += 1;
*batch_total_isl += prompt_before + prompt_tokens;
*batch_total_prefix += prompt_before;
}
if from_waiting {
self.state.transition_to_running(uuid);
}
*token_budget = token_budget.saturating_sub(tokens_used);
let admission = if from_waiting {
Some(AdmissionEvent {
uuid,
reused_input_tokens: cached_prefix_tokens,
})
} else {
None
};
ScheduleOutcome::Scheduled {
tokens_used,
admission,
}
}
fn emit_ready_tokens(
&mut self,
mut collector: Option<&mut TraceCollector>,
decode_start_ms: f64,
) -> (Duration, Vec<OutputSignal>) {
let ready = self
.state
.running
.iter()
.copied()
.filter(|uuid| {
let Some(request) = self.state.requests.get(uuid) else {
return false;
};
request.num_computed_tokens >= request.sequence.len()
&& request.sequence.generated_tokens() < request.sequence.max_output_tokens()
})
.collect::<Vec<_>>();
if ready.is_empty() {
return (Duration::ZERO, Vec::new());
}
let active_kv_tokens = self.kv_manager.num_active_blocks() * self.args.block_size;
let total_length = ready
.iter()
.filter_map(|uuid| self.state.requests.get(uuid))
.map(|request| request.sequence.len())
.sum::<usize>();
let context_length = total_length / ready.len();
let decode_ms =
self.args
.perf_model
.predict_decode_time(ready.len(), active_kv_tokens, context_length);
let decode_time = scale_decode_time(decode_ms, &self.args);
let decode_end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0;
let mut output_signals = Vec::with_capacity(ready.len());
for uuid in ready {
let mut emitted = false;
let mut completed = false;
loop {
debug_assert_vllm_ready_to_decode(&self.state.requests, uuid);
let Some(sequence) = self.state.running_sequence_mut(uuid) else {
break;
};
let signals = sequence.generate();
if process_signals(&mut self.kv_manager, &signals) {
if sequence.generated_tokens() < sequence.max_output_tokens() {
sequence.commit_allocation(sequence.len());
}
emitted = true;
completed = sequence.generated_tokens() >= sequence.max_output_tokens();
break;
}
sequence.pop();
let Some(preempted) = self.state.preempt(self.args.preemption_mode) else {
break;
};
for signal in preempted.signals {
self.kv_manager.process(&signal);
}
if preempted.uuid == uuid {
break;
}
}
if !emitted {
continue;
}
if let Some(collector) = collector.as_deref_mut() {
collector.on_token(uuid, decode_end_ms);
}
if let Some(request) = self.state.requests.get(&uuid) {
debug_assert_vllm_request_progress(uuid, request);
}
output_signals.push(OutputSignal { uuid, completed });
if completed {
self.state.complete(&uuid);
}
}
if output_signals.is_empty() {
return (Duration::ZERO, output_signals);
}
self.state.compact_running();
(decode_time, output_signals)
}
}
fn request_sequence_len(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) -> usize {
requests
.get(&uuid)
.map(|request| request.sequence.len())
.unwrap_or_default()
}
fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState) {
#[cfg(debug_assertions)]
{
let seq_len = request.sequence.len();
let allocated = request.sequence.num_allocated_tokens();
debug_assert!(
request.num_computed_tokens <= seq_len,
"request {uuid} computed {} tokens but sequence length is {seq_len}",
request.num_computed_tokens
);
debug_assert!(
allocated <= seq_len,
"request {uuid} allocated {allocated} tokens but sequence length is {seq_len}"
);
}
}
fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
#[cfg(debug_assertions)]
{
debug_assert_vllm_request_invariants(uuid, request);
let allocated = request.sequence.num_allocated_tokens();
debug_assert!(
allocated >= request.num_computed_tokens,
"request {uuid} allocated {allocated} tokens but computed {}",
request.num_computed_tokens
);
}
}
fn debug_assert_vllm_ready_to_decode(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) {
#[cfg(debug_assertions)]
{
let Some(request) = requests.get(&uuid) else {
return;
};
let seq_len = request.sequence.len();
if request.num_computed_tokens < seq_len {
return;
}
let allocated = request.sequence.num_allocated_tokens();
debug_assert_eq!(
allocated, seq_len,
"request {uuid} is decode-ready but allocated {allocated} tokens for sequence length {seq_len}"
);
}
}
fn debug_assert_vllm_scheduler_state(state: &SchedulerState) {
#[cfg(debug_assertions)]
{
let mut seen = std::collections::HashSet::new();
for uuid in &state.waiting_members {
debug_assert!(
seen.insert(*uuid),
"request {uuid} appears multiple times across waiting/running queues"
);
let request = state
.requests
.get(uuid)
.expect("waiting request missing from state map");
debug_assert!(
request.status != RequestStatus::Running,
"request {uuid} is queued in waiting but marked Running"
);
debug_assert_vllm_request_invariants(*uuid, request);
}
for uuid in &state.running_members {
debug_assert!(
seen.insert(*uuid),
"request {uuid} appears multiple times across waiting/running queues"
);
let request = state
.requests
.get(uuid)
.expect("running request missing from state map");
debug_assert_eq!(
request.status,
RequestStatus::Running,
"request {uuid} is queued in running but marked {:?}",
request.status
);
debug_assert_vllm_request_invariants(*uuid, request);
}
debug_assert!(
state.waiting.len() >= state.waiting_members.len(),
"waiting queue dropped live membership entries"
);
debug_assert!(
state.running.len() >= state.running_members.len(),
"running queue dropped live membership entries"
);
}
}
fn predict_prefill_duration(
batch_count: usize,
batch_total_isl: usize,
batch_total_prefix: usize,
args: &MockEngineArgs,
) -> Duration {
if batch_count == 0 || args.worker_type == WorkerType::Decode {
return Duration::ZERO;
}
let mean_isl = batch_total_isl / batch_count;
let mean_prefix = batch_total_prefix / batch_count;
let prefill_ms = args
.perf_model
.predict_prefill_time(batch_count, mean_isl, mean_prefix);
let total_time = Duration::from_secs_f64(prefill_ms / 1000.0);
if args.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
return total_time;
}
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio)
}
fn scale_decode_time(decode_ms: f64, args: &MockEngineArgs) -> Duration {
let unscaled = Duration::from_secs_f64(decode_ms / 1000.0);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
if effective_ratio <= 0.0 || unscaled <= Duration::ZERO {
return unscaled;
}
Duration::from_secs_f64(unscaled.as_secs_f64() / effective_ratio)
}
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
for signal in signals {
if kv_manager.process(signal) > 0 {
continue;
}
let MoveBlock::Use(blocks, ..) = signal else {
panic!("Failed signal is invalid. Expected decode allocation failure, got {signal:?}");
};
if blocks.len() != 1 {
panic!(
"Failed signal is invalid. Tried to allocate {} blocks during decode.",
blocks.len()
);
}
if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
panic!("Failed signal is invalid. Decode allocation must use a partial block.");
}
return false;
}
true
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
AdmissionEvent, RouterEventVisibility, SchedulerHandle, capture_deferred_kv_publish_sink,
publish_deferred_kv_events,
};
use super::core::VllmCore;
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
pub dp_rank: dynamo_kv_router::protocols::DpRank,
pub active_decode_blocks: u64,
pub total_blocks: u64,
pub gpu_cache_usage_perc: f64,
}
impl MockerMetrics {
pub fn new(
dp_rank: dynamo_kv_router::protocols::DpRank,
active_decode_blocks: u64,
total_blocks: u64,
) -> Self {
let gpu_cache_usage_perc = if total_blocks == 0 {
0.0
} else {
active_decode_blocks as f64 / total_blocks as f64
};
Self {
dp_rank,
active_decode_blocks,
total_blocks,
gpu_cache_usage_perc,
}
}
}
#[derive(Clone)]
pub struct Scheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
impl Scheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
None,
)
}
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
Self::new_internal(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
admission_tx,
)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let total_blocks = args.num_gpu_blocks as u64;
let initial_metrics = MockerMetrics::new(dp_rank, 0, total_blocks);
let (metrics_tx, metrics_rx) = tokio::sync::watch::channel(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
tokio::spawn(async move {
let (deferred_kv_events, buffering_publishers) =
capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
let mut core = VllmCore::new_with_sink(args, dp_rank, buffering_publishers);
loop {
if receive_requests(&mut core, &mut request_rx, &cancel_token_clone)
.await
.is_none()
{
break;
}
let iteration_start = Instant::now();
let pass = core.execute_pass_internal(None, 0.0, admission_tx.as_ref());
let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
if pass.router_event_visibility == RouterEventVisibility::PassStart {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
if total_time > std::time::Duration::ZERO {
sleep_until_precise(iteration_start + total_time).await;
}
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&mut core, &output_tx, &pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
core.kv_manager.num_active_blocks() as u64,
total_blocks,
));
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl SchedulerHandle for Scheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
core: &mut VllmCore,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if core.is_empty() {
tokio::select! {
biased;
_ = cancel_token.cancelled() => return None,
result = request_rx.recv() => {
let request = result?;
core.receive(request);
}
}
}
while let Ok(request) = request_rx.try_recv() {
core.receive(request);
}
Some(())
}
fn flush_output_signals(
core: &mut VllmCore,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
if tx.send(signal.clone()).is_ok() {
continue;
}
core.drop_request(signal.uuid);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! vLLM scheduler simulation around a unified waiting/running request model.
//!
//! Reference: vllm/vllm/v1/core/sched/scheduler.py
mod core;
mod live;
pub(crate) use core::VllmCore;
pub use live::{MockerMetrics, Scheduler};
#[cfg(test)]
mod tests;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::{Arc, Mutex};
use std::time::Duration;
use dynamo_kv_router::indexer::{METRIC_EVENT_REMOVED, METRIC_EVENT_STORED};
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, WorkerId};
use rstest::rstest;
use tokio::sync::mpsc;
use tokio::time::interval;
use uuid::Uuid;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal,
PreemptionMode, RawKvEvent, RawKvEventSink,
};
use crate::common::sequence::ActiveSequence;
use crate::scheduler::RouterEventVisibility;
use crate::scheduler::SchedulerHandle;
use crate::scheduler::test_utils::{RouterIndexerHarness, removed_event_count, stored_hashes};
use super::core::{RequestStatus, VllmCore, VllmRequestState};
use super::live::{MockerMetrics, Scheduler};
const ROUTER_TEST_WORKER_ID: WorkerId = 23;
fn assert_scheduler_idle(metrics: &MockerMetrics) {
assert_eq!(
metrics.active_decode_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.active_decode_blocks
);
assert_eq!(
metrics.gpu_cache_usage_perc, 0.0,
"Expected 0.0 cache usage, got {}",
metrics.gpu_cache_usage_perc
);
assert!(
metrics.total_blocks > 0,
"Expected total_blocks to be populated, got {}",
metrics.total_blocks
);
}
fn make_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn router_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(12)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.speedup_ratio(0.0)
.build()
.unwrap()
}
mod core_behavior {
use super::*;
#[test]
fn test_unified_pass_keeps_partial_prefill_in_running() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
assert_eq!(
pass.output_signals.len(),
1,
"first request should emit immediately"
);
assert_eq!(core.state.waiting.len(), 0);
assert_eq!(
core.state.running.iter().copied().collect::<Vec<_>>(),
vec![r1, r2]
);
assert_eq!(core.state.requests.get(&r1).unwrap().num_computed_tokens, 8);
assert_eq!(core.state.requests.get(&r2).unwrap().num_computed_tokens, 4);
assert_eq!(
core.state
.requests
.get(&r1)
.unwrap()
.sequence
.generated_tokens(),
1
);
assert_eq!(
core.state.requests.get(&r2).unwrap().status,
RequestStatus::Running
);
assert_eq!(core.kv_manager.num_active_blocks(), 4);
}
#[test]
fn test_running_requests_consume_budget_before_waiting() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(4))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
core.execute_pass(&mut collector, 0.0);
let pass = core.execute_pass(&mut collector, 1.0);
assert!(pass.output_signals.iter().any(|signal| signal.uuid == r1));
assert_eq!(
core.state.requests.get(&r2).unwrap().num_computed_tokens,
0,
"waiting request should not steal budget before the running request catches up"
);
}
#[test]
fn test_first_token_can_arrive_on_prompt_completion_pass() {
let mut core = VllmCore::new(make_args());
let uuid = Uuid::from_u128(11);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
assert_eq!(pass.output_signals.len(), 1);
assert_eq!(pass.output_signals[0].uuid, uuid);
assert!(!pass.output_signals[0].completed);
assert_eq!(
core.state
.requests
.get(&uuid)
.unwrap()
.sequence
.generated_tokens(),
1
);
}
#[test]
fn test_preemption_requeues_newest_running_request() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.preemption_mode(PreemptionMode::Lifo)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
let r3 = Uuid::from_u128(3);
for (uuid, range) in [(r1, 0u32..8u32), (r2, 100u32..108u32), (r3, 200u32..212u32)] {
core.receive(DirectRequest {
tokens: range.collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
core.execute_pass(&mut collector, 0.0);
core.execute_pass(&mut collector, 1.0);
let request = core.state.requests.get(&r2).unwrap();
assert_eq!(request.status, RequestStatus::Preempted);
assert_eq!(request.num_computed_tokens, 0);
assert_eq!(request.num_preemptions, 1);
assert_eq!(core.state.waiting.front().copied(), Some(r2));
}
#[test]
fn test_running_request_catches_up_decode_tail_before_promote() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(8)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(1))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let uuid = Uuid::from_u128(99);
let mut sequence = ActiveSequence::new((0..6).collect(), 16, Some(4), true, false);
let signal = sequence.take_creation_signal().unwrap();
assert_eq!(core.kv_manager.process(&signal), 2);
for _ in 0..6 {
let signals = sequence.generate();
for signal in &signals {
core.kv_manager.process(signal);
}
if sequence.generated_tokens() < sequence.max_output_tokens() {
sequence.commit_allocation(sequence.len());
}
}
let free = sequence.reset_with_signal();
for signal in &free {
core.kv_manager.process(signal);
}
let prompt_only = sequence
.prepare_allocation(sequence.num_input_tokens())
.unwrap();
assert_eq!(core.kv_manager.process(&prompt_only), 2);
sequence.commit_allocation(sequence.num_input_tokens());
core.state.insert_running_for_test(uuid);
core.state.requests.insert(
uuid,
VllmRequestState {
sequence,
status: RequestStatus::Running,
num_computed_tokens: 9,
num_preemptions: 1,
},
);
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let request = core.state.requests.get(&uuid).unwrap();
assert_eq!(pass.output_signals.len(), 1);
assert_eq!(request.num_computed_tokens, 12);
assert_eq!(request.sequence.num_allocated_tokens(), 13);
assert_eq!(core.kv_manager.num_active_blocks(), 4);
}
#[test]
fn test_completion_returns_scheduler_to_idle() {
let mut core = VllmCore::new(make_args());
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
while !core.is_empty() {
core.execute_pass(&mut collector, 0.0);
}
assert!(core.state.waiting.is_empty());
assert!(core.state.running.is_empty());
assert_eq!(core.kv_manager.num_active_blocks(), 0);
}
}
mod router_events {
use super::*;
#[test]
fn test_vllm_pass_visibility_is_pass_start() {
let mut core = VllmCore::new_with_kv_capture(router_args(), ROUTER_TEST_WORKER_ID);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(71)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
assert_eq!(
pass.router_event_visibility,
RouterEventVisibility::PassStart
);
}
#[tokio::test]
async fn test_completion_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let mut core = VllmCore::new_with_kv_capture(router_args(), ROUTER_TEST_WORKER_ID);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(41)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let mut now_ms = 0.0;
let mut saw_store = false;
while !core.is_empty() {
let pass = core.execute_pass(&mut collector, now_ms);
saw_store |= !stored_hashes(&pass.kv_events).is_empty();
now_ms = pass.end_ms;
harness.apply_events(pass.kv_events).await;
}
assert!(saw_store);
assert!(harness.ok_count(METRIC_EVENT_STORED) > 0);
assert_eq!(core.kv_manager.num_active_blocks(), 0);
harness.shutdown();
}
#[tokio::test]
async fn test_preemption_recompute_events_apply_cleanly() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.preemption_mode(PreemptionMode::Lifo)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new_with_kv_capture(args, ROUTER_TEST_WORKER_ID);
let r1 = Uuid::from_u128(51);
let r2 = Uuid::from_u128(52);
let r3 = Uuid::from_u128(53);
for (uuid, range) in [(r1, 0u32..8u32), (r2, 100u32..108u32), (r3, 200u32..212u32)] {
core.receive(DirectRequest {
tokens: range.collect(),
max_output_tokens: 2,
uuid: Some(uuid),
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let mut collector = crate::replay::TraceCollector::default();
let mut now_ms = 0.0;
let mut saw_remove = false;
for _ in 0..2 {
let pass = core.execute_pass(&mut collector, now_ms);
saw_remove |= removed_event_count(&pass.kv_events) > 0;
now_ms = pass.end_ms;
harness.apply_events(pass.kv_events).await;
}
let request = core.state.requests.get(&r2).unwrap();
assert_eq!(request.status, RequestStatus::Preempted);
assert_eq!(request.num_computed_tokens, 0);
assert_eq!(request.num_preemptions, 1);
assert_eq!(core.state.waiting.front().copied(), Some(r2));
assert!(saw_remove);
assert!(harness.ok_count(METRIC_EVENT_REMOVED) > 0);
harness.shutdown();
}
}
mod live_scheduler {
use super::*;
type CapturedKvEvent = (KvCacheEvent, Option<Vec<Vec<u32>>>);
#[derive(Default)]
struct CapturingKvSink {
events: Mutex<Vec<CapturedKvEvent>>,
}
impl CapturingKvSink {
fn take(&self) -> Vec<CapturedKvEvent> {
std::mem::take(&mut *self.events.lock().unwrap())
}
}
impl KvCacheEventSink for CapturingKvSink {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.events.lock().unwrap().push((event, None));
Ok(())
}
}
impl RawKvEventSink for CapturingKvSink {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()> {
self.events
.lock()
.unwrap()
.push((event.event, event.block_token_ids));
Ok(())
}
}
#[rstest]
#[case::case_1(false, false, false)]
#[case::case_2(false, true, false)]
#[case::case_3(true, false, false)]
#[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
async fn test_cache_hit_rate_with_identical_requests() {
let block_size: usize = 64;
let max_output_tokens: usize = 10;
let speedup_ratio = 10.0;
let num_requests = 10;
let token_length = 65;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
.block_size(block_size)
.speedup_ratio(speedup_ratio)
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let identical_tokens: Vec<u32> = (0..token_length).collect();
for _ in 0..num_requests {
scheduler.receive(DirectRequest {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
tokio::time::sleep(Duration::from_millis(100)).await;
}
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
let metrics_rx = scheduler.metrics_receiver();
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => break,
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
assert_eq!(received_tokens, num_requests * max_output_tokens);
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(10)
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
scheduler.receive(DirectRequest {
tokens: (0..256).collect(),
max_output_tokens: 200,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut received_count = 0;
while received_count < 129 {
if output_rx.recv().await.is_some() {
received_count += 1;
continue;
}
panic!("Channel closed before receiving 129 tokens");
}
drop(output_rx);
let metrics_rx = scheduler.metrics_receiver();
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
loop {
if metrics_rx.borrow().active_decode_blocks == 0 {
break;
}
if tokio::time::Instant::now() >= deadline {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
}
#[tokio::test]
async fn test_live_scheduler_forwards_buffered_kv_token_ids() {
let sink = Arc::new(CapturingKvSink::default());
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(12)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(1))
.enable_chunked_prefill(true)
.enable_prefix_caching(true)
.speedup_ratio(1000.0)
.zmq_kv_events_port(Some(12345))
.build()
.unwrap();
let scheduler = Scheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::new(None, Some(sink.clone())),
None,
);
scheduler.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(72)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let signal = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
.await
.expect("scheduler should emit output")
.expect("output channel should stay open");
assert!(signal.completed);
tokio::time::sleep(Duration::from_millis(50)).await;
let events = sink.take();
let stored = events
.into_iter()
.find_map(|(event, block_token_ids)| match event.data {
KvCacheEventData::Stored(_) => block_token_ids,
_ => None,
})
.expect("live scheduler should forward stored KV event token ids");
assert!(!stored.is_empty());
assert!(stored.iter().all(|block| !block.is_empty()));
}
#[tokio::test]
async fn test_live_pathological_load_no_router_event_errors() {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = Scheduler::new(
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(3))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(1000.0)
.build()
.unwrap(),
0,
Some(output_tx),
KvEventPublishers::new(Some(sink.clone()), None),
None,
);
for _ in 0..8 {
scheduler.receive(DirectRequest {
tokens: vec![42; 8],
max_output_tokens: 4,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected = 8 * 4;
let mut seen = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
if seen == expected {
break;
}
}
_ = &mut timeout => {
break;
}
}
}
assert_eq!(seen, expected);
drop(scheduler);
drop(sink);
forward_task.await.unwrap();
harness.flush().await;
harness.assert_no_event_errors();
assert!(harness.ok_count(METRIC_EVENT_STORED) > 0);
harness.shutdown();
}
}
...@@ -333,6 +333,9 @@ pub mod model { ...@@ -333,6 +333,9 @@ pub mod model {
/// KV Router configuration environment variables /// KV Router configuration environment variables
pub mod router { pub mod router {
/// Minimum number of workers required before KV router startup continues.
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";
/// Queue threshold fraction for prefill token capacity. /// Queue threshold fraction for prefill token capacity.
/// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens. /// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
pub const DYN_ROUTER_QUEUE_THRESHOLD: &str = "DYN_ROUTER_QUEUE_THRESHOLD"; pub const DYN_ROUTER_QUEUE_THRESHOLD: &str = "DYN_ROUTER_QUEUE_THRESHOLD";
...@@ -492,6 +495,7 @@ mod tests { ...@@ -492,6 +495,7 @@ mod tests {
model::huggingface::HF_HOME, model::huggingface::HF_HOME,
model::huggingface::HF_HUB_OFFLINE, model::huggingface::HF_HUB_OFFLINE,
// Router // Router
router::DYN_ROUTER_MIN_INITIAL_WORKERS,
router::DYN_ROUTER_QUEUE_THRESHOLD, router::DYN_ROUTER_QUEUE_THRESHOLD,
router::DYN_ROUTER_QUEUE_POLICY, router::DYN_ROUTER_QUEUE_POLICY,
// Event Plane // Event Plane
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from pathlib import Path
from types import SimpleNamespace
import pytest
from dynamo.llm import EngineType, EntrypointArgs
MODULE_PATH = (
Path(__file__).resolve().parents[2] / "components/src/dynamo/mocker/config.py"
)
SPEC = importlib.util.spec_from_file_location("dynamo_mocker_config", MODULE_PATH)
assert SPEC is not None
assert SPEC.loader is not None
CONFIG = importlib.util.module_from_spec(SPEC)
SPEC.loader.exec_module(CONFIG)
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.gpu_0,
pytest.mark.parallel,
pytest.mark.unit,
]
def make_args(**overrides):
defaults = {
"extra_engine_args": None,
"engine_type": "vllm",
"num_gpu_blocks": 16384,
"block_size": None,
"max_num_seqs": 256,
"max_num_batched_tokens": 8192,
"enable_prefix_caching": True,
"enable_chunked_prefill": True,
"preemption_mode": "lifo",
"speedup_ratio": 1.0,
"decode_speedup_ratio": 1.0,
"dp_size": 1,
"startup_time": None,
"durable_kv_events": False,
"kv_transfer_bandwidth": 64.0,
"reasoning": None,
"sglang_schedule_policy": None,
"sglang_page_size": None,
"sglang_max_prefill_tokens": None,
"sglang_chunked_prefill_size": None,
"sglang_clip_max_new_tokens": None,
"sglang_schedule_conservativeness": None,
"aic_perf_model": False,
"aic_system": None,
"aic_backend_version": None,
"aic_tp_size": None,
"model_path": None,
"is_prefill_worker": False,
"is_decode_worker": False,
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_build_runtime_config_uses_normalized_sglang_page_size_alias():
engine_args = CONFIG.build_mocker_engine_args(
make_args(engine_type="sglang", block_size=None, sglang_page_size=16)
)
block_size, runtime_config = CONFIG.build_runtime_config(engine_args)
assert block_size == 16
assert runtime_config.total_kv_blocks == 16384
assert runtime_config.max_num_seqs == 256
assert runtime_config.max_num_batched_tokens == 8192
def test_build_mocker_engine_args_rejects_mismatched_sglang_sizes():
with pytest.raises(Exception, match="block_size and sglang.page_size to match"):
CONFIG.build_mocker_engine_args(
make_args(engine_type="sglang", block_size=8, sglang_page_size=4)
)
def test_load_mocker_engine_args_from_json_file_normalizes_page_size(tmp_path):
config_path = tmp_path / "engine_args.json"
config_path.write_text(
'{"engine_type":"sglang","sglang":{"page_size":32},"num_gpu_blocks":1024}'
)
engine_args = CONFIG.load_mocker_engine_args(
make_args(extra_engine_args=config_path)
)
assert engine_args.block_size == 32
assert engine_args.num_gpu_blocks == 1024
def test_worker_overrides_drive_runtime_config_for_prefill_worker():
engine_args = CONFIG.build_mocker_engine_args(make_args(is_prefill_worker=True))
worker_args = CONFIG.apply_worker_engine_args_overrides(
engine_args,
bootstrap_port=9001,
kv_bytes_per_token=128,
)
block_size, runtime_config = CONFIG.build_runtime_config(worker_args)
assert block_size == 64
assert worker_args.bootstrap_port == 9001
assert runtime_config.bootstrap_port == 9001
assert runtime_config.bootstrap_host is not None
def test_runtime_config_disables_local_indexer_for_decode_worker():
engine_args = CONFIG.build_mocker_engine_args(
make_args(is_decode_worker=True, durable_kv_events=False)
)
_, runtime_config = CONFIG.build_runtime_config(engine_args)
assert engine_args.enable_local_indexer is True
assert runtime_config.enable_local_indexer is False
def test_entrypoint_args_accept_typed_mocker_engine_args():
engine_args = CONFIG.build_mocker_engine_args(make_args())
entrypoint_args = EntrypointArgs(
engine_type=EngineType.Mocker,
mocker_engine_args=engine_args,
kv_cache_block_size=engine_args.block_size,
)
assert entrypoint_args is not None
...@@ -5,7 +5,6 @@ import asyncio ...@@ -5,7 +5,6 @@ import asyncio
import json import json
import logging import logging
import random import random
import time
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import aiohttp import aiohttp
...@@ -164,19 +163,16 @@ def _test_router_two_routers( ...@@ -164,19 +163,16 @@ def _test_router_two_routers(
for i, port in enumerate(router_ports): for i, port in enumerate(router_ports):
logger.info(f"Starting KV router frontend on port {port}") logger.info(f"Starting KV router frontend on port {port}")
kv_router = KVRouterProcess( kv_router = KVRouterProcess(
request, block_size, port, engine_workers.namespace, store_backend request,
block_size,
port,
engine_workers.namespace,
store_backend,
min_initial_workers=engine_workers.num_workers,
) )
kv_router.__enter__() kv_router.__enter__()
kv_routers.append(kv_router) kv_routers.append(kv_router)
# Add delay between routers for file backend to ensure first router's
# registration is visible before second router starts its cleanup
if i == 0 and store_backend == "file":
logger.info(
"Waiting 0.5s for first router to fully register (file backend)"
)
time.sleep(0.5)
# Wait for workers to be ready on both routers # Wait for workers to be ready on both routers
logger.info("Waiting for workers to register with both routers...") logger.info("Waiting for workers to register with both routers...")
for i, port in enumerate(router_ports): for i, port in enumerate(router_ports):
...@@ -331,7 +327,7 @@ def _test_python_router_bindings( ...@@ -331,7 +327,7 @@ def _test_python_router_bindings(
AssertionError: If requests fail or router doesn't work correctly AssertionError: If requests fail or router doesn't work correctly
""" """
# Create KvRouterConfig with default settings # Create KvRouterConfig with default settings
kv_router_config = KvRouterConfig() kv_router_config = KvRouterConfig(min_initial_workers=num_workers)
# Create KvRouter Python object # Create KvRouter Python object
kv_router = KvRouter( kv_router = KvRouter(
...@@ -834,6 +830,7 @@ def _test_router_indexers_sync( ...@@ -834,6 +830,7 @@ def _test_router_indexers_sync(
router_snapshot_threshold=20, router_snapshot_threshold=20,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
min_initial_workers=num_workers,
) )
# If standalone indexer mode, launch mockers one-by-one and register. # If standalone indexer mode, launch mockers one-by-one and register.
...@@ -1478,11 +1475,15 @@ def _test_router_decisions( ...@@ -1478,11 +1475,15 @@ def _test_router_decisions(
if standalone_indexer_url: if standalone_indexer_url:
await engine_workers.launch_mockers_with_indexer(endpoint) await engine_workers.launch_mockers_with_indexer(endpoint)
# Workers register one instance per process (not per dp_rank)
expected_num_instances = engine_workers.num_workers
kv_router_config = KvRouterConfig( kv_router_config = KvRouterConfig(
router_snapshot_threshold=20, router_snapshot_threshold=20,
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
min_initial_workers=expected_num_instances,
) )
kv_router = KvRouter( kv_router = KvRouter(
endpoint=endpoint, endpoint=endpoint,
...@@ -1490,9 +1491,6 @@ def _test_router_decisions( ...@@ -1490,9 +1491,6 @@ def _test_router_decisions(
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
) )
# Workers register one instance per process (not per dp_rank)
expected_num_instances = engine_workers.num_workers
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready( worker_ids = await wait_for_workers_ready(
endpoint, endpoint,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import time
from typing import Any
from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_indexers_sync,
)
from tests.router.helper import get_runtime
from tests.utils.constants import DefaultPort
from tests.utils.port_utils import allocate_ports, deallocate_ports
from tests.utils.test_output import resolve_test_output_path
logger = logging.getLogger(__name__)
TEST_PROMPT = (
"In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on "
"clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint "
"rustle of leaves, but it remained calm, confident in the safety of its burrow "
"just a few hops away. The late afternoon sun warmed its fur, and tiny dust "
"motes danced in the golden light as bees hummed lazily nearby. Though the "
"rabbit lived a simple life, every day was an adventure of scents, shadows, and "
"snacks-an endless search for the tastiest patch of greens and the softest spot "
"to nap."
)
def allocate_frontend_ports(request, count: int) -> list[int]:
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
def build_test_payload(model_name: str) -> dict[str, Any]:
return {
"model": model_name,
"messages": [{"role": "user", "content": TEST_PROMPT}],
"stream": True,
"max_tokens": 10,
}
class ManagedEngineProcessMixin:
process_name = "worker"
cleanup_name = "worker resources"
init_delay_seconds = 5
init_delay_reason = "initialize before starting next worker"
cleanup_delay_seconds = 2
def __enter__(self):
logger.info(
"[%s] Starting %d worker processes sequentially...",
self.__class__.__name__,
len(self.worker_processes),
)
for i, process in enumerate(self.worker_processes):
logger.info(
"[%s] Starting %s %d...", self.__class__.__name__, self.process_name, i
)
try:
process._logger = logging.getLogger(process.__class__.__name__)
process._command_name = process.command[0]
process.log_dir = resolve_test_output_path(process.log_dir)
os.makedirs(process.log_dir, exist_ok=True)
log_name = f"{process._command_name}.log.txt"
process._log_path = os.path.join(process.log_dir, log_name)
if process.data_dir:
process._remove_directory(process.data_dir)
process._terminate_all_matching_process_names()
logger.info(
"[%s] Launching process %d (pid will be assigned)...",
self.__class__.__name__,
i,
)
process._start_process()
logger.info(
"[%s] Worker %d launched with PID: %s",
self.__class__.__name__,
i,
process.proc.pid if process.proc else "unknown",
)
time.sleep(process.delayed_start)
if i < len(self.worker_processes) - 1:
logger.info(
"[%s] Waiting %ss for worker %d to %s...",
self.__class__.__name__,
self.init_delay_seconds,
i,
self.init_delay_reason,
)
time.sleep(self.init_delay_seconds)
except Exception:
logger.exception(
"[%s] Failed to start worker %d", self.__class__.__name__, i
)
try:
process.__exit__(None, None, None)
except Exception as cleanup_err:
logger.warning(
"[%s] Error during cleanup: %s",
self.__class__.__name__,
cleanup_err,
)
raise
logger.info(
"[%s] All %d workers launched with sequential initialization.",
self.__class__.__name__,
len(self.worker_processes),
)
logger.info(
"[%s] Waiting for health checks to complete...", self.__class__.__name__
)
for i, process in enumerate(self.worker_processes):
logger.info(
"[%s] Checking health for worker %d...", self.__class__.__name__, i
)
try:
elapsed = process._check_ports(process.timeout)
process._check_urls(process.timeout - elapsed)
process._check_funcs(process.timeout - elapsed)
logger.info(
"[%s] Worker %d health checks passed", self.__class__.__name__, i
)
except Exception:
logger.error(
"[%s] Worker %d health check failed", self.__class__.__name__, i
)
self.__exit__(None, None, None)
raise
logger.info(
"[%s] All workers started successfully and passed health checks!",
self.__class__.__name__,
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for i, process in enumerate(self.worker_processes):
logger.info("Stopping %s %d", self.process_name, i)
process.__exit__(exc_type, exc_val, exc_tb)
logger.info("Waiting for %s to fully clean up...", self.cleanup_name)
time.sleep(self.cleanup_delay_seconds)
def get_engine_endpoint(engine_workers, request_plane: str, component_name: str):
runtime = get_runtime(request_plane=request_plane)
return runtime.endpoint(f"{engine_workers.namespace}.{component_name}.generate")
def run_basic_router_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
num_workers: int,
single_gpu: bool,
request,
request_plane: str,
block_size: int,
model_name: str,
frontend_timeout: int = 180,
):
with engine_process_cls(
request,
num_workers=num_workers,
single_gpu=single_gpu,
request_plane=request_plane,
**{engine_args_name: engine_args},
) as engine_workers:
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic(
engine_workers=engine_workers,
block_size=block_size,
request=request,
frontend_port=frontend_port,
test_payload=build_test_payload(model_name),
num_requests=10,
frontend_timeout=frontend_timeout,
store_backend="etcd",
request_plane=request_plane,
)
def run_router_decisions_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
request,
request_plane: str,
model_name: str,
block_size: int,
component_name: str,
num_workers: int,
single_gpu: bool,
test_dp_rank: bool,
extra_process_kwargs: dict[str, Any] | None = None,
):
process_kwargs = extra_process_kwargs or {}
with engine_process_cls(
request,
num_workers=num_workers,
single_gpu=single_gpu,
request_plane=request_plane,
**{engine_args_name: engine_args},
**process_kwargs,
) as engine_workers:
endpoint = get_engine_endpoint(engine_workers, request_plane, component_name)
_test_router_decisions(
engine_workers,
endpoint,
model_name,
request,
test_dp_rank=test_dp_rank,
block_size=block_size,
)
def run_indexers_sync_test(
*,
engine_process_cls,
engine_args_name: str,
engine_args: dict[str, Any],
request,
runtime_services_dynamic_ports,
store_backend: str,
durable_kv_events: bool,
request_plane: str,
block_size: int,
model_name: str,
num_workers: int,
):
nats_process, _etcd_process = runtime_services_dynamic_ports
with engine_process_cls(
request,
num_workers=num_workers,
single_gpu=True,
request_plane=request_plane,
store_backend=store_backend,
durable_kv_events=durable_kv_events,
**{engine_args_name: engine_args},
) as engine_workers:
_test_router_indexers_sync(
engine_workers=engine_workers,
block_size=block_size,
model_name=model_name,
num_workers=num_workers,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
)
...@@ -28,6 +28,7 @@ class FrontendRouterProcess(ManagedProcess): ...@@ -28,6 +28,7 @@ class FrontendRouterProcess(ManagedProcess):
request_plane: str = "nats", request_plane: str = "nats",
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_mode: str = "kv", router_mode: str = "kv",
min_initial_workers: int | None = None,
): ):
command = [ command = [
"python3", "python3",
...@@ -65,6 +66,8 @@ class FrontendRouterProcess(ManagedProcess): ...@@ -65,6 +66,8 @@ class FrontendRouterProcess(ManagedProcess):
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane env["DYN_REQUEST_PLANE"] = request_plane
if min_initial_workers is not None:
env["DYN_ROUTER_MIN_INITIAL_WORKERS"] = str(min_initial_workers)
super().__init__( super().__init__(
command=command, command=command,
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
from tests.utils.constants import ROUTER_MODEL_NAME
MODEL_NAME = ROUTER_MODEL_NAME
MOONCAKE_TRACE_BLOCK_SIZE = 512
MOONCAKE_TRACE_SAMPLE_LINES = [
'{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}',
'{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}',
'{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}',
'{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}',
'{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}',
'{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}',
'{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}',
'{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}',
'{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}',
'{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}',
'{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}',
'{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}',
'{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}',
'{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}',
'{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}',
'{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}',
'{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}',
'{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}',
'{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}',
'{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}',
]
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.gpu_0,
pytest.mark.integration,
pytest.mark.parallel,
pytest.mark.router,
pytest.mark.model(MODEL_NAME),
]
@pytest.mark.timeout(120)
def test_mocker_trace_file_replay(tmp_path):
repo_root = Path.cwd()
trace_file = tmp_path / "mooncake_trace.jsonl"
trace_file.write_text(
"\n".join(MOONCAKE_TRACE_SAMPLE_LINES) + "\n", encoding="utf-8"
)
replay_report = trace_file.with_name(f"{trace_file.stem}.replay.json")
pythonpath_entries = [
str(repo_root / "components/src"),
str(repo_root / "lib/bindings/python/src"),
]
existing_pythonpath = os.environ.get("PYTHONPATH")
if existing_pythonpath:
pythonpath_entries.append(existing_pythonpath)
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries)
result = subprocess.run(
[
sys.executable,
"-m",
"dynamo.mocker",
"--trace-file",
str(trace_file),
"--model-path",
MODEL_NAME,
"--num-workers",
"1",
"--block-size",
str(MOONCAKE_TRACE_BLOCK_SIZE),
"--speedup-ratio",
"0",
],
cwd=repo_root,
env=env,
capture_output=True,
text=True,
timeout=120,
check=False,
)
assert result.returncode == 0, (
f"dynamo.mocker trace replay failed with exit code {result.returncode}\n"
f"stdout:\n{result.stdout}\n"
f"stderr:\n{result.stderr}"
)
assert replay_report.exists(), (
"Expected default replay report next to the temp trace file, "
f"but {replay_report} was not created.\nstdout:\n{result.stdout}\n"
f"stderr:\n{result.stderr}"
)
assert "Replay Summary" in result.stdout
assert f"JSON report: {replay_report}" in result.stdout
report = json.loads(replay_report.read_text(encoding="utf-8"))
assert report["num_requests"] == len(MOONCAKE_TRACE_SAMPLE_LINES)
assert report["completed_requests"] == len(MOONCAKE_TRACE_SAMPLE_LINES)
assert report["total_input_tokens"] > 0
assert report["total_output_tokens"] > 0
assert report["duration_ms"] > 0
assert report["wall_time_ms"] >= 0
assert report["request_throughput_rps"] > 0
...@@ -7,21 +7,20 @@ ...@@ -7,21 +7,20 @@
# so we set explicit pytest timeouts to fail fast on hangs (see per-test markers below). # so we set explicit pytest timeouts to fail fast on hangs (see per-test markers below).
import logging import logging
import os import os
import time
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import pytest import pytest
from tests.router.common import ( from tests.router.e2e_harness import (
_test_router_basic, ManagedEngineProcessMixin,
_test_router_decisions, run_basic_router_test,
_test_router_indexers_sync, run_indexers_sync_test,
run_router_decisions_test,
) )
from tests.router.helper import generate_random_suffix, get_runtime from tests.router.helper import generate_random_suffix
from tests.utils.constants import DefaultPort from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports from tests.utils.port_utils import allocate_ports, deallocate_ports
from tests.utils.test_output import resolve_test_output_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,31 +32,8 @@ pytestmark = [ ...@@ -33,31 +32,8 @@ pytestmark = [
pytest.mark.sglang, pytest.mark.sglang,
pytest.mark.model(MODEL_NAME), pytest.mark.model(MODEL_NAME),
] ]
SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 10
PAGE_SIZE = 16 # SGLang uses "page_size" instead of "block_size" PAGE_SIZE = 16 # SGLang uses "page_size" instead of "block_size"
def allocate_frontend_ports(request, count: int) -> list[int]:
"""Allocate random free frontend ports for xdist-safe execution."""
ports = allocate_ports(count, DefaultPort.FRONTEND.value)
request.addfinalizer(lambda: deallocate_ports(ports))
return ports
# Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Shared SGLang configuration for all tests # Shared SGLang configuration for all tests
# mem_fraction_static limits actual VRAM allocation (required for multi-worker on same GPU) # mem_fraction_static limits actual VRAM allocation (required for multi-worker on same GPU)
SGLANG_ARGS: Dict[str, Any] = { SGLANG_ARGS: Dict[str, Any] = {
...@@ -69,7 +45,7 @@ SGLANG_ARGS: Dict[str, Any] = { ...@@ -69,7 +45,7 @@ SGLANG_ARGS: Dict[str, Any] = {
} }
class SGLangProcess: class SGLangProcess(ManagedEngineProcessMixin):
"""Manages SGLang workers using dynamo.sglang (HTTP API + KV events). """Manages SGLang workers using dynamo.sglang (HTTP API + KV events).
This is a drop-in replacement for MockerProcess that uses real SGLang workers. This is a drop-in replacement for MockerProcess that uses real SGLang workers.
...@@ -242,97 +218,8 @@ class SGLangProcess: ...@@ -242,97 +218,8 @@ class SGLangProcess:
f"with endpoint: {self.endpoint}" f"with endpoint: {self.endpoint}"
) )
def __enter__(self): process_name = "SGLang worker"
"""Start all SGLang worker processes with sequential initialization. cleanup_name = "SGLang worker resources"
Workers are started sequentially with a delay between each to avoid
resource contention during initialization. This prevents
shared memory handle allocation failures when multiple workers
try to initialize simultaneously on the same GPU.
"""
logger.info(
f"[SGLangProcess] Starting {len(self.worker_processes)} worker processes sequentially..."
)
# Start each process sequentially, waiting for initialization before next
for i, process in enumerate(self.worker_processes):
logger.info(f"[SGLangProcess] Starting SGLang worker {i}...")
try:
# Manually initialize the process without blocking on health checks
process._logger = logging.getLogger(process.__class__.__name__)
process._command_name = process.command[0]
process.log_dir = resolve_test_output_path(process.log_dir)
os.makedirs(process.log_dir, exist_ok=True)
log_name = f"{process._command_name}.log.txt"
process._log_path = os.path.join(process.log_dir, log_name)
if process.data_dir:
process._remove_directory(process.data_dir)
process._terminate_all_matching_process_names()
logger.info(
f"[SGLangProcess] Launching process {i} (pid will be assigned)..."
)
process._start_process() # Start the process but don't wait
logger.info(
f"[SGLangProcess] Worker {i} launched with PID: {process.proc.pid if process.proc else 'unknown'}"
)
time.sleep(process.delayed_start)
# Wait for initialization before starting next worker
# This prevents shared memory contention
if i < len(self.worker_processes) - 1:
init_delay = 5 # seconds
logger.info(
f"[SGLangProcess] Waiting {init_delay}s for worker {i} to initialize before starting next worker..."
)
time.sleep(init_delay)
except Exception:
logger.exception(f"[SGLangProcess] Failed to start worker {i}")
# Clean up on failure
try:
process.__exit__(None, None, None)
except Exception as cleanup_err:
logger.warning(
f"[SGLangProcess] Error during cleanup: {cleanup_err}"
)
raise
logger.info(
f"[SGLangProcess] All {len(self.worker_processes)} workers launched with sequential initialization."
)
logger.info("[SGLangProcess] Waiting for health checks to complete...")
# Now wait for health checks for all processes
for i, process in enumerate(self.worker_processes):
logger.info(f"[SGLangProcess] Checking health for worker {i}...")
try:
elapsed = process._check_ports(process.timeout)
process._check_urls(process.timeout - elapsed)
process._check_funcs(process.timeout - elapsed)
logger.info(f"[SGLangProcess] Worker {i} health checks passed")
except Exception:
logger.error(f"[SGLangProcess] Worker {i} health check failed")
# Clean up all processes on failure
self.__exit__(None, None, None)
raise
logger.info(
"[SGLangProcess] All workers started successfully and passed health checks!"
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all SGLang worker processes gracefully."""
for i, process in enumerate(self.worker_processes):
logger.info(f"Stopping SGLang worker {i}")
process.__exit__(exc_type, exc_val, exc_tb)
# Add delay to ensure full cleanup of NATS/ETCD/ZMQ resources
# This prevents test isolation issues when running multiple tests
logger.info("Waiting for SGLang worker resources to fully clean up...")
time.sleep(2)
@pytest.mark.pre_merge @pytest.mark.pre_merge
...@@ -346,41 +233,17 @@ def test_sglang_kv_router_basic( ...@@ -346,41 +233,17 @@ def test_sglang_kv_router_basic(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
request_plane, request_plane,
): ):
""" run_basic_router_test(
Quick e2e sanity test for KV router with SGLang engine instances. engine_process_cls=SGLangProcess,
Tests both NATS and TCP request planes. engine_args_name="sglang_args",
""" engine_args=SGLANG_ARGS,
num_workers=2,
# runtime_services starts etcd and nats single_gpu=True,
N_SGLANG_WORKERS = 2 request=request,
logger.info(
f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers using request_plane={request_plane}"
)
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane, request_plane=request_plane,
) as sglang_workers: block_size=PAGE_SIZE,
# Start SGLang workers model_name=MODEL_NAME,
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers") )
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
# Run basic router test (starts router internally and waits for workers to be ready)
frontend_port = allocate_frontend_ports(request, 1)[0]
_test_router_basic(
engine_workers=sglang_workers,
block_size=PAGE_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS,
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
store_backend="etcd", # Explicit for clarity
request_plane=request_plane,
)
@pytest.mark.pre_merge @pytest.mark.pre_merge
...@@ -393,32 +256,19 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -393,32 +256,19 @@ def test_router_decisions_sglang_multiple_workers(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
request_plane, request_plane,
): ):
# runtime_services starts etcd and nats run_router_decisions_test(
logger.info("Starting SGLang router prefix reuse test with two workers") engine_process_cls=SGLangProcess,
N_WORKERS = 2 engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
with SGLangProcess( request=request,
request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS,
single_gpu=True, # Worker uses GPU 0
request_plane=request_plane, request_plane=request_plane,
) as sglang_workers: model_name=MODEL_NAME,
# Start 2 worker processes on the same GPU block_size=PAGE_SIZE,
logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)") component_name="backend",
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") num_workers=2,
single_gpu=True,
runtime = get_runtime(request_plane=request_plane) test_dp_rank=False,
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate") )
_test_router_decisions(
sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=False,
block_size=PAGE_SIZE,
)
@pytest.mark.gpu_2 @pytest.mark.gpu_2
...@@ -442,33 +292,20 @@ def test_router_decisions_sglang_dp( ...@@ -442,33 +292,20 @@ def test_router_decisions_sglang_dp(
* The (worker_id, dp_rank) with events should have exactly 4 events (one per request) * The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse) * All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
""" """
N_WORKERS = 1 run_router_decisions_test(
DP_SIZE = 2 engine_process_cls=SGLangProcess,
engine_args_name="sglang_args",
with SGLangProcess( engine_args=SGLANG_ARGS,
request, request=request,
sglang_args=SGLANG_ARGS,
num_workers=N_WORKERS, # Ignored when data_parallel_size is set
single_gpu=False,
data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank)
request_plane=request_plane, request_plane=request_plane,
) as sglang_workers: model_name=MODEL_NAME,
logger.info("Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)") block_size=PAGE_SIZE,
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") component_name="backend",
num_workers=1,
# Get runtime and create endpoint single_gpu=False,
runtime = get_runtime(request_plane=request_plane) test_dp_rank=True,
# Use the namespace from the SGLang workers extra_process_kwargs={"data_parallel_size": 2},
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate") )
_test_router_decisions(
sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=PAGE_SIZE,
)
@pytest.mark.pre_merge @pytest.mark.pre_merge
...@@ -492,50 +329,16 @@ def test_sglang_indexers_sync( ...@@ -492,50 +329,16 @@ def test_sglang_indexers_sync(
durable_kv_events, durable_kv_events,
request_plane, request_plane,
): ):
""" run_indexers_sync_test(
Test that two KV routers have synchronized indexer states after processing requests engine_process_cls=SGLangProcess,
with SGLang workers. This test verifies that both routers converge to the same internal state. engine_args_name="sglang_args",
engine_args=SGLANG_ARGS,
Tests with configuration: request=request,
- nats_core: etcd backend, local indexer with NATS Core, TCP request plane runtime_services_dynamic_ports=runtime_services_dynamic_ports,
(includes NATS interruption/recovery testing)
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process, _etcd_process = runtime_services_dynamic_ports
logger.info(
f"Starting SGLang indexers sync test: store_backend={store_backend}, "
f"durable_kv_events={durable_kv_events}, request_plane={request_plane}"
)
N_SGLANG_WORKERS = 2
with SGLangProcess(
request,
sglang_args=SGLANG_ARGS,
num_workers=N_SGLANG_WORKERS,
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend, store_backend=store_backend,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
) as sglang_workers: request_plane=request_plane,
# Start SGLang workers block_size=PAGE_SIZE,
logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers") model_name=MODEL_NAME,
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") num_workers=2,
)
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# When using durable_kv_events=True, use JetStream mode for the router
_test_router_indexers_sync(
engine_workers=sglang_workers,
block_size=PAGE_SIZE,
model_name=MODEL_NAME,
num_workers=N_SGLANG_WORKERS,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
)
logger.info("SGLang indexers sync test completed successfully")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment