Unverified Commit 930721c8 authored by Yongming Ding's avatar Yongming Ding Committed by GitHub
Browse files

feat(mocker): add SGLang engine simulation (#6977)


Signed-off-by: default avatarYongming Ding <yongmingd@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 22fb3398
......@@ -2091,6 +2091,7 @@ dependencies = [
"rstest 0.18.2",
"serde",
"serde_json",
"slotmap",
"tokio",
"tokio-timerfd",
"tokio-util",
......@@ -7362,6 +7363,15 @@ version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "slotmap"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038"
dependencies = [
"version_check",
]
[[package]]
name = "smallvec"
version = "1.15.1"
......
......@@ -126,6 +126,7 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
# - kv_bytes_per_token is auto-computed in main.py after model prefetch,
# - kv_cache_dtype is only used Python-side for the auto-computation.
"kv_transfer_bandwidth": getattr(args, "kv_transfer_bandwidth", None),
"engine_type": getattr(args, "engine_type", None),
}
# Parse --reasoning JSON string into a nested object
......@@ -133,6 +134,21 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
if reasoning_str:
engine_args["reasoning"] = json.loads(reasoning_str)
# Build nested sglang config from individual CLI flags
sglang_args = {
"schedule_policy": getattr(args, "sglang_schedule_policy", None),
"page_size": getattr(args, "sglang_page_size", None),
"max_prefill_tokens": getattr(args, "sglang_max_prefill_tokens", None),
"chunked_prefill_size": getattr(args, "sglang_chunked_prefill_size", None),
"clip_max_new_tokens": getattr(args, "sglang_clip_max_new_tokens", None),
"schedule_conservativeness": getattr(
args, "sglang_schedule_conservativeness", None
),
}
sglang_args = {k: v for k, v in sglang_args.items() if v is not None}
if sglang_args:
engine_args["sglang"] = sglang_args
# Remove None values to only include explicitly set arguments
engine_args = {k: v for k, v in engine_args.items() if v is not None}
......@@ -348,6 +364,54 @@ def parse_args() -> argparse.Namespace:
'Example: \'{"start_thinking_token_id": 123, "end_thinking_token_id": 456, "thinking_ratio": 0.6}\'',
)
# Engine type selection
parser.add_argument(
"--engine-type",
type=str,
default=None,
choices=["vllm", "sglang"],
help="Engine simulation type: 'vllm' (default) or 'sglang'.",
)
# SGLang-specific configuration
parser.add_argument(
"--sglang-schedule-policy",
type=str,
default=None,
choices=["fifo", "fcfs", "lpm"],
help="SGLang scheduling policy: 'fifo'/'fcfs' (default) or 'lpm' (longest prefix match).",
)
parser.add_argument(
"--sglang-page-size",
type=int,
default=None,
help="SGLang radix cache page size in tokens (default: 1).",
)
parser.add_argument(
"--sglang-max-prefill-tokens",
type=int,
default=None,
help="SGLang maximum prefill tokens budget per batch (default: 16384).",
)
parser.add_argument(
"--sglang-chunked-prefill-size",
type=int,
default=None,
help="SGLang chunked prefill size — max tokens per chunk (default: 8192).",
)
parser.add_argument(
"--sglang-clip-max-new-tokens",
type=int,
default=None,
help="SGLang clip max new tokens for admission budget (default: 4096).",
)
parser.add_argument(
"--sglang-schedule-conservativeness",
type=float,
default=None,
help="SGLang schedule conservativeness factor 0.0-1.0 (default: 1.0).",
)
# Legacy support - allow direct JSON file specification
parser.add_argument(
"--extra-engine-args",
......
......@@ -11,6 +11,7 @@ use dynamo_kv_router::protocols::WorkerWithDpRank;
use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest};
use dynamo_mocker::common::protocols::{DirectRequest, OutputSignal};
use dynamo_mocker::scheduler::Scheduler;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_tokens::SequenceHash;
use std::collections::HashMap;
use std::sync::Arc;
......@@ -170,27 +171,23 @@ async fn generate_sequence_events(
while i < worker_trace.len() {
let prev_i = i;
scheduler
.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
})
.await;
scheduler.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
});
i += 1;
while i < worker_trace.len()
&& worker_trace[i].timestamp == worker_trace[i - 1].timestamp
{
scheduler
.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
})
.await;
scheduler.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
});
i += 1;
}
......
......@@ -13,6 +13,7 @@ use dynamo_kv_router::protocols::{
pub use dynamo_kv_router::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
use dynamo_mocker::common::protocols::{DirectRequest, KvCacheEventSink, MockEngineArgs};
use dynamo_mocker::scheduler::Scheduler;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_tokens::compute_hash_v2;
use indicatif::{ProgressBar, ProgressStyle};
use plotters::prelude::*;
......@@ -367,27 +368,23 @@ pub async fn generate_kv_events(
while i < worker_trace.len() {
let prev_i = i;
scheduler
.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
})
.await;
scheduler.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
});
i += 1;
while i < worker_trace.len()
&& worker_trace[i].timestamp == worker_trace[i - 1].timestamp
{
scheduler
.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
})
.await;
scheduler.receive(DirectRequest {
tokens: tokens_from_request(&worker_trace[i], block_size),
max_output_tokens: worker_trace[i].output_length as usize,
uuid: Some(worker_trace[i].uuid),
dp_rank: 0,
});
i += 1;
}
......
......@@ -1715,6 +1715,7 @@ dependencies = [
"rand 0.9.2",
"serde",
"serde_json",
"slotmap",
"tokio",
"tokio-timerfd",
"tokio-util",
......@@ -6324,6 +6325,15 @@ version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "slotmap"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038"
dependencies = [
"version_check",
]
[[package]]
name = "smallvec"
version = "1.15.1"
......
......@@ -1731,6 +1731,7 @@ dependencies = [
"rand 0.9.2",
"serde",
"serde_json",
"slotmap",
"tokio",
"tokio-timerfd",
"tokio-util",
......@@ -6391,6 +6392,15 @@ version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "slotmap"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038"
dependencies = [
"version_check",
]
[[package]]
name = "smallvec"
version = "1.15.1"
......
......@@ -23,7 +23,8 @@ use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal,
};
use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise};
use dynamo_mocker::scheduler::Scheduler;
use dynamo_mocker::engine::create_engine;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{
......@@ -307,7 +308,7 @@ fn generate_random_token() -> TokenIdType {
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
pub struct MockVllmEngine {
pub struct MockEngine {
active_requests: Arc<DashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>,
request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify,
......@@ -315,11 +316,11 @@ pub struct MockVllmEngine {
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
/// Keep schedulers alive so their CancelGuards don't fire prematurely.
_schedulers: OnceCell<Vec<Scheduler>>,
_schedulers: OnceCell<Vec<Box<dyn SchedulerHandle>>>,
}
impl MockVllmEngine {
/// Create a new MockVllmEngine with the given parameters
impl MockEngine {
/// Create a new MockEngine with the given parameters
pub fn new(engine_args: MockEngineArgs) -> Self {
Self {
active_requests: Arc::new(DashMap::new()),
......@@ -404,9 +405,9 @@ impl MockVllmEngine {
&self,
component: Option<&Component>,
cancel_token: CancellationToken,
) -> Vec<Scheduler> {
) -> Vec<Box<dyn SchedulerHandle>> {
let args = &self.engine_args;
let mut schedulers = Vec::<Scheduler>::new();
let mut schedulers = Vec::<Box<dyn SchedulerHandle>>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
for dp_rank in 0..args.dp_size {
......@@ -485,7 +486,7 @@ impl MockVllmEngine {
None => (None, None),
};
let scheduler = Scheduler::new(
let scheduler = create_engine(
args.clone(),
dp_rank,
Some(output_tx),
......@@ -536,7 +537,7 @@ impl MockVllmEngine {
/// Start background tasks to publish metrics on change
async fn start_metrics_publishing(
schedulers: &[Scheduler],
schedulers: &[Box<dyn SchedulerHandle>],
component: Component,
cancel_token: CancellationToken,
) -> Result<()> {
......@@ -579,9 +580,7 @@ impl MockVllmEngine {
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
for MockVllmEngine
{
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> for MockEngine {
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
......@@ -744,12 +743,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
}
pub struct AnnotatedMockEngine {
inner: Arc<MockVllmEngine>,
inner: Arc<MockEngine>,
}
impl AnnotatedMockEngine {
pub fn new(
inner: MockVllmEngine,
inner: MockEngine,
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
) -> Self {
......@@ -818,7 +817,7 @@ pub async fn make_mocker_engine(
// Create the mocker engine
tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
AnnotatedMockEngine::new(MockEngine::new(args), distributed_runtime, endpoint_id);
Ok(Arc::new(annotated_engine))
}
......@@ -32,6 +32,7 @@ validator = { workspace = true }
# crate-specific
ndarray = "0.16"
slotmap = "1"
ndarray-npy = "0.9"
ndarray-interp = "0.5"
......
......@@ -4,5 +4,7 @@
//! Cache data structures for KV block management.
pub mod hash_cache;
pub mod radix_cache;
pub use hash_cache::HashCache;
pub use radix_cache::RadixCache;
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared KV cache trace logging for both vLLM and SGLang backends.
//!
//! Enabled by setting `DYN_MOCKER_KV_CACHE_TRACE=1` or `true`.
use dynamo_runtime::config::environment_names::mocker;
use std::env;
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
/// Check the env var to enable KV cache allocation/eviction trace logs.
pub static KV_CACHE_TRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
env::var(mocker::DYN_MOCKER_KV_CACHE_TRACE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
});
fn timestamp_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Log a vLLM KV cache trace event.
pub fn log_vllm_trace(
event: &str,
dp_rank: u32,
block_size: usize,
active_blocks: usize,
inactive_blocks: usize,
total_blocks: usize,
) {
if !*KV_CACHE_TRACE_ENABLED {
return;
}
let free_blocks = total_blocks
.saturating_sub(active_blocks)
.saturating_sub(inactive_blocks);
let utilization = if total_blocks > 0 {
(active_blocks + inactive_blocks) as f64 / total_blocks as f64
} else {
0.0
};
tracing::info!(
engine_type = "vllm",
event,
timestamp_ms = timestamp_ms(),
dp_rank,
block_size,
free_blocks,
active_blocks,
inactive_blocks,
total_blocks,
utilization,
"KV cache trace"
);
}
/// SGLang cache state snapshot for trace logging.
pub struct SglangCacheState<'a> {
pub event: &'a str,
pub dp_rank: u32,
pub num_tokens: usize,
pub page_size: usize,
pub available_tokens: usize,
pub evictable_tokens: usize,
pub protected_tokens: usize,
pub total_tokens: usize,
}
/// Log an SGLang KV cache trace event.
pub fn log_sglang_trace(state: &SglangCacheState) {
if !*KV_CACHE_TRACE_ENABLED {
return;
}
let utilization = if state.total_tokens > 0 {
(state.total_tokens - state.available_tokens) as f64 / state.total_tokens as f64
} else {
0.0
};
tracing::info!(
engine_type = "sglang",
event = state.event,
timestamp_ms = timestamp_ms(),
dp_rank = state.dp_rank,
num_tokens = state.num_tokens,
page_size = state.page_size,
available_tokens = state.available_tokens,
evictable_tokens = state.evictable_tokens,
protected_tokens = state.protected_tokens,
total_tokens = state.total_tokens,
utilization,
"KV cache trace"
);
}
......@@ -5,6 +5,7 @@
pub mod bootstrap;
pub mod evictor;
pub mod kv_cache_trace;
pub mod perf_model;
pub mod protocols;
pub mod running_mean;
......
......@@ -90,6 +90,16 @@ pub enum PreemptionMode {
Fifo,
}
/// Engine type for selecting scheduling and KV cache simulation behavior
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EngineType {
/// vLLM-style scheduling with hash-based block KV cache
#[default]
Vllm,
/// SGLang-style scheduling with radix-tree KV cache
Sglang,
}
/// Worker type for disaggregated serving configurations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
......@@ -134,10 +144,39 @@ impl ReasoningConfig {
}
}
/// Configuration arguments for MockVllmEngine
/// SGLang-specific configuration parameters.
///
/// Grouped into a nested struct to keep the `MockEngineArgs` namespace clean,
/// following the same pattern as [`ReasoningConfig`].
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Default)]
pub struct SglangArgs {
/// Scheduling policy: "fifo"/"fcfs" or "lpm". Default: "fifo".
pub schedule_policy: Option<String>,
/// Radix cache page size in tokens. Default: 1.
#[validate(range(min = 1))]
pub page_size: Option<usize>,
/// Maximum prefill tokens budget per batch. Default: 16384.
#[validate(range(min = 1))]
pub max_prefill_tokens: Option<usize>,
/// Chunked prefill size (max tokens per chunk). Default: 8192.
#[validate(range(min = 1))]
pub chunked_prefill_size: Option<usize>,
/// Clip max new tokens for admission budget. Default: 4096.
#[validate(range(min = 1))]
pub clip_max_new_tokens: Option<usize>,
/// Schedule conservativeness factor (0.0–1.0). Default: 1.0.
#[validate(range(min = 0.0, max = 1.0))]
pub schedule_conservativeness: Option<f64>,
}
/// Configuration arguments for MockEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
/// Engine type: vLLM or SGLang simulation
#[builder(default = "EngineType::Vllm")]
pub engine_type: EngineType,
#[builder(default = "16384")]
#[validate(range(min = 1))]
pub num_gpu_blocks: usize,
......@@ -236,6 +275,10 @@ pub struct MockEngineArgs {
/// Lifo (default) evicts the newest request; Fifo evicts the oldest.
#[builder(default)]
pub preemption_mode: PreemptionMode,
/// SGLang-specific configuration. Only used when `engine_type == Sglang`.
#[builder(default = "None")]
pub sglang: Option<SglangArgs>,
}
impl Default for MockEngineArgs {
......@@ -273,6 +316,7 @@ impl MockEngineArgs {
// Define valid field names
let valid_fields: HashSet<&str> = [
"engine_type",
"num_gpu_blocks",
"block_size",
"max_num_seqs",
......@@ -294,6 +338,7 @@ impl MockEngineArgs {
"zmq_kv_events_port",
"zmq_replay_port",
"preemption_mode",
"sglang",
]
.iter()
.cloned()
......@@ -315,6 +360,22 @@ impl MockEngineArgs {
}
// Apply each extra argument to the builder
if let Some(value) = extra_args.get("engine_type")
&& let Some(s) = value.as_str()
{
let engine_type = match s {
"vllm" => EngineType::Vllm,
"sglang" => EngineType::Sglang,
other => {
return Err(anyhow::anyhow!(
"Invalid engine_type '{}'. Must be 'vllm' or 'sglang'.",
other
));
}
};
builder = builder.engine_type(engine_type);
}
if let Some(value) = extra_args.get("num_gpu_blocks")
&& let Some(num) = value.as_u64()
{
......@@ -433,6 +494,12 @@ impl MockEngineArgs {
builder = builder.preemption_mode(mode);
}
if let Some(value) = extra_args.get("sglang") {
let cfg: SglangArgs = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse sglang config: {}", e))?;
builder = builder.sglang(Some(cfg));
}
// Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args
.get("is_prefill")
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Engine factory — creates the appropriate scheduler based on [`EngineType`].
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{EngineType, KvCacheEventSink, MockEngineArgs, OutputSignal};
use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
/// Create a scheduler for the configured engine type.
///
/// Returns a boxed [`SchedulerHandle`] that the engine wrapper can use
/// without knowing which backend is running underneath.
pub fn create_engine(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Box<dyn SchedulerHandle> {
match args.engine_type {
EngineType::Vllm => Box::new(Scheduler::new(
args,
dp_rank,
output_tx,
kv_event_sink,
cancellation_token,
)),
EngineType::Sglang => Box::new(SglangScheduler::new(
args,
dp_rank,
output_tx,
kv_event_sink,
cancellation_token,
)),
}
}
......@@ -3,6 +3,8 @@
//! Pluggable KV cache block managers.
pub mod sglang_backend;
pub mod vllm_backend;
pub use sglang_backend::SglangKvManager;
pub use vllm_backend::KvManager;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SGLang KV manager — wraps [`RadixCache`] with request-level lifecycle
//! operations and KV event publishing.
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::cache::radix_cache::{NodeId, RadixCache};
use crate::common::kv_cache_trace;
use crate::common::protocols::KvCacheEventSink;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData,
};
/// Result of `allocate_for_request`.
pub struct AllocResult {
/// Number of tokens matched from the prefix cache.
pub prefix_len: usize,
/// Pool token indices for the allocated input (1 per token).
pub kv_indices: Vec<usize>,
/// The deepest matched node in the radix tree (used for lock/unlock).
/// This is the prefix match point, not the new tokens — new tokens are
/// only in kv_indices and get inserted into the tree on completion.
pub last_node: NodeId,
}
pub struct SglangKvManager {
cache: RadixCache,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
dp_rank: u32,
next_event_id: u64,
/// Maps pool_idx → block_hash assigned during Stored events,
/// so Removed events can use the same block_hash.
idx_to_block_hash: HashMap<usize, ExternalSequenceBlockHash>,
}
impl SglangKvManager {
pub fn new(
total_tokens: usize,
page_size: usize,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
dp_rank: u32,
) -> Self {
Self {
cache: RadixCache::new(total_tokens, page_size),
kv_event_sink,
dp_rank,
next_event_id: 0,
idx_to_block_hash: HashMap::new(),
}
}
pub fn cache(&self) -> &RadixCache {
&self.cache
}
pub fn cache_mut(&mut self) -> &mut RadixCache {
&mut self.cache
}
/// Try to allocate KV cache for a new request.
/// Returns `None` if the pool doesn't have enough token slots (OOM).
pub fn allocate_for_request(&mut self, token_ids: &[u64]) -> Option<AllocResult> {
let (prefix_len, last_node) = self.cache.match_prefix(token_ids);
let new_tokens = token_ids.len() - prefix_len;
let prefix_indices = self.collect_path_indices(last_node);
let new_indices = self.cache.token_pool.allocate(new_tokens)?;
let mut kv_indices = prefix_indices;
kv_indices.extend_from_slice(&new_indices);
self.cache.inc_lock_ref(last_node);
// Chain from prefix's last block_hash (if any)
let parent_hash = kv_indices
.get(prefix_len.wrapping_sub(1))
.and_then(|&idx| self.idx_to_block_hash.get(&idx).copied());
self.publish_stored_event(&token_ids[prefix_len..], &new_indices, parent_hash);
self.log_trace("allocation", new_tokens);
Some(AllocResult {
prefix_len,
kv_indices,
last_node,
})
}
/// Cache a completed request's full sequence into the radix tree.
///
/// Inserts the full token sequence so future requests can reuse it,
/// then unlocks the path.
pub fn cache_finished_req(
&mut self,
token_ids: &[u64],
kv_indices: &[usize],
last_node: NodeId,
) {
self.cache.insert(token_ids, kv_indices);
self.cache.dec_lock_ref(last_node);
}
/// Cache a partial sequence after a chunked prefill step.
///
/// Inserts the partial sequence, then transfers the lock from the old
/// path to the new (extended) path. The request is still active, so the
/// new deepest node stays locked.
///
/// Returns the new `last_node` that the caller should use for
/// subsequent calls.
pub fn cache_unfinished_req(
&mut self,
token_ids: &[u64],
kv_indices: &[usize],
last_node: NodeId,
) -> NodeId {
self.cache.insert(token_ids, kv_indices);
// Find the new deepest node after insert
let (_, new_last_node) = self.cache.match_prefix(token_ids);
// Transfer lock: release old path, protect new path
self.cache.dec_lock_ref(last_node);
self.cache.inc_lock_ref(new_last_node);
new_last_node
}
/// Allocate a single token slot for decode output and publish a BlockStored event.
/// `last_idx` is the request's previous pool index for chaining block_hash.
pub fn allocate_decode_token(&mut self, last_idx: Option<usize>) -> Option<usize> {
let indices = self.cache.token_pool.allocate(1)?;
let idx = indices[0];
let parent_hash = last_idx.and_then(|i| self.idx_to_block_hash.get(&i).copied());
self.publish_stored_event(&[], &[idx], parent_hash);
self.log_trace("allocation", 1);
Some(idx)
}
/// Free a request without caching (e.g., aborted request).
///
/// Unlocks the path without inserting into the tree.
pub fn free_request(&mut self, last_node: NodeId) {
self.cache.dec_lock_ref(last_node);
}
/// Collect token indices from the matched prefix path by walking root→last_node.
fn collect_path_indices(&self, last_node: NodeId) -> Vec<usize> {
if last_node == self.cache.root() {
return Vec::new();
}
// Walk from last_node to root, collecting node IDs
let mut path = Vec::new();
let mut current = last_node;
loop {
let node = self.cache.node(current);
if node.parent.is_none() {
break;
}
path.push(current);
current = node.parent.unwrap();
}
path.reverse();
// Collect token indices from each node's value
let mut indices = Vec::new();
for node_id in path {
indices.extend_from_slice(&self.cache.node(node_id).value);
}
indices
}
/// Evict tokens from the cache, publish BlockRemoved events, and log a trace.
pub fn evict(&mut self, num_tokens: usize) {
let (evicted, evicted_indices) = self.cache.evict(num_tokens);
if !evicted_indices.is_empty() {
self.publish_removed_event(&evicted_indices);
}
self.log_trace("eviction", evicted);
}
fn log_trace(&self, event: &str, num_tokens: usize) {
kv_cache_trace::log_sglang_trace(&kv_cache_trace::SglangCacheState {
event,
dp_rank: self.dp_rank,
num_tokens,
page_size: self.cache.page_size(),
available_tokens: self.cache.available_tokens(),
evictable_tokens: self.cache.evictable_size,
protected_tokens: self.cache.protected_size,
total_tokens: self.cache.total_tokens(),
});
}
fn publish_stored_event(
&mut self,
token_ids: &[u64],
indices: &[usize],
parent_hash: Option<ExternalSequenceBlockHash>,
) {
if indices.is_empty() {
return;
}
let Some(ref sink) = self.kv_event_sink else {
return;
};
let mut blocks = Vec::with_capacity(indices.len());
let mut running_hash = parent_hash.map_or(0u64, |h| h.0);
for (i, &idx) in indices.iter().enumerate() {
// tokens_hash: per-token content hash for router prefix matching
let token_bytes: Vec<u8> = token_ids
.get(i)
.unwrap_or(&(idx as u64))
.to_le_bytes()
.to_vec();
let tokens_hash = dynamo_kv_router::protocols::compute_block_hash(&token_bytes);
// block_hash: cumulative hash (parent_hash, token_id) so it's unique
// per position and uniform across workers with the same token sequence.
let mut hasher = DefaultHasher::new();
running_hash.hash(&mut hasher);
tokens_hash.0.hash(&mut hasher);
running_hash = hasher.finish();
let block_hash = ExternalSequenceBlockHash(running_hash);
self.idx_to_block_hash.insert(idx, block_hash);
blocks.push(KvCacheStoredBlockData {
block_hash,
tokens_hash,
mm_extra_info: None,
});
}
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks,
}),
dp_rank: self.dp_rank,
};
self.next_event_id += 1;
if let Err(e) = sink.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV event: {e}");
}
}
fn publish_removed_event(&mut self, evicted_indices: &[usize]) {
let Some(ref sink) = self.kv_event_sink else {
return;
};
let block_hashes: Vec<ExternalSequenceBlockHash> = evicted_indices
.iter()
.filter_map(|&idx| self.idx_to_block_hash.remove(&idx))
.collect();
if block_hashes.is_empty() {
return;
}
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
};
self.next_event_id += 1;
if let Err(e) = sink.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV remove event: {e}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct MockSink {
events: Mutex<Vec<KvCacheEvent>>,
}
impl MockSink {
fn new() -> Self {
Self {
events: Mutex::new(Vec::new()),
}
}
fn event_count(&self) -> usize {
self.events.lock().unwrap().len()
}
}
impl KvCacheEventSink for MockSink {
fn publish(
&self,
event: KvCacheEvent,
_block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
}
#[test]
fn test_allocate_cache_miss() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let result = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(result.prefix_len, 0);
assert_eq!(result.kv_indices.len(), 5);
assert_eq!(mgr.cache().token_pool.available(), 95);
}
#[test]
fn test_allocate_cache_hit() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
// First request: allocate and cache
let r1 = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(r1.kv_indices.len(), 5); // 5 pages (page_size=1)
mgr.cache_finished_req(&[1, 2, 3, 4, 5], &r1.kv_indices, r1.last_node);
// Second request with shared prefix
let r2 = mgr.allocate_for_request(&[1, 2, 3, 4, 5, 6, 7]).unwrap();
assert_eq!(r2.prefix_len, 5);
assert_eq!(r2.kv_indices.len(), 7); // 5 reused + 2 new pages
assert_eq!(mgr.cache().token_pool.available(), 93); // 100 - 5 - 2
}
#[test]
fn test_free_request_without_caching() {
let mut mgr = SglangKvManager::new(100, 1, None, 0);
let result = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
mgr.free_request(result.last_node);
// Path is unlocked, tokens still allocated in pool
assert_eq!(mgr.cache().protected_size, 0);
}
#[test]
fn test_event_publishing() {
let sink = Arc::new(MockSink::new());
let mut mgr = SglangKvManager::new(100, 1, Some(sink.clone()), 0);
let r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(sink.event_count(), 1); // BlockStored for 3 new pages
mgr.cache_finished_req(&[1, 2, 3], &r.kv_indices, r.last_node);
// Second request with full cache hit → no new events
let r2 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(r2.prefix_len, 3);
assert_eq!(sink.event_count(), 1); // no new event
}
#[test]
fn test_allocate_oom() {
let mut mgr = SglangKvManager::new(3, 1, None, 0);
let _r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
// Pool is full
let result = mgr.allocate_for_request(&[4, 5, 6]);
assert!(result.is_none());
}
}
......@@ -35,26 +35,17 @@
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager.
use crate::cache::HashCache;
use crate::common::kv_cache_trace;
use crate::common::protocols::{KvCacheEventSink, MoveBlock, PrefillCost};
use crate::common::sequence::ActiveSequence;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use dynamo_runtime::config::environment_names::mocker;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap;
use std::env;
use std::sync::{Arc, LazyLock};
use std::time::{SystemTime, UNIX_EPOCH};
/// Check the env var to enable KV cache allocation/eviction trace logs.
static KV_CACHE_TRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
env::var(mocker::DYN_MOCKER_KV_CACHE_TRACE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
});
use std::sync::Arc;
pub struct KvManager {
cache: HashCache,
......@@ -104,32 +95,14 @@ impl KvManager {
return;
}
if *KV_CACHE_TRACE_ENABLED {
let active_len = self.cache.num_active();
let inactive_len = self.cache.num_inactive();
let free_blocks = self
.cache
.max_capacity()
.saturating_sub(active_len)
.saturating_sub(inactive_len);
let event = if is_store { "allocation" } else { "eviction" };
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
tracing::info!(
event,
timestamp_ms,
block_ids = ?&full_blocks,
block_size = self.block_size,
free_blocks_after = free_blocks,
active_blocks = active_len,
inactive_blocks = inactive_len,
total_blocks = self.cache.max_capacity(),
dp_rank = self.dp_rank,
"KV cache trace"
);
}
kv_cache_trace::log_vllm_trace(
if is_store { "allocation" } else { "eviction" },
self.dp_rank,
self.block_size,
self.cache.num_active(),
self.cache.num_inactive(),
self.cache.max_capacity(),
);
let Some(ref sink) = self.kv_event_sink else {
return;
......
......@@ -9,5 +9,6 @@
pub mod cache;
pub mod common;
pub mod engine;
pub mod kv_manager;
pub mod scheduler;
......@@ -3,7 +3,112 @@
//! Engine-specific scheduling implementations.
pub mod sglang;
pub mod vllm;
// Backward compatibility: re-export Scheduler from vllm module
pub use vllm::Scheduler;
use crate::common::protocols::DirectRequest;
use tokio::sync::mpsc;
pub use sglang::SglangScheduler;
pub use vllm::{MockerMetrics, Scheduler};
/// Engine-agnostic scheduler interface.
///
/// Both vLLM and SGLang schedulers implement this trait so that the engine
/// wrapper (`MockEngine`) can work with either backend through the same API.
pub trait SchedulerHandle: Send + Sync {
/// Send a request to the scheduler's waiting queue.
fn receive(&self, request: DirectRequest);
/// Get a clone of the request sender channel for direct sending.
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest>;
/// Get a watch receiver for scheduler metrics (active decode blocks, etc.).
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics>;
}
/// Shared test utilities for scheduler stress tests.
#[cfg(test)]
pub(crate) mod test_utils {
use super::*;
use crate::common::protocols::OutputSignal;
use tokio::time::Duration;
/// Send `num_requests` to a scheduler, collect all output signals, and assert
/// that the scheduler produces exactly `num_requests * max_output_tokens` signals
/// and returns to idle (0 active decode blocks).
///
/// When `use_shared_tokens` is true, the first half of each request shares a
/// common prefix to exercise prefix caching / radix tree reuse.
pub async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
scheduler.receive(DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received_tokens, expected_tokens,
"Expected {expected_tokens} output signals, got {received_tokens}"
);
// Verify scheduler returns to idle
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_eq!(
metrics.active_decode_blocks, 0,
"Scheduler should be idle after all requests complete, got {} active blocks",
metrics.active_decode_blocks
);
}
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment