Unverified Commit 930721c8 authored by Yongming Ding's avatar Yongming Ding Committed by GitHub
Browse files

feat(mocker): add SGLang engine simulation (#6977)


Signed-off-by: default avatarYongming Ding <yongmingd@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 22fb3398
...@@ -2091,6 +2091,7 @@ dependencies = [ ...@@ -2091,6 +2091,7 @@ dependencies = [
"rstest 0.18.2", "rstest 0.18.2",
"serde", "serde",
"serde_json", "serde_json",
"slotmap",
"tokio", "tokio",
"tokio-timerfd", "tokio-timerfd",
"tokio-util", "tokio-util",
...@@ -7362,6 +7363,15 @@ version = "0.4.12" ...@@ -7362,6 +7363,15 @@ version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "slotmap"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038"
dependencies = [
"version_check",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.15.1" version = "1.15.1"
......
...@@ -126,6 +126,7 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path: ...@@ -126,6 +126,7 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
# - kv_bytes_per_token is auto-computed in main.py after model prefetch, # - kv_bytes_per_token is auto-computed in main.py after model prefetch,
# - kv_cache_dtype is only used Python-side for the auto-computation. # - kv_cache_dtype is only used Python-side for the auto-computation.
"kv_transfer_bandwidth": getattr(args, "kv_transfer_bandwidth", None), "kv_transfer_bandwidth": getattr(args, "kv_transfer_bandwidth", None),
"engine_type": getattr(args, "engine_type", None),
} }
# Parse --reasoning JSON string into a nested object # Parse --reasoning JSON string into a nested object
...@@ -133,6 +134,21 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path: ...@@ -133,6 +134,21 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
if reasoning_str: if reasoning_str:
engine_args["reasoning"] = json.loads(reasoning_str) engine_args["reasoning"] = json.loads(reasoning_str)
# Build nested sglang config from individual CLI flags
sglang_args = {
"schedule_policy": getattr(args, "sglang_schedule_policy", None),
"page_size": getattr(args, "sglang_page_size", None),
"max_prefill_tokens": getattr(args, "sglang_max_prefill_tokens", None),
"chunked_prefill_size": getattr(args, "sglang_chunked_prefill_size", None),
"clip_max_new_tokens": getattr(args, "sglang_clip_max_new_tokens", None),
"schedule_conservativeness": getattr(
args, "sglang_schedule_conservativeness", None
),
}
sglang_args = {k: v for k, v in sglang_args.items() if v is not None}
if sglang_args:
engine_args["sglang"] = sglang_args
# Remove None values to only include explicitly set arguments # Remove None values to only include explicitly set arguments
engine_args = {k: v for k, v in engine_args.items() if v is not None} engine_args = {k: v for k, v in engine_args.items() if v is not None}
...@@ -348,6 +364,54 @@ def parse_args() -> argparse.Namespace: ...@@ -348,6 +364,54 @@ def parse_args() -> argparse.Namespace:
'Example: \'{"start_thinking_token_id": 123, "end_thinking_token_id": 456, "thinking_ratio": 0.6}\'', 'Example: \'{"start_thinking_token_id": 123, "end_thinking_token_id": 456, "thinking_ratio": 0.6}\'',
) )
# Engine type selection
parser.add_argument(
"--engine-type",
type=str,
default=None,
choices=["vllm", "sglang"],
help="Engine simulation type: 'vllm' (default) or 'sglang'.",
)
# SGLang-specific configuration
parser.add_argument(
"--sglang-schedule-policy",
type=str,
default=None,
choices=["fifo", "fcfs", "lpm"],
help="SGLang scheduling policy: 'fifo'/'fcfs' (default) or 'lpm' (longest prefix match).",
)
parser.add_argument(
"--sglang-page-size",
type=int,
default=None,
help="SGLang radix cache page size in tokens (default: 1).",
)
parser.add_argument(
"--sglang-max-prefill-tokens",
type=int,
default=None,
help="SGLang maximum prefill tokens budget per batch (default: 16384).",
)
parser.add_argument(
"--sglang-chunked-prefill-size",
type=int,
default=None,
help="SGLang chunked prefill size — max tokens per chunk (default: 8192).",
)
parser.add_argument(
"--sglang-clip-max-new-tokens",
type=int,
default=None,
help="SGLang clip max new tokens for admission budget (default: 4096).",
)
parser.add_argument(
"--sglang-schedule-conservativeness",
type=float,
default=None,
help="SGLang schedule conservativeness factor 0.0-1.0 (default: 1.0).",
)
# Legacy support - allow direct JSON file specification # Legacy support - allow direct JSON file specification
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
......
...@@ -11,6 +11,7 @@ use dynamo_kv_router::protocols::WorkerWithDpRank; ...@@ -11,6 +11,7 @@ use dynamo_kv_router::protocols::WorkerWithDpRank;
use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest}; use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest};
use dynamo_mocker::common::protocols::{DirectRequest, OutputSignal}; use dynamo_mocker::common::protocols::{DirectRequest, OutputSignal};
use dynamo_mocker::scheduler::Scheduler; use dynamo_mocker::scheduler::Scheduler;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
...@@ -170,27 +171,23 @@ async fn generate_sequence_events( ...@@ -170,27 +171,23 @@ async fn generate_sequence_events(
while i < worker_trace.len() { while i < worker_trace.len() {
let prev_i = i; let prev_i = i;
scheduler scheduler.receive(DirectRequest {
.receive(DirectRequest { tokens: tokens_from_request(&worker_trace[i], block_size),
tokens: tokens_from_request(&worker_trace[i], block_size), max_output_tokens: worker_trace[i].output_length as usize,
max_output_tokens: worker_trace[i].output_length as usize, uuid: Some(worker_trace[i].uuid),
uuid: Some(worker_trace[i].uuid), dp_rank: 0,
dp_rank: 0, });
})
.await;
i += 1; i += 1;
while i < worker_trace.len() while i < worker_trace.len()
&& worker_trace[i].timestamp == worker_trace[i - 1].timestamp && worker_trace[i].timestamp == worker_trace[i - 1].timestamp
{ {
scheduler scheduler.receive(DirectRequest {
.receive(DirectRequest { tokens: tokens_from_request(&worker_trace[i], block_size),
tokens: tokens_from_request(&worker_trace[i], block_size), max_output_tokens: worker_trace[i].output_length as usize,
max_output_tokens: worker_trace[i].output_length as usize, uuid: Some(worker_trace[i].uuid),
uuid: Some(worker_trace[i].uuid), dp_rank: 0,
dp_rank: 0, });
})
.await;
i += 1; i += 1;
} }
......
...@@ -13,6 +13,7 @@ use dynamo_kv_router::protocols::{ ...@@ -13,6 +13,7 @@ use dynamo_kv_router::protocols::{
pub use dynamo_kv_router::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; pub use dynamo_kv_router::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
use dynamo_mocker::common::protocols::{DirectRequest, KvCacheEventSink, MockEngineArgs}; use dynamo_mocker::common::protocols::{DirectRequest, KvCacheEventSink, MockEngineArgs};
use dynamo_mocker::scheduler::Scheduler; use dynamo_mocker::scheduler::Scheduler;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_tokens::compute_hash_v2; use dynamo_tokens::compute_hash_v2;
use indicatif::{ProgressBar, ProgressStyle}; use indicatif::{ProgressBar, ProgressStyle};
use plotters::prelude::*; use plotters::prelude::*;
...@@ -367,27 +368,23 @@ pub async fn generate_kv_events( ...@@ -367,27 +368,23 @@ pub async fn generate_kv_events(
while i < worker_trace.len() { while i < worker_trace.len() {
let prev_i = i; let prev_i = i;
scheduler scheduler.receive(DirectRequest {
.receive(DirectRequest { tokens: tokens_from_request(&worker_trace[i], block_size),
tokens: tokens_from_request(&worker_trace[i], block_size), max_output_tokens: worker_trace[i].output_length as usize,
max_output_tokens: worker_trace[i].output_length as usize, uuid: Some(worker_trace[i].uuid),
uuid: Some(worker_trace[i].uuid), dp_rank: 0,
dp_rank: 0, });
})
.await;
i += 1; i += 1;
while i < worker_trace.len() while i < worker_trace.len()
&& worker_trace[i].timestamp == worker_trace[i - 1].timestamp && worker_trace[i].timestamp == worker_trace[i - 1].timestamp
{ {
scheduler scheduler.receive(DirectRequest {
.receive(DirectRequest { tokens: tokens_from_request(&worker_trace[i], block_size),
tokens: tokens_from_request(&worker_trace[i], block_size), max_output_tokens: worker_trace[i].output_length as usize,
max_output_tokens: worker_trace[i].output_length as usize, uuid: Some(worker_trace[i].uuid),
uuid: Some(worker_trace[i].uuid), dp_rank: 0,
dp_rank: 0, });
})
.await;
i += 1; i += 1;
} }
......
...@@ -1715,6 +1715,7 @@ dependencies = [ ...@@ -1715,6 +1715,7 @@ dependencies = [
"rand 0.9.2", "rand 0.9.2",
"serde", "serde",
"serde_json", "serde_json",
"slotmap",
"tokio", "tokio",
"tokio-timerfd", "tokio-timerfd",
"tokio-util", "tokio-util",
...@@ -6324,6 +6325,15 @@ version = "0.4.12" ...@@ -6324,6 +6325,15 @@ version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "slotmap"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038"
dependencies = [
"version_check",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.15.1" version = "1.15.1"
......
...@@ -1731,6 +1731,7 @@ dependencies = [ ...@@ -1731,6 +1731,7 @@ dependencies = [
"rand 0.9.2", "rand 0.9.2",
"serde", "serde",
"serde_json", "serde_json",
"slotmap",
"tokio", "tokio",
"tokio-timerfd", "tokio-timerfd",
"tokio-util", "tokio-util",
...@@ -6391,6 +6392,15 @@ version = "0.4.12" ...@@ -6391,6 +6392,15 @@ version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "slotmap"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038"
dependencies = [
"version_check",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.15.1" version = "1.15.1"
......
...@@ -23,7 +23,8 @@ use dynamo_mocker::common::protocols::{ ...@@ -23,7 +23,8 @@ use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal, DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal,
}; };
use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise}; use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise};
use dynamo_mocker::scheduler::Scheduler; use dynamo_mocker::engine::create_engine;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -307,7 +308,7 @@ fn generate_random_token() -> TokenIdType { ...@@ -307,7 +308,7 @@ fn generate_random_token() -> TokenIdType {
} }
/// AsyncEngine wrapper around the Scheduler that generates random character tokens /// AsyncEngine wrapper around the Scheduler that generates random character tokens
pub struct MockVllmEngine { pub struct MockEngine {
active_requests: Arc<DashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>, active_requests: Arc<DashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>,
request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>, request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify, senders_ready: Notify,
...@@ -315,11 +316,11 @@ pub struct MockVllmEngine { ...@@ -315,11 +316,11 @@ pub struct MockVllmEngine {
/// Bootstrap server for prefill workers in disaggregated mode /// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>, bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
/// Keep schedulers alive so their CancelGuards don't fire prematurely. /// Keep schedulers alive so their CancelGuards don't fire prematurely.
_schedulers: OnceCell<Vec<Scheduler>>, _schedulers: OnceCell<Vec<Box<dyn SchedulerHandle>>>,
} }
impl MockVllmEngine { impl MockEngine {
/// Create a new MockVllmEngine with the given parameters /// Create a new MockEngine with the given parameters
pub fn new(engine_args: MockEngineArgs) -> Self { pub fn new(engine_args: MockEngineArgs) -> Self {
Self { Self {
active_requests: Arc::new(DashMap::new()), active_requests: Arc::new(DashMap::new()),
...@@ -404,9 +405,9 @@ impl MockVllmEngine { ...@@ -404,9 +405,9 @@ impl MockVllmEngine {
&self, &self,
component: Option<&Component>, component: Option<&Component>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Vec<Scheduler> { ) -> Vec<Box<dyn SchedulerHandle>> {
let args = &self.engine_args; let args = &self.engine_args;
let mut schedulers = Vec::<Scheduler>::new(); let mut schedulers = Vec::<Box<dyn SchedulerHandle>>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize); let mut senders = Vec::with_capacity(args.dp_size as usize);
for dp_rank in 0..args.dp_size { for dp_rank in 0..args.dp_size {
...@@ -485,7 +486,7 @@ impl MockVllmEngine { ...@@ -485,7 +486,7 @@ impl MockVllmEngine {
None => (None, None), None => (None, None),
}; };
let scheduler = Scheduler::new( let scheduler = create_engine(
args.clone(), args.clone(),
dp_rank, dp_rank,
Some(output_tx), Some(output_tx),
...@@ -536,7 +537,7 @@ impl MockVllmEngine { ...@@ -536,7 +537,7 @@ impl MockVllmEngine {
/// Start background tasks to publish metrics on change /// Start background tasks to publish metrics on change
async fn start_metrics_publishing( async fn start_metrics_publishing(
schedulers: &[Scheduler], schedulers: &[Box<dyn SchedulerHandle>],
component: Component, component: Component,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
...@@ -579,9 +580,7 @@ impl MockVllmEngine { ...@@ -579,9 +580,7 @@ impl MockVllmEngine {
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> for MockEngine {
for MockVllmEngine
{
async fn generate( async fn generate(
&self, &self,
input: SingleIn<PreprocessedRequest>, input: SingleIn<PreprocessedRequest>,
...@@ -744,12 +743,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -744,12 +743,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
} }
pub struct AnnotatedMockEngine { pub struct AnnotatedMockEngine {
inner: Arc<MockVllmEngine>, inner: Arc<MockEngine>,
} }
impl AnnotatedMockEngine { impl AnnotatedMockEngine {
pub fn new( pub fn new(
inner: MockVllmEngine, inner: MockEngine,
distributed_runtime: DistributedRuntime, distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId, endpoint_id: dynamo_runtime::protocols::EndpointId,
) -> Self { ) -> Self {
...@@ -818,7 +817,7 @@ pub async fn make_mocker_engine( ...@@ -818,7 +817,7 @@ pub async fn make_mocker_engine(
// Create the mocker engine // Create the mocker engine
tracing::info!("Creating mocker engine with config: {args:?}"); tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine = let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id); AnnotatedMockEngine::new(MockEngine::new(args), distributed_runtime, endpoint_id);
Ok(Arc::new(annotated_engine)) Ok(Arc::new(annotated_engine))
} }
...@@ -32,6 +32,7 @@ validator = { workspace = true } ...@@ -32,6 +32,7 @@ validator = { workspace = true }
# crate-specific # crate-specific
ndarray = "0.16" ndarray = "0.16"
slotmap = "1"
ndarray-npy = "0.9" ndarray-npy = "0.9"
ndarray-interp = "0.5" ndarray-interp = "0.5"
......
...@@ -4,5 +4,7 @@ ...@@ -4,5 +4,7 @@
//! Cache data structures for KV block management. //! Cache data structures for KV block management.
pub mod hash_cache; pub mod hash_cache;
pub mod radix_cache;
pub use hash_cache::HashCache; pub use hash_cache::HashCache;
pub use radix_cache::RadixCache;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Radix-tree KV cache for SGLang engine simulation.
//!
//! Reference: sglang/python/sglang/srt/mem_cache/radix_cache.py
use slotmap::{SlotMap, new_key_type};
use std::collections::{HashMap, HashSet};
use std::time::Instant;
new_key_type! {
/// Stable identifier for a tree node inside the [`RadixCache`].
pub struct NodeId;
}
/// Manages free / allocated token slot indices for the KV cache pool.
pub struct TokenPool {
free: Vec<usize>,
total: usize,
}
impl TokenPool {
pub fn new(total: usize) -> Self {
let free: Vec<usize> = (0..total).rev().collect();
Self { free, total }
}
/// Allocate `n` token slots. Returns `None` if not enough free slots.
pub fn allocate(&mut self, n: usize) -> Option<Vec<usize>> {
if self.free.len() < n {
return None;
}
let start = self.free.len() - n;
let indices: Vec<usize> = self.free.drain(start..).collect();
Some(indices)
}
/// Return token slots to the free pool.
pub fn free(&mut self, indices: &[usize]) {
self.free.extend(indices);
}
pub fn available(&self) -> usize {
self.free.len()
}
pub fn total(&self) -> usize {
self.total
}
}
/// A single node in the radix tree.
pub struct TreeNode {
/// Children keyed by `child.key[..page_size]` (a "child key").
pub children: HashMap<Vec<u64>, NodeId>,
pub parent: Option<NodeId>,
/// Token IDs stored at this edge.
pub key: Vec<u64>,
/// KV cache pool token indices. Length = `key.len()`.
pub value: Vec<usize>,
/// Walk-to-root reference count (protected when > 0).
pub lock_ref: usize,
/// Monotonic timestamp for LRU eviction.
pub last_access_time: Instant,
}
/// Radix tree for SGLang KV cache simulation.
pub struct RadixCache {
nodes: SlotMap<NodeId, TreeNode>,
root: NodeId,
pub token_pool: TokenPool,
page_size: usize,
/// Total token count in evictable nodes.
pub evictable_leaves: HashSet<NodeId>,
pub evictable_size: usize,
/// Total token count in protected (locked) nodes.
pub protected_size: usize,
}
impl RadixCache {
pub fn new(total_tokens: usize, page_size: usize) -> Self {
assert!(page_size >= 1, "page_size must be >= 1");
let mut nodes = SlotMap::with_key();
let root = nodes.insert(TreeNode {
children: HashMap::new(),
parent: None,
key: Vec::new(),
value: Vec::new(),
lock_ref: 0,
last_access_time: Instant::now(),
});
Self {
nodes,
root,
token_pool: TokenPool::new(total_tokens),
page_size,
evictable_leaves: HashSet::new(),
evictable_size: 0,
protected_size: 0,
}
}
pub fn root(&self) -> NodeId {
self.root
}
pub fn node(&self, id: NodeId) -> &TreeNode {
&self.nodes[id]
}
pub fn page_size(&self) -> usize {
self.page_size
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
fn child_key(&self, key: &[u64]) -> Vec<u64> {
key[..self.page_size.min(key.len())].to_vec()
}
fn page_align(&self, len: usize) -> usize {
len / self.page_size * self.page_size
}
fn key_match(&self, key0: &[u64], key1: &[u64]) -> usize {
if self.page_size == 1 {
key0.iter().zip(key1).take_while(|(a, b)| a == b).count()
} else {
let min_len = key0.len().min(key1.len());
let mut i = 0;
while i + self.page_size <= min_len {
if key0[i..i + self.page_size] != key1[i..i + self.page_size] {
break;
}
i += self.page_size;
}
i
}
}
pub fn match_prefix(&mut self, key: &[u64]) -> (usize, NodeId) {
let now = Instant::now();
self.nodes[self.root].last_access_time = now;
let mut current = self.root;
let mut matched: usize = 0;
while matched < key.len() {
let ck = self.child_key(&key[matched..]);
let child_id = match self.nodes[current].children.get(&ck).copied() {
Some(id) => id,
None => break,
};
let child_key = self.nodes[child_id].key.clone();
let common_len = self.key_match(&child_key, &key[matched..]);
if common_len < child_key.len() {
if common_len > 0 {
let intermediate = self.split_node(child_id, common_len);
current = intermediate;
}
matched += common_len;
break;
}
matched += common_len;
current = child_id;
self.nodes[current].last_access_time = now;
}
(matched, current)
}
/// Read-only prefix match length (does not mutate timestamps or split nodes).
/// Used for LPM scheduling scoring.
pub fn prefix_match_len(&self, key: &[u64]) -> usize {
let mut current = self.root;
let mut matched: usize = 0;
while matched < key.len() {
let ck = self.child_key(&key[matched..]);
let child_id = match self.nodes[current].children.get(&ck).copied() {
Some(id) => id,
None => break,
};
let child_key = &self.nodes[child_id].key;
let common_len = self.key_match(child_key, &key[matched..]);
if common_len < child_key.len() {
matched += common_len;
break;
}
matched += common_len;
current = child_id;
}
// Round down to page boundary
matched / self.page_size * self.page_size
}
/// Insert a token sequence into the tree. Key is page-aligned before insertion.
pub fn insert(&mut self, key: &[u64], value: &[usize]) {
let aligned_len = self.page_align(key.len());
if aligned_len == 0 {
return;
}
assert!(
value.len() >= aligned_len,
"not enough token indices: need {aligned_len}, got {}",
value.len()
);
let key = &key[..aligned_len];
let value = &value[..aligned_len];
let now = Instant::now();
self.nodes[self.root].last_access_time = now;
let mut current = self.root;
let mut key_offset: usize = 0;
while key_offset < key.len() {
let ck = self.child_key(&key[key_offset..]);
let child_id = match self.nodes[current].children.get(&ck).copied() {
Some(id) => id,
None => {
self.create_child(current, &key[key_offset..], &value[key_offset..]);
return;
}
};
let child_key = self.nodes[child_id].key.clone();
let common_len = self.key_match(&child_key, &key[key_offset..]);
if common_len == child_key.len() {
key_offset += common_len;
current = child_id;
self.nodes[current].last_access_time = now;
} else {
if common_len > 0 {
let intermediate = self.split_node(child_id, common_len);
key_offset += common_len;
if key_offset < key.len() {
self.create_child(intermediate, &key[key_offset..], &value[key_offset..]);
}
}
return;
}
}
}
fn split_node(&mut self, child_id: NodeId, split_pos: usize) -> NodeId {
let child = &self.nodes[child_id];
let child_parent = child.parent;
let child_key = child.key.clone();
let child_value = child.value.clone();
let child_lock_ref = child.lock_ref;
let child_last_access = child.last_access_time;
let suffix_ck = self.child_key(&child_key[split_pos..]);
let mut inter_children = HashMap::new();
inter_children.insert(suffix_ck, child_id);
let intermediate = TreeNode {
children: inter_children,
parent: child_parent,
key: child_key[..split_pos].to_vec(),
value: child_value[..split_pos].to_vec(),
lock_ref: child_lock_ref,
last_access_time: child_last_access,
};
let inter_id = self.nodes.insert(intermediate);
let child = &mut self.nodes[child_id];
child.key = child_key[split_pos..].to_vec();
child.value = child_value[split_pos..].to_vec();
child.parent = Some(inter_id);
let original_ck = self.child_key(&child_key);
if let Some(parent_id) = child_parent {
self.nodes[parent_id].children.insert(original_ck, inter_id);
}
// Size tracking: intermediate inherits child's lock_ref, so
// protected_size is unchanged (split_pos + remainder = original).
// For evictable: intermediate is NOT a leaf (has child), so only
// the child's contribution changes.
if self.evictable_leaves.contains(&child_id) {
let old_tokens = child_key.len();
let new_tokens = child_key.len() - split_pos;
self.evictable_size = self.evictable_size - old_tokens + new_tokens;
}
inter_id
}
fn create_child(&mut self, parent_id: NodeId, key: &[u64], value: &[usize]) {
let new_node = TreeNode {
children: HashMap::new(),
parent: Some(parent_id),
key: key.to_vec(),
value: value.to_vec(),
lock_ref: 0,
last_access_time: Instant::now(),
};
let ck = self.child_key(key);
let new_id = self.nodes.insert(new_node);
if self.evictable_leaves.remove(&parent_id) {
let parent_tokens = self.nodes[parent_id].key.len();
self.evictable_size -= parent_tokens;
}
self.nodes[parent_id].children.insert(ck, new_id);
self.evictable_leaves.insert(new_id);
self.evictable_size += key.len();
}
pub fn is_leaf(&self, id: NodeId) -> bool {
self.nodes[id].children.is_empty()
}
pub fn inc_lock_ref(&mut self, node_id: NodeId) {
let mut current = Some(node_id);
while let Some(id) = current {
if id == self.root {
break;
}
let node = &mut self.nodes[id];
let tokens = node.key.len();
node.lock_ref += 1;
if node.lock_ref == 1 {
if self.evictable_leaves.remove(&id) {
self.evictable_size -= tokens;
}
self.protected_size += tokens;
}
current = self.nodes[id].parent;
}
}
pub fn dec_lock_ref(&mut self, node_id: NodeId) {
let mut current = Some(node_id);
while let Some(id) = current {
if id == self.root {
break;
}
let node = &mut self.nodes[id];
if node.lock_ref == 0 {
tracing::warn!("dec_lock_ref on node with lock_ref == 0, skipping");
break;
}
node.lock_ref -= 1;
if node.lock_ref == 0 {
let tokens = node.key.len();
self.protected_size -= tokens;
if self.is_leaf(id) {
self.evictable_leaves.insert(id);
self.evictable_size += tokens;
}
}
current = self.nodes[id].parent;
}
}
/// Evict tokens from the cache by LRU order.
/// Returns `(num_tokens_evicted, evicted_page_indices)`.
pub fn evict(&mut self, num_tokens: usize) -> (usize, Vec<usize>) {
let mut evicted = 0;
let mut evicted_indices = Vec::new();
while evicted < num_tokens {
let victim = self
.evictable_leaves
.iter()
.min_by_key(|&&id| self.nodes[id].last_access_time)
.copied();
let Some(victim_id) = victim else {
break;
};
let victim_node = &self.nodes[victim_id];
let tokens = victim_node.key.len();
let pool_indices = victim_node.value.clone();
let parent_id = victim_node.parent;
let victim_key = victim_node.key.clone();
self.evictable_leaves.remove(&victim_id);
self.evictable_size -= tokens;
evicted += tokens;
evicted_indices.extend_from_slice(&pool_indices);
self.token_pool.free(&pool_indices);
if let Some(pid) = parent_id {
let ck = self.child_key(&victim_key);
self.nodes[pid].children.remove(&ck);
if pid != self.root
&& self.nodes[pid].children.is_empty()
&& self.nodes[pid].lock_ref == 0
{
let parent_tokens = self.nodes[pid].key.len();
self.evictable_leaves.insert(pid);
self.evictable_size += parent_tokens;
}
}
self.nodes.remove(victim_id);
}
(evicted, evicted_indices)
}
pub fn available_tokens(&self) -> usize {
self.token_pool.available()
}
pub fn total_tokens(&self) -> usize {
self.token_pool.total()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_pool_allocate_and_free() {
let mut pool = TokenPool::new(10);
assert_eq!(pool.available(), 10);
let a = pool.allocate(3).unwrap();
assert_eq!(a.len(), 3);
assert_eq!(pool.available(), 7);
let b = pool.allocate(7).unwrap();
assert_eq!(pool.available(), 0);
assert!(pool.allocate(1).is_none());
pool.free(&a);
assert_eq!(pool.available(), 3);
pool.free(&b);
assert_eq!(pool.available(), 10);
}
#[test]
fn test_match_prefix() {
let mut cache = RadixCache::new(100, 1);
// Empty tree
let (len, node) = cache.match_prefix(&[1, 2, 3]);
assert_eq!(len, 0);
assert_eq!(node, cache.root());
// Full match
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5]).0, 5);
// Partial match with split
cache.insert(&[1, 2, 3, 4, 5, 6, 7], &[10, 20, 30, 40, 50, 60, 70]);
let (len, node) = cache.match_prefix(&[1, 2, 3, 4, 5, 9, 9]);
assert_eq!(len, 5);
let n = cache.node(node);
assert_eq!(n.key, vec![1, 2, 3, 4, 5]);
assert_eq!(n.value, vec![10, 20, 30, 40, 50]);
let &suffix_id = n.children.get(&vec![6]).unwrap();
assert_eq!(cache.node(suffix_id).value, vec![60, 70]);
}
#[test]
fn test_insert() {
let mut cache = RadixCache::new(100, 1);
// Shared prefix splits the tree
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
cache.insert(&[1, 2, 3, 6, 7], &[10, 20, 30, 60, 70]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5]).0, 5);
assert_eq!(cache.match_prefix(&[1, 2, 3, 6, 7]).0, 5);
assert_eq!(cache.match_prefix(&[1, 2, 3, 9]).0, 3);
// Extend existing prefix
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3], &[10, 20, 30]);
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5]).0, 5);
// Duplicate insert is idempotent
cache.insert(&[1, 2, 3], &[10, 20, 30]);
// Match then insert suffix
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3, 4, 5], &[10, 20, 30, 40, 50]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5, 6, 7, 8]).0, 5);
cache.insert(&[1, 2, 3, 4, 5, 6, 7, 8], &[10, 20, 30, 40, 50, 60, 70, 80]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5, 6, 7, 8]).0, 8);
}
#[test]
fn test_page_size() {
// Insert and match with page_size=4
let mut cache = RadixCache::new(100, 4);
assert_eq!(cache.token_pool.total(), 100);
cache.insert(&[1, 2, 3, 4, 5, 6, 7], &[0, 1, 2, 3, 4, 5, 6]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4]).0, 4);
let (_, node) = cache.match_prefix(&[1, 2, 3, 4]);
assert_eq!(cache.node(node).value, vec![0, 1, 2, 3]);
cache.insert(&[1, 2, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3, 10, 11, 12, 13]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4, 5, 6, 7, 8]).0, 8);
// Children disambiguated by first page_size tokens
let mut cache = RadixCache::new(100, 4);
cache.insert(&[1, 2, 3, 4], &[0, 1, 2, 3]);
cache.insert(&[1, 2, 3, 5], &[4, 5, 6, 7]);
assert_eq!(cache.match_prefix(&[1, 2, 3, 4]).0, 4);
assert_eq!(cache.match_prefix(&[1, 2, 3, 5]).0, 4);
assert_eq!(cache.match_prefix(&[1, 2, 3, 6]).0, 0);
// Split at page boundary preserves value
let mut cache = RadixCache::new(100, 4);
cache.insert(&[1, 2, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3, 10, 11, 12, 13]);
cache.match_prefix(&[1, 2, 3, 4, 9, 9, 9, 9]);
let (_, node) = cache.match_prefix(&[1, 2, 3, 4]);
assert_eq!(cache.node(node).value, vec![0, 1, 2, 3]);
}
#[test]
fn test_lock_unlock_shared_prefix() {
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3, 4, 5], &[0, 1, 2, 3, 4]);
cache.insert(&[1, 2, 3, 6, 7], &[0, 1, 2, 5, 6]);
let (_, node_a) = cache.match_prefix(&[1, 2, 3, 4, 5]);
let (_, node_b) = cache.match_prefix(&[1, 2, 3, 6, 7]);
cache.inc_lock_ref(node_a);
cache.inc_lock_ref(node_b);
assert_eq!(cache.protected_size, 7); // 2+2+3
cache.dec_lock_ref(node_a);
assert!(cache.evictable_leaves.contains(&node_a));
cache.dec_lock_ref(node_b);
assert_eq!(cache.protected_size, 0);
}
#[test]
fn test_evict() {
// LRU order: oldest evicted first
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3], &[0, 1, 2]);
let (_, n1) = cache.match_prefix(&[1, 2, 3]);
cache.inc_lock_ref(n1);
cache.dec_lock_ref(n1);
std::thread::sleep(std::time::Duration::from_millis(1));
cache.insert(&[4, 5, 6], &[3, 4, 5]);
let (_, n2) = cache.match_prefix(&[4, 5, 6]);
cache.inc_lock_ref(n2);
cache.dec_lock_ref(n2);
let (evicted_count, evicted_indices) = cache.evict(3);
assert_eq!(evicted_count, 3);
// Evicted indices should match the pool indices originally inserted for [1,2,3]
let mut sorted_evicted = evicted_indices.clone();
sorted_evicted.sort();
let mut expected_indices = vec![0, 1, 2];
expected_indices.sort();
assert_eq!(
sorted_evicted, expected_indices,
"evicted indices should match inserted indices"
);
assert_eq!(cache.match_prefix(&[1, 2, 3]).0, 0); // oldest evicted
assert_eq!(cache.match_prefix(&[4, 5, 6]).0, 3); // newer kept
// Locked nodes are not evicted
let mut cache = RadixCache::new(100, 1);
cache.insert(&[1, 2, 3], &[0, 1, 2]);
cache.insert(&[4, 5, 6], &[3, 4, 5]);
let (_, locked) = cache.match_prefix(&[1, 2, 3]);
cache.inc_lock_ref(locked);
let (_, unlocked) = cache.match_prefix(&[4, 5, 6]);
cache.inc_lock_ref(unlocked);
cache.dec_lock_ref(unlocked);
let (evicted_count, evicted_indices) = cache.evict(6);
assert_eq!(evicted_count, 3); // only unlocked evicted
let mut sorted_evicted = evicted_indices;
sorted_evicted.sort();
assert_eq!(
sorted_evicted,
vec![3, 4, 5],
"should evict unlocked [4,5,6] indices"
);
assert_eq!(cache.match_prefix(&[1, 2, 3]).0, 3);
}
#[test]
fn test_query_methods() {
let cache = RadixCache::new(100, 1);
assert_eq!(cache.available_tokens(), 100);
assert_eq!(cache.total_tokens(), 100);
let cache4 = RadixCache::new(100, 4);
assert_eq!(cache4.available_tokens(), 100);
assert_eq!(cache4.total_tokens(), 100);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared KV cache trace logging for both vLLM and SGLang backends.
//!
//! Enabled by setting `DYN_MOCKER_KV_CACHE_TRACE=1` or `true`.
use dynamo_runtime::config::environment_names::mocker;
use std::env;
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
/// Check the env var to enable KV cache allocation/eviction trace logs.
pub static KV_CACHE_TRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
env::var(mocker::DYN_MOCKER_KV_CACHE_TRACE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
});
fn timestamp_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Log a vLLM KV cache trace event.
pub fn log_vllm_trace(
event: &str,
dp_rank: u32,
block_size: usize,
active_blocks: usize,
inactive_blocks: usize,
total_blocks: usize,
) {
if !*KV_CACHE_TRACE_ENABLED {
return;
}
let free_blocks = total_blocks
.saturating_sub(active_blocks)
.saturating_sub(inactive_blocks);
let utilization = if total_blocks > 0 {
(active_blocks + inactive_blocks) as f64 / total_blocks as f64
} else {
0.0
};
tracing::info!(
engine_type = "vllm",
event,
timestamp_ms = timestamp_ms(),
dp_rank,
block_size,
free_blocks,
active_blocks,
inactive_blocks,
total_blocks,
utilization,
"KV cache trace"
);
}
/// SGLang cache state snapshot for trace logging.
pub struct SglangCacheState<'a> {
pub event: &'a str,
pub dp_rank: u32,
pub num_tokens: usize,
pub page_size: usize,
pub available_tokens: usize,
pub evictable_tokens: usize,
pub protected_tokens: usize,
pub total_tokens: usize,
}
/// Log an SGLang KV cache trace event.
pub fn log_sglang_trace(state: &SglangCacheState) {
if !*KV_CACHE_TRACE_ENABLED {
return;
}
let utilization = if state.total_tokens > 0 {
(state.total_tokens - state.available_tokens) as f64 / state.total_tokens as f64
} else {
0.0
};
tracing::info!(
engine_type = "sglang",
event = state.event,
timestamp_ms = timestamp_ms(),
dp_rank = state.dp_rank,
num_tokens = state.num_tokens,
page_size = state.page_size,
available_tokens = state.available_tokens,
evictable_tokens = state.evictable_tokens,
protected_tokens = state.protected_tokens,
total_tokens = state.total_tokens,
utilization,
"KV cache trace"
);
}
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
pub mod bootstrap; pub mod bootstrap;
pub mod evictor; pub mod evictor;
pub mod kv_cache_trace;
pub mod perf_model; pub mod perf_model;
pub mod protocols; pub mod protocols;
pub mod running_mean; pub mod running_mean;
......
...@@ -90,6 +90,16 @@ pub enum PreemptionMode { ...@@ -90,6 +90,16 @@ pub enum PreemptionMode {
Fifo, Fifo,
} }
/// Engine type for selecting scheduling and KV cache simulation behavior
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EngineType {
/// vLLM-style scheduling with hash-based block KV cache
#[default]
Vllm,
/// SGLang-style scheduling with radix-tree KV cache
Sglang,
}
/// Worker type for disaggregated serving configurations /// Worker type for disaggregated serving configurations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType { pub enum WorkerType {
...@@ -134,10 +144,39 @@ impl ReasoningConfig { ...@@ -134,10 +144,39 @@ impl ReasoningConfig {
} }
} }
/// Configuration arguments for MockVllmEngine /// SGLang-specific configuration parameters.
///
/// Grouped into a nested struct to keep the `MockEngineArgs` namespace clean,
/// following the same pattern as [`ReasoningConfig`].
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Default)]
pub struct SglangArgs {
/// Scheduling policy: "fifo"/"fcfs" or "lpm". Default: "fifo".
pub schedule_policy: Option<String>,
/// Radix cache page size in tokens. Default: 1.
#[validate(range(min = 1))]
pub page_size: Option<usize>,
/// Maximum prefill tokens budget per batch. Default: 16384.
#[validate(range(min = 1))]
pub max_prefill_tokens: Option<usize>,
/// Chunked prefill size (max tokens per chunk). Default: 8192.
#[validate(range(min = 1))]
pub chunked_prefill_size: Option<usize>,
/// Clip max new tokens for admission budget. Default: 4096.
#[validate(range(min = 1))]
pub clip_max_new_tokens: Option<usize>,
/// Schedule conservativeness factor (0.0–1.0). Default: 1.0.
#[validate(range(min = 0.0, max = 1.0))]
pub schedule_conservativeness: Option<f64>,
}
/// Configuration arguments for MockEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[builder(pattern = "owned", build_fn(public))] #[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs { pub struct MockEngineArgs {
/// Engine type: vLLM or SGLang simulation
#[builder(default = "EngineType::Vllm")]
pub engine_type: EngineType,
#[builder(default = "16384")] #[builder(default = "16384")]
#[validate(range(min = 1))] #[validate(range(min = 1))]
pub num_gpu_blocks: usize, pub num_gpu_blocks: usize,
...@@ -236,6 +275,10 @@ pub struct MockEngineArgs { ...@@ -236,6 +275,10 @@ pub struct MockEngineArgs {
/// Lifo (default) evicts the newest request; Fifo evicts the oldest. /// Lifo (default) evicts the newest request; Fifo evicts the oldest.
#[builder(default)] #[builder(default)]
pub preemption_mode: PreemptionMode, pub preemption_mode: PreemptionMode,
/// SGLang-specific configuration. Only used when `engine_type == Sglang`.
#[builder(default = "None")]
pub sglang: Option<SglangArgs>,
} }
impl Default for MockEngineArgs { impl Default for MockEngineArgs {
...@@ -273,6 +316,7 @@ impl MockEngineArgs { ...@@ -273,6 +316,7 @@ impl MockEngineArgs {
// Define valid field names // Define valid field names
let valid_fields: HashSet<&str> = [ let valid_fields: HashSet<&str> = [
"engine_type",
"num_gpu_blocks", "num_gpu_blocks",
"block_size", "block_size",
"max_num_seqs", "max_num_seqs",
...@@ -294,6 +338,7 @@ impl MockEngineArgs { ...@@ -294,6 +338,7 @@ impl MockEngineArgs {
"zmq_kv_events_port", "zmq_kv_events_port",
"zmq_replay_port", "zmq_replay_port",
"preemption_mode", "preemption_mode",
"sglang",
] ]
.iter() .iter()
.cloned() .cloned()
...@@ -315,6 +360,22 @@ impl MockEngineArgs { ...@@ -315,6 +360,22 @@ impl MockEngineArgs {
} }
// Apply each extra argument to the builder // Apply each extra argument to the builder
if let Some(value) = extra_args.get("engine_type")
&& let Some(s) = value.as_str()
{
let engine_type = match s {
"vllm" => EngineType::Vllm,
"sglang" => EngineType::Sglang,
other => {
return Err(anyhow::anyhow!(
"Invalid engine_type '{}'. Must be 'vllm' or 'sglang'.",
other
));
}
};
builder = builder.engine_type(engine_type);
}
if let Some(value) = extra_args.get("num_gpu_blocks") if let Some(value) = extra_args.get("num_gpu_blocks")
&& let Some(num) = value.as_u64() && let Some(num) = value.as_u64()
{ {
...@@ -433,6 +494,12 @@ impl MockEngineArgs { ...@@ -433,6 +494,12 @@ impl MockEngineArgs {
builder = builder.preemption_mode(mode); builder = builder.preemption_mode(mode);
} }
if let Some(value) = extra_args.get("sglang") {
let cfg: SglangArgs = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse sglang config: {}", e))?;
builder = builder.sglang(Some(cfg));
}
// Parse worker type from is_prefill and is_decode flags // Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args let is_prefill = extra_args
.get("is_prefill") .get("is_prefill")
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Engine factory — creates the appropriate scheduler based on [`EngineType`].
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{EngineType, KvCacheEventSink, MockEngineArgs, OutputSignal};
use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
/// Create a scheduler for the configured engine type.
///
/// Returns a boxed [`SchedulerHandle`] that the engine wrapper can use
/// without knowing which backend is running underneath.
pub fn create_engine(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Box<dyn SchedulerHandle> {
match args.engine_type {
EngineType::Vllm => Box::new(Scheduler::new(
args,
dp_rank,
output_tx,
kv_event_sink,
cancellation_token,
)),
EngineType::Sglang => Box::new(SglangScheduler::new(
args,
dp_rank,
output_tx,
kv_event_sink,
cancellation_token,
)),
}
}
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
//! Pluggable KV cache block managers. //! Pluggable KV cache block managers.
pub mod sglang_backend;
pub mod vllm_backend; pub mod vllm_backend;
pub use sglang_backend::SglangKvManager;
pub use vllm_backend::KvManager; pub use vllm_backend::KvManager;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SGLang KV manager — wraps [`RadixCache`] with request-level lifecycle
//! operations and KV event publishing.
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::cache::radix_cache::{NodeId, RadixCache};
use crate::common::kv_cache_trace;
use crate::common::protocols::KvCacheEventSink;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData,
};
/// Result of `allocate_for_request`.
pub struct AllocResult {
/// Number of tokens matched from the prefix cache.
pub prefix_len: usize,
/// Pool token indices for the allocated input (1 per token).
pub kv_indices: Vec<usize>,
/// The deepest matched node in the radix tree (used for lock/unlock).
/// This is the prefix match point, not the new tokens — new tokens are
/// only in kv_indices and get inserted into the tree on completion.
pub last_node: NodeId,
}
pub struct SglangKvManager {
cache: RadixCache,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
dp_rank: u32,
next_event_id: u64,
/// Maps pool_idx → block_hash assigned during Stored events,
/// so Removed events can use the same block_hash.
idx_to_block_hash: HashMap<usize, ExternalSequenceBlockHash>,
}
impl SglangKvManager {
pub fn new(
total_tokens: usize,
page_size: usize,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
dp_rank: u32,
) -> Self {
Self {
cache: RadixCache::new(total_tokens, page_size),
kv_event_sink,
dp_rank,
next_event_id: 0,
idx_to_block_hash: HashMap::new(),
}
}
pub fn cache(&self) -> &RadixCache {
&self.cache
}
pub fn cache_mut(&mut self) -> &mut RadixCache {
&mut self.cache
}
/// Try to allocate KV cache for a new request.
/// Returns `None` if the pool doesn't have enough token slots (OOM).
pub fn allocate_for_request(&mut self, token_ids: &[u64]) -> Option<AllocResult> {
let (prefix_len, last_node) = self.cache.match_prefix(token_ids);
let new_tokens = token_ids.len() - prefix_len;
let prefix_indices = self.collect_path_indices(last_node);
let new_indices = self.cache.token_pool.allocate(new_tokens)?;
let mut kv_indices = prefix_indices;
kv_indices.extend_from_slice(&new_indices);
self.cache.inc_lock_ref(last_node);
// Chain from prefix's last block_hash (if any)
let parent_hash = kv_indices
.get(prefix_len.wrapping_sub(1))
.and_then(|&idx| self.idx_to_block_hash.get(&idx).copied());
self.publish_stored_event(&token_ids[prefix_len..], &new_indices, parent_hash);
self.log_trace("allocation", new_tokens);
Some(AllocResult {
prefix_len,
kv_indices,
last_node,
})
}
/// Cache a completed request's full sequence into the radix tree.
///
/// Inserts the full token sequence so future requests can reuse it,
/// then unlocks the path.
pub fn cache_finished_req(
&mut self,
token_ids: &[u64],
kv_indices: &[usize],
last_node: NodeId,
) {
self.cache.insert(token_ids, kv_indices);
self.cache.dec_lock_ref(last_node);
}
/// Cache a partial sequence after a chunked prefill step.
///
/// Inserts the partial sequence, then transfers the lock from the old
/// path to the new (extended) path. The request is still active, so the
/// new deepest node stays locked.
///
/// Returns the new `last_node` that the caller should use for
/// subsequent calls.
pub fn cache_unfinished_req(
&mut self,
token_ids: &[u64],
kv_indices: &[usize],
last_node: NodeId,
) -> NodeId {
self.cache.insert(token_ids, kv_indices);
// Find the new deepest node after insert
let (_, new_last_node) = self.cache.match_prefix(token_ids);
// Transfer lock: release old path, protect new path
self.cache.dec_lock_ref(last_node);
self.cache.inc_lock_ref(new_last_node);
new_last_node
}
/// Allocate a single token slot for decode output and publish a BlockStored event.
/// `last_idx` is the request's previous pool index for chaining block_hash.
pub fn allocate_decode_token(&mut self, last_idx: Option<usize>) -> Option<usize> {
let indices = self.cache.token_pool.allocate(1)?;
let idx = indices[0];
let parent_hash = last_idx.and_then(|i| self.idx_to_block_hash.get(&i).copied());
self.publish_stored_event(&[], &[idx], parent_hash);
self.log_trace("allocation", 1);
Some(idx)
}
/// Free a request without caching (e.g., aborted request).
///
/// Unlocks the path without inserting into the tree.
pub fn free_request(&mut self, last_node: NodeId) {
self.cache.dec_lock_ref(last_node);
}
/// Collect token indices from the matched prefix path by walking root→last_node.
fn collect_path_indices(&self, last_node: NodeId) -> Vec<usize> {
if last_node == self.cache.root() {
return Vec::new();
}
// Walk from last_node to root, collecting node IDs
let mut path = Vec::new();
let mut current = last_node;
loop {
let node = self.cache.node(current);
if node.parent.is_none() {
break;
}
path.push(current);
current = node.parent.unwrap();
}
path.reverse();
// Collect token indices from each node's value
let mut indices = Vec::new();
for node_id in path {
indices.extend_from_slice(&self.cache.node(node_id).value);
}
indices
}
/// Evict tokens from the cache, publish BlockRemoved events, and log a trace.
pub fn evict(&mut self, num_tokens: usize) {
let (evicted, evicted_indices) = self.cache.evict(num_tokens);
if !evicted_indices.is_empty() {
self.publish_removed_event(&evicted_indices);
}
self.log_trace("eviction", evicted);
}
fn log_trace(&self, event: &str, num_tokens: usize) {
kv_cache_trace::log_sglang_trace(&kv_cache_trace::SglangCacheState {
event,
dp_rank: self.dp_rank,
num_tokens,
page_size: self.cache.page_size(),
available_tokens: self.cache.available_tokens(),
evictable_tokens: self.cache.evictable_size,
protected_tokens: self.cache.protected_size,
total_tokens: self.cache.total_tokens(),
});
}
fn publish_stored_event(
&mut self,
token_ids: &[u64],
indices: &[usize],
parent_hash: Option<ExternalSequenceBlockHash>,
) {
if indices.is_empty() {
return;
}
let Some(ref sink) = self.kv_event_sink else {
return;
};
let mut blocks = Vec::with_capacity(indices.len());
let mut running_hash = parent_hash.map_or(0u64, |h| h.0);
for (i, &idx) in indices.iter().enumerate() {
// tokens_hash: per-token content hash for router prefix matching
let token_bytes: Vec<u8> = token_ids
.get(i)
.unwrap_or(&(idx as u64))
.to_le_bytes()
.to_vec();
let tokens_hash = dynamo_kv_router::protocols::compute_block_hash(&token_bytes);
// block_hash: cumulative hash (parent_hash, token_id) so it's unique
// per position and uniform across workers with the same token sequence.
let mut hasher = DefaultHasher::new();
running_hash.hash(&mut hasher);
tokens_hash.0.hash(&mut hasher);
running_hash = hasher.finish();
let block_hash = ExternalSequenceBlockHash(running_hash);
self.idx_to_block_hash.insert(idx, block_hash);
blocks.push(KvCacheStoredBlockData {
block_hash,
tokens_hash,
mm_extra_info: None,
});
}
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks,
}),
dp_rank: self.dp_rank,
};
self.next_event_id += 1;
if let Err(e) = sink.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV event: {e}");
}
}
fn publish_removed_event(&mut self, evicted_indices: &[usize]) {
let Some(ref sink) = self.kv_event_sink else {
return;
};
let block_hashes: Vec<ExternalSequenceBlockHash> = evicted_indices
.iter()
.filter_map(|&idx| self.idx_to_block_hash.remove(&idx))
.collect();
if block_hashes.is_empty() {
return;
}
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
};
self.next_event_id += 1;
if let Err(e) = sink.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV remove event: {e}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct MockSink {
events: Mutex<Vec<KvCacheEvent>>,
}
impl MockSink {
fn new() -> Self {
Self {
events: Mutex::new(Vec::new()),
}
}
fn event_count(&self) -> usize {
self.events.lock().unwrap().len()
}
}
impl KvCacheEventSink for MockSink {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
}
#[test]
fn test_allocate_cache_miss() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let result = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(result.prefix_len, 0);
assert_eq!(result.kv_indices.len(), 5);
assert_eq!(mgr.cache().token_pool.available(), 95);
}
#[test]
fn test_allocate_cache_hit() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
// First request: allocate and cache
let r1 = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(r1.kv_indices.len(), 5); // 5 pages (page_size=1)
mgr.cache_finished_req(&[1, 2, 3, 4, 5], &r1.kv_indices, r1.last_node);
// Second request with shared prefix
let r2 = mgr.allocate_for_request(&[1, 2, 3, 4, 5, 6, 7]).unwrap();
assert_eq!(r2.prefix_len, 5);
assert_eq!(r2.kv_indices.len(), 7); // 5 reused + 2 new pages
assert_eq!(mgr.cache().token_pool.available(), 93); // 100 - 5 - 2
}
#[test]
fn test_free_request_without_caching() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let result = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
mgr.free_request(result.last_node);
// Path is unlocked, tokens still allocated in pool
assert_eq!(mgr.cache().protected_size, 0);
}
#[test]
fn test_event_publishing() {
let sink = Arc::new(MockSink::new());
let mut mgr = SglangKvManager::new(100, 1, Some(sink.clone()), 0);
let r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(sink.event_count(), 1); // BlockStored for 3 new pages
mgr.cache_finished_req(&[1, 2, 3], &r.kv_indices, r.last_node);
// Second request with full cache hit → no new events
let r2 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(r2.prefix_len, 3);
assert_eq!(sink.event_count(), 1); // no new event
}
#[test]
fn test_allocate_oom() {
let mut mgr = SglangKvManager::new(3, 1, None, 0);
let _r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
// Pool is full
let result = mgr.allocate_for_request(&[4, 5, 6]);
assert!(result.is_none());
}
}
...@@ -35,26 +35,17 @@ ...@@ -35,26 +35,17 @@
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror //! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager. //! implementation of the main block manager.
use crate::cache::HashCache; use crate::cache::HashCache;
use crate::common::kv_cache_trace;
use crate::common::protocols::{KvCacheEventSink, MoveBlock, PrefillCost}; use crate::common::protocols::{KvCacheEventSink, MoveBlock, PrefillCost};
use crate::common::sequence::ActiveSequence; use crate::common::sequence::ActiveSequence;
use dynamo_kv_router::protocols::{ use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, KvCacheStoredBlockData, LocalBlockHash,
}; };
use dynamo_runtime::config::environment_names::mocker;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash}; use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use std::time::{SystemTime, UNIX_EPOCH};
/// Check the env var to enable KV cache allocation/eviction trace logs.
static KV_CACHE_TRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
env::var(mocker::DYN_MOCKER_KV_CACHE_TRACE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
});
pub struct KvManager { pub struct KvManager {
cache: HashCache, cache: HashCache,
...@@ -104,32 +95,14 @@ impl KvManager { ...@@ -104,32 +95,14 @@ impl KvManager {
return; return;
} }
if *KV_CACHE_TRACE_ENABLED { kv_cache_trace::log_vllm_trace(
let active_len = self.cache.num_active(); if is_store { "allocation" } else { "eviction" },
let inactive_len = self.cache.num_inactive(); self.dp_rank,
let free_blocks = self self.block_size,
.cache self.cache.num_active(),
.max_capacity() self.cache.num_inactive(),
.saturating_sub(active_len) self.cache.max_capacity(),
.saturating_sub(inactive_len); );
let event = if is_store { "allocation" } else { "eviction" };
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
tracing::info!(
event,
timestamp_ms,
block_ids = ?&full_blocks,
block_size = self.block_size,
free_blocks_after = free_blocks,
active_blocks = active_len,
inactive_blocks = inactive_len,
total_blocks = self.cache.max_capacity(),
dp_rank = self.dp_rank,
"KV cache trace"
);
}
let Some(ref sink) = self.kv_event_sink else { let Some(ref sink) = self.kv_event_sink else {
return; return;
......
...@@ -9,5 +9,6 @@ ...@@ -9,5 +9,6 @@
pub mod cache; pub mod cache;
pub mod common; pub mod common;
pub mod engine;
pub mod kv_manager; pub mod kv_manager;
pub mod scheduler; pub mod scheduler;
...@@ -3,7 +3,112 @@ ...@@ -3,7 +3,112 @@
//! Engine-specific scheduling implementations. //! Engine-specific scheduling implementations.
pub mod sglang;
pub mod vllm; pub mod vllm;
// Backward compatibility: re-export Scheduler from vllm module use crate::common::protocols::DirectRequest;
pub use vllm::Scheduler; use tokio::sync::mpsc;
pub use sglang::SglangScheduler;
pub use vllm::{MockerMetrics, Scheduler};
/// Engine-agnostic scheduler interface.
///
/// Both vLLM and SGLang schedulers implement this trait so that the engine
/// wrapper (`MockEngine`) can work with either backend through the same API.
pub trait SchedulerHandle: Send + Sync {
/// Send a request to the scheduler's waiting queue.
fn receive(&self, request: DirectRequest);
/// Get a clone of the request sender channel for direct sending.
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest>;
/// Get a watch receiver for scheduler metrics (active decode blocks, etc.).
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics>;
}
/// Shared test utilities for scheduler stress tests.
#[cfg(test)]
pub(crate) mod test_utils {
use super::*;
use crate::common::protocols::OutputSignal;
use tokio::time::Duration;
/// Send `num_requests` to a scheduler, collect all output signals, and assert
/// that the scheduler produces exactly `num_requests * max_output_tokens` signals
/// and returns to idle (0 active decode blocks).
///
/// When `use_shared_tokens` is true, the first half of each request shares a
/// common prefix to exercise prefix caching / radix tree reuse.
pub async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
scheduler.receive(DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received_tokens, expected_tokens,
"Expected {expected_tokens} output signals, got {received_tokens}"
);
// Verify scheduler returns to idle
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_eq!(
metrics.active_decode_blocks, 0,
"Scheduler should be idle after all requests complete, got {} active blocks",
metrics.active_decode_blocks
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SGLang scheduler simulation with adaptive admission control.
//!
//! Reference: sglang/python/sglang/srt/managers/scheduler.py
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use validator::Validate;
use crate::cache::radix_cache::NodeId;
use crate::common::perf_model::PerfModel;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal, WorkerType,
};
use crate::common::utils::sleep_until_precise;
use crate::kv_manager::SglangKvManager;
use super::MockerMetrics;
// SGLang default constants
const DEFAULT_MAX_PREFILL_TOKENS: usize = 16384;
const DEFAULT_CHUNKED_PREFILL_SIZE: usize = 8192;
const DEFAULT_CLIP_MAX_NEW_TOKENS: usize = 4096;
const DEFAULT_INIT_NEW_TOKEN_RATIO: f64 = 0.7;
const DEFAULT_MIN_NEW_TOKEN_RATIO_FACTOR: f64 = 0.14;
const DEFAULT_NEW_TOKEN_RATIO_DECAY_STEPS: f64 = 600.0;
const LPM_FALLBACK_THRESHOLD: usize = 128;
/// Tracks a single request inside the SGLang scheduler.
struct SglangRequest {
uuid: Uuid,
token_ids: Vec<u64>,
max_output_tokens: usize,
output_len: usize,
/// Deepest matched node in radix tree.
last_node: Option<NodeId>,
/// Pool page indices for the full sequence.
kv_indices: Vec<usize>,
/// Number of input tokens already prefilled (for chunked prefill).
prefilled_tokens: usize,
}
impl SglangRequest {
fn total_tokens_needed(&self, clip_max_new_tokens: usize) -> usize {
let remaining_input = self.token_ids.len() - self.prefilled_tokens;
let clipped_output = self.max_output_tokens.min(clip_max_new_tokens);
remaining_input + clipped_output
}
fn extend_input_len(&self) -> usize {
self.token_ids.len() - self.prefilled_tokens
}
}
/// SGLang scheduler with adaptive admission control.
///
/// The scheduling loop mirrors SGLang's `Scheduler.event_loop_normal`:
/// `receive_requests → apply_schedule_policy → get_new_batch_prefill →
/// simulate_prefill → simulate_decode → decay_new_token_ratio`
pub struct SglangScheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
/// Scheduling policy for reordering the waiting queue.
#[derive(Clone, Copy, Debug, Default)]
pub enum SchedulePolicy {
/// Process in arrival order.
#[default]
Fifo,
/// Longest prefix match — prioritise requests with the most cached tokens.
/// Falls back to FIFO when `waiting.len() > 128` (prefix matching is expensive).
Lpm,
}
/// Configuration extracted from MockEngineArgs for SGLang-specific params.
struct SglangConfig {
schedule_policy: SchedulePolicy,
max_prefill_tokens: usize,
chunked_prefill_size: usize,
clip_max_new_tokens: usize,
init_new_token_ratio: f64,
min_new_token_ratio: f64,
new_token_ratio_decay_step: f64,
perf_model: Arc<PerfModel>,
speedup_ratio: f64,
worker_type: WorkerType,
page_size: usize,
}
impl SglangConfig {
fn from_args(args: &MockEngineArgs) -> Self {
let sglang = args.sglang.as_ref();
let schedule_conservativeness = sglang
.and_then(|s| s.schedule_conservativeness)
.unwrap_or(1.0);
let init_new_token_ratio = DEFAULT_INIT_NEW_TOKEN_RATIO * schedule_conservativeness;
let min_new_token_ratio = init_new_token_ratio * DEFAULT_MIN_NEW_TOKEN_RATIO_FACTOR;
let decay_steps = DEFAULT_NEW_TOKEN_RATIO_DECAY_STEPS;
let decay_step = (init_new_token_ratio - min_new_token_ratio) / decay_steps;
let policy_str = sglang.and_then(|s| s.schedule_policy.as_deref());
let schedule_policy = match policy_str {
Some("lpm") => SchedulePolicy::Lpm,
Some("fifo") | Some("fcfs") | None => SchedulePolicy::Fifo,
Some(other) => {
tracing::warn!(
"Unknown sglang schedule_policy '{}', falling back to FIFO",
other
);
SchedulePolicy::Fifo
}
};
Self {
schedule_policy,
max_prefill_tokens: sglang
.and_then(|s| s.max_prefill_tokens)
.unwrap_or(DEFAULT_MAX_PREFILL_TOKENS),
chunked_prefill_size: sglang
.and_then(|s| s.chunked_prefill_size)
.unwrap_or(DEFAULT_CHUNKED_PREFILL_SIZE),
clip_max_new_tokens: sglang
.and_then(|s| s.clip_max_new_tokens)
.unwrap_or(DEFAULT_CLIP_MAX_NEW_TOKENS),
init_new_token_ratio,
min_new_token_ratio,
new_token_ratio_decay_step: decay_step,
perf_model: args.perf_model.clone(),
speedup_ratio: args.speedup_ratio,
worker_type: args.worker_type,
page_size: sglang.and_then(|s| s.page_size).unwrap_or(1),
}
}
}
impl SglangScheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let initial_metrics = MockerMetrics {
dp_rank,
active_decode_blocks: 0,
};
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
args.validate().expect("invalid MockEngineArgs");
let config = SglangConfig::from_args(&args);
let total_tokens = args.num_gpu_blocks * args.block_size;
tokio::spawn(async move {
let mut kv_manager =
SglangKvManager::new(total_tokens, config.page_size, kv_event_sink, dp_rank);
let mut waiting: VecDeque<SglangRequest> = VecDeque::new();
let mut running: Vec<SglangRequest> = Vec::new();
let mut new_token_ratio = config.init_new_token_ratio;
loop {
// 1. Receive requests
if receive_requests(&mut waiting, &mut request_rx, &cancel_token_clone, &running)
.await
.is_none()
{
break;
}
// 2. Apply scheduling policy
apply_schedule_policy(&mut waiting, &kv_manager, &config);
// 3. Admit new requests for prefill
let admit = get_new_batch_prefill(
&mut waiting,
&mut kv_manager,
&config,
new_token_ratio,
&running,
);
if admit.oom {
new_token_ratio = config.init_new_token_ratio;
}
// 4. Simulate prefill
simulate_prefill(admit.total_new_tokens, admit.can_run.len(), &config).await;
// Separate fully-prefilled from chunked requests
for mut req in admit.can_run {
if req.prefilled_tokens < req.token_ids.len() {
// Chunked prefill: cache partial sequence, put back in waiting
if let Some(last_node) = req.last_node {
let new_last = kv_manager.cache_unfinished_req(
&req.token_ids[..req.prefilled_tokens],
&req.kv_indices,
last_node,
);
req.last_node = Some(new_last);
}
waiting.push_front(req);
} else {
running.push(req);
}
}
// 5. Simulate decode (may retract requests under memory pressure)
let retracted = simulate_decode(
&mut running,
&mut kv_manager,
&output_tx,
&config,
dp_rank,
&metrics_tx,
)
.await;
if !retracted.is_empty() {
// Retracted requests go back to the front of the waiting queue
for req in retracted.into_iter().rev() {
waiting.push_front(req);
}
// Reset new_token_ratio like SGLang does after retraction
new_token_ratio = config.init_new_token_ratio;
}
// 6. Decay new_token_ratio
new_token_ratio = (new_token_ratio - config.new_token_ratio_decay_step)
.max(config.min_new_token_ratio);
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl super::SchedulerHandle for SglangScheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
waiting: &mut VecDeque<SglangRequest>,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
running: &[SglangRequest],
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if waiting.is_empty() && running.is_empty() {
// Fully idle — block until request or shutdown
tokio::select! {
biased;
_ = cancel_token.cancelled() => return None,
result = request_rx.recv() => {
let request = result?;
waiting.push_back(direct_to_sglang(request));
}
}
}
// Drain any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
waiting.push_back(direct_to_sglang(request));
}
Some(())
}
fn direct_to_sglang(req: DirectRequest) -> SglangRequest {
SglangRequest {
uuid: req.uuid.unwrap_or_else(Uuid::new_v4),
token_ids: req.tokens.iter().map(|&t| t as u64).collect(),
max_output_tokens: req.max_output_tokens,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
}
}
/// Reorder waiting queue based on scheduling policy.
fn apply_schedule_policy(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &SglangKvManager,
config: &SglangConfig,
) {
match config.schedule_policy {
SchedulePolicy::Fifo => {} // already in arrival order
SchedulePolicy::Lpm => {
if waiting.len() > LPM_FALLBACK_THRESHOLD {
return; // too expensive, fall back to FIFO
}
// Score each request by prefix match length (read-only, no mutation)
let mut scored: Vec<(usize, SglangRequest)> = waiting
.drain(..)
.map(|req| {
let prefix_len = kv_manager.cache().prefix_match_len(&req.token_ids);
(prefix_len, req)
})
.collect();
// Sort descending by prefix match length (stable sort preserves FIFO for ties)
scored.sort_by(|a, b| b.0.cmp(&a.0));
for (_, req) in scored {
waiting.push_back(req);
}
}
}
}
struct AdmitResult {
can_run: Vec<SglangRequest>,
/// Total new tokens to prefill (computed before prefilled_tokens is updated).
total_new_tokens: usize,
oom: bool,
}
/// Admit requests from waiting queue within budget constraints.
fn get_new_batch_prefill(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
new_token_ratio: f64,
running: &[SglangRequest],
) -> AdmitResult {
let cache = kv_manager.cache();
let reserved: f64 = running
.iter()
.map(|req| {
let remaining_output =
(req.max_output_tokens - req.output_len).min(config.clip_max_new_tokens);
remaining_output as f64 * new_token_ratio
})
.sum();
let mut rem_total_tokens = (cache.available_tokens() + cache.evictable_size) as f64 - reserved;
let mut rem_input_tokens = config.max_prefill_tokens as f64;
let mut rem_chunk_tokens = config.chunked_prefill_size as f64;
let mut can_run = Vec::new();
let mut rejected = VecDeque::new();
let mut oom = false;
let mut total_new_tokens: usize = 0;
while let Some(mut req) = waiting.pop_front() {
let extend_input = req.extend_input_len() as f64;
let total_needed = req.total_tokens_needed(config.clip_max_new_tokens) as f64;
// For chunked prefill: check against the chunk size, not the full input.
let effective_input = extend_input.min(config.chunked_prefill_size as f64);
if total_needed > rem_total_tokens || effective_input > rem_input_tokens {
rejected.push_back(req);
break;
}
// Keep previous chunk lock alive to protect cached prefix from eviction.
// Released after allocate_for_request secures its own lock.
let prev_node = req.last_node.take();
// Determine chunk boundary before allocation
let chunk_end = if extend_input > rem_chunk_tokens && rem_chunk_tokens > 0.0 {
let chunk = (rem_chunk_tokens as usize) / config.page_size * config.page_size;
if chunk > 0 {
req.prefilled_tokens + chunk
} else {
req.token_ids.len()
}
} else {
req.token_ids.len()
};
let alloc_tokens = &req.token_ids[..chunk_end];
let prefix_len = kv_manager.cache().prefix_match_len(alloc_tokens);
let needed_new = alloc_tokens.len() - prefix_len;
let available = kv_manager.cache().token_pool.available();
if available < needed_new {
kv_manager.evict(needed_new - available);
}
let alloc = kv_manager.allocate_for_request(alloc_tokens);
let Some(alloc) = alloc else {
// Restore lock on rejection so the cached prefix stays protected
req.last_node = prev_node;
rejected.push_back(req);
oom = true;
break;
};
// New allocation has its own lock; release the previous one
if let Some(node) = prev_node {
kv_manager.free_request(node);
}
req.last_node = Some(alloc.last_node);
req.kv_indices = alloc.kv_indices;
req.prefilled_tokens = chunk_end;
let actual_prefilled = (chunk_end - (req.token_ids.len() - extend_input as usize)) as f64;
// Only count cache-miss tokens for prefill timing (prefix hits skip compute)
let new_compute_tokens = chunk_end.saturating_sub(alloc.prefix_len);
total_new_tokens += new_compute_tokens;
rem_total_tokens -= total_needed;
rem_input_tokens -= actual_prefilled;
rem_chunk_tokens -= actual_prefilled;
can_run.push(req);
if rem_chunk_tokens <= 0.0 {
break;
}
}
while let Some(req) = rejected.pop_back() {
waiting.push_front(req);
}
AdmitResult {
can_run,
total_new_tokens,
oom,
}
}
async fn simulate_prefill(total_new_tokens: usize, num_reqs: usize, config: &SglangConfig) {
if num_reqs == 0 {
return;
}
if config.worker_type == WorkerType::Decode {
return;
}
let start = Instant::now();
let prefill_time = config.perf_model.predict_prefill_time(total_new_tokens);
let total_time = Duration::from_secs_f64(prefill_time / 1000.0);
if config.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / config.speedup_ratio);
sleep_until_precise(start + sleep_duration).await;
}
}
/// Check if the pool has enough tokens for one decode step of the entire batch.
/// Tries eviction first; if still short, retracts requests by output_len desc
/// (matching SGLang's retract_decode policy) until enough memory is available.
/// Returns retracted requests that should go back to the waiting queue.
fn check_decode_mem(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
) -> Vec<SglangRequest> {
let needed = running.len();
let available = kv_manager.cache().token_pool.available();
let evictable = kv_manager.cache().evictable_size;
if available + evictable >= needed {
// Evict just enough to cover the deficit
if available < needed {
kv_manager.evict(needed - available);
}
return Vec::new();
}
// Not enough even after full eviction — retract requests.
// Sort indices by output_len descending (longest-running first, like SGLang).
let mut sorted_indices: Vec<usize> = (0..running.len()).collect();
sorted_indices.sort_by(|&a, &b| running[b].output_len.cmp(&running[a].output_len));
let mut freed = 0usize;
while available + evictable + freed < sorted_indices.len() {
if sorted_indices.len() <= 1 {
break; // always keep at least one request
}
let idx = sorted_indices.pop().unwrap();
let req = &running[idx];
// Free this request's KV indices and radix lock
let kv_len = req.kv_indices.len();
kv_manager.cache_mut().token_pool.free(&req.kv_indices);
if let Some(last_node) = req.last_node {
kv_manager.free_request(last_node);
}
freed += kv_len;
// Mark index for removal (we'll collect in a second pass)
sorted_indices.retain(|&i| i != idx);
}
// Remove retracted requests from running (those NOT in sorted_indices).
let remaining_set: std::collections::HashSet<usize> = sorted_indices.into_iter().collect();
let mut remove_indices: Vec<usize> = (0..running.len())
.filter(|i| !remaining_set.contains(i))
.collect();
remove_indices.sort_unstable_by(|a, b| b.cmp(a));
let mut retracted = Vec::with_capacity(remove_indices.len());
for idx in remove_indices {
let mut req = running.swap_remove(idx);
// Reset decode state so it re-enters as a fresh prefill
req.output_len = 0;
req.kv_indices.clear();
req.last_node = None;
req.prefilled_tokens = 0;
retracted.push(req);
}
// Now evict to cover remaining deficit
let available = kv_manager.cache().token_pool.available();
let needed = running.len();
if available < needed {
kv_manager.evict(needed - available);
}
if !retracted.is_empty() {
tracing::warn!(
num_retracted = retracted.len(),
remaining = running.len(),
"SGLang decode retract requests because KV pool is full"
);
}
retracted
}
async fn simulate_decode(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
config: &SglangConfig,
dp_rank: u32,
metrics_tx: &tokio::sync::watch::Sender<MockerMetrics>,
) -> Vec<SglangRequest> {
if running.is_empty() {
return Vec::new();
}
let start = Instant::now();
let total_context: usize = running
.iter()
.map(|r| r.token_ids.len() + r.output_len)
.sum();
let avg_context = total_context / running.len();
let decode_time = config
.perf_model
.predict_decode_time(total_context, avg_context);
let total_time = Duration::from_secs_f64(decode_time / 1000.0);
// Retract requests if not enough memory for one decode step
let retracted = check_decode_mem(running, kv_manager);
for req in running.iter_mut() {
if kv_manager.cache().token_pool.available() == 0 {
kv_manager.evict(1);
}
let last_idx = req.kv_indices.last().copied();
if let Some(new_idx) = kv_manager.allocate_decode_token(last_idx) {
req.kv_indices.push(new_idx);
req.output_len += 1;
} else {
tracing::warn!(uuid = %req.uuid, "Failed to allocate decode token, skipping output");
}
}
// Send output signals and handle completions
let mut completed_indices = Vec::new();
for (i, req) in running.iter_mut().enumerate() {
let is_complete = req.output_len >= req.max_output_tokens;
if let Some(tx) = output_tx {
let _ = tx.send(OutputSignal {
uuid: req.uuid,
completed: is_complete,
});
}
if is_complete {
let mut all_tokens = req.token_ids.clone();
for j in 0..req.output_len {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
req.uuid.hash(&mut hasher);
j.hash(&mut hasher);
all_tokens.push(hasher.finish());
}
// Page-align and cap by available indices.
let aligned_tokens = (all_tokens.len() / config.page_size) * config.page_size;
let tokens_to_cache = aligned_tokens.min(req.kv_indices.len());
all_tokens.truncate(tokens_to_cache);
// Free excess token indices not covered by the cached sequence.
if req.kv_indices.len() > tokens_to_cache {
let excess = req.kv_indices[tokens_to_cache..].to_vec();
kv_manager.cache_mut().token_pool.free(&excess);
}
if let Some(last_node) = req.last_node {
if tokens_to_cache > 0 {
kv_manager.cache_finished_req(
&all_tokens,
&req.kv_indices[..tokens_to_cache],
last_node,
);
} else {
kv_manager.free_request(last_node);
}
}
completed_indices.push(i);
}
}
// Remove completed requests in reverse order so swap_remove doesn't
// invalidate pending indices (completed_indices is built in ascending order).
for &i in completed_indices.iter().rev() {
running.swap_remove(i);
}
// Publish metrics: active blocks from running requests' total context
let remaining_context: usize = running
.iter()
.map(|r| r.token_ids.len() + r.output_len)
.sum();
let active_blocks = remaining_context / config.page_size;
let _ = metrics_tx.send(MockerMetrics {
dp_rank,
active_decode_blocks: active_blocks as u64,
});
if config.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / config.speedup_ratio);
sleep_until_precise(start + sleep_duration).await;
}
retracted
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::SglangArgs;
use crate::scheduler::SchedulerHandle;
use rstest::rstest;
#[tokio::test]
async fn test_sglang_scheduler_fifo_ordering() {
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = SglangScheduler::new(args, 0, Some(output_tx), None, None);
let num_requests = 5;
let max_output = 3;
for i in 0..num_requests {
scheduler.receive(DirectRequest {
tokens: vec![i as u32; 10],
max_output_tokens: max_output,
uuid: None,
dp_rank: 0,
});
}
// Collect all output signals
let expected_signals = num_requests * max_output;
let mut received = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
if received >= expected_signals {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received, expected_signals,
"Expected {expected_signals} signals, got {received}"
);
}
#[tokio::test]
async fn test_sglang_scheduler_admission_budget() {
// Small pool — only enough for a few requests
let args = MockEngineArgs::builder()
.num_gpu_blocks(2) // 2 * 64 = 128 tokens
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = SglangScheduler::new(args, 0, Some(output_tx), None, None);
// Send requests that collectively exceed budget
for _ in 0..10 {
scheduler.receive(DirectRequest {
tokens: vec![1; 20],
max_output_tokens: 5,
uuid: None,
dp_rank: 0,
});
}
// Should still complete all eventually (as earlier ones finish, budget frees up)
let mut received = 0;
let timeout = tokio::time::sleep(Duration::from_secs(10));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
let expected = 10 * 5;
assert_eq!(
received, expected,
"Expected {expected} signals, got {received}"
);
}
#[test]
fn test_lpm_reorders_by_prefix_match() {
let mut kv_manager = SglangKvManager::new(1000, 1, None, 0);
// Seed cache with [1,2,3,4,5]
kv_manager
.cache_mut()
.insert(&[1, 2, 3, 4, 5], &[0, 1, 2, 3, 4]);
let config = SglangConfig {
schedule_policy: SchedulePolicy::Lpm,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let no_match_uuid = Uuid::new_v4();
let match_uuid = Uuid::new_v4();
let mut waiting: VecDeque<SglangRequest> = VecDeque::new();
// no_match first in FIFO order
waiting.push_back(SglangRequest {
uuid: no_match_uuid,
token_ids: vec![9, 8, 7],
max_output_tokens: 1,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
});
// match second in FIFO order
waiting.push_back(SglangRequest {
uuid: match_uuid,
token_ids: vec![1, 2, 3, 4, 5, 6, 7],
max_output_tokens: 1,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
});
apply_schedule_policy(&mut waiting, &kv_manager, &config);
// LPM should reorder: match (5-token prefix) before no_match (0-token)
assert_eq!(waiting[0].uuid, match_uuid);
assert_eq!(waiting[1].uuid, no_match_uuid);
}
#[test]
fn test_chunked_prefill_budget() {
let config = SglangConfig {
chunked_prefill_size: 10,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let mut kv_manager = SglangKvManager::new(10000, 1, None, 0);
let mut waiting: VecDeque<SglangRequest> = VecDeque::new();
waiting.push_back(SglangRequest {
uuid: Uuid::new_v4(),
token_ids: vec![1; 20], // 20 tokens > chunked_prefill_size=10
max_output_tokens: 3,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
});
let admit = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
assert_eq!(admit.can_run.len(), 1);
// Should only prefill 10 tokens (chunked_prefill_size), not all 20
assert_eq!(admit.can_run[0].prefilled_tokens, 10);
assert!(admit.can_run[0].prefilled_tokens < admit.can_run[0].token_ids.len());
}
#[test]
fn test_new_token_ratio_decay_and_oom_reset() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut ratio = config.init_new_token_ratio;
for _ in 0..600 {
ratio = (ratio - config.new_token_ratio_decay_step).max(config.min_new_token_ratio);
}
// After 600 steps, ratio should be at or near minimum
assert!(
(ratio - config.min_new_token_ratio).abs() < 0.01,
"ratio={ratio}, min={}",
config.min_new_token_ratio
);
// Simulate OOM reset
ratio = config.init_new_token_ratio;
assert!((ratio - 0.7).abs() < 0.001);
}
/// Stress test mirroring vLLM's `test_scheduler_token_generation_patterns`.
/// Sends 200 requests × 1000 input × 100 output under heavy eviction pressure
/// and parametrises across `(shared_tokens, schedule_policy, page_size)`.
#[rstest]
#[case::case_1(false, "fifo", 1)]
#[case::case_2(true, "fifo", 1)]
#[case::case_3(false, "lpm", 1)]
#[case::case_4(true, "lpm", 1)]
#[case::case_5(false, "fifo", 4)]
#[case::case_6(true, "fifo", 4)]
#[case::case_7(false, "lpm", 4)]
#[case::case_8(true, "lpm", 4)]
#[tokio::test]
async fn test_sglang_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] schedule_policy: &str,
#[case] page_size: usize,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.sglang(Some(SglangArgs {
schedule_policy: Some(schedule_policy.to_string()),
page_size: Some(page_size),
..Default::default()
}))
.build()
.unwrap();
let scheduler = SglangScheduler::new(args, 0, Some(output_tx), None, None);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment