"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "d22d9e761e6e9a569491654eea5fa439d3904601"
Unverified Commit 91fb78cd authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: make mocker it's own crate (#5958)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a72f41f6
......@@ -2347,6 +2347,7 @@ dependencies = [
"dynamo-async-openai",
"dynamo-kv-router",
"dynamo-memory",
"dynamo-mocker",
"dynamo-parsers",
"dynamo-runtime",
"dynamo-tokens",
......@@ -2442,6 +2443,29 @@ dependencies = [
"tracing",
]
[[package]]
name = "dynamo-mocker"
version = "0.9.0"
dependencies = [
"anyhow",
"dashmap 6.1.0",
"derive-getters",
"derive_builder",
"dynamo-kv-router",
"dynamo-tokens",
"ndarray",
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
"rstest 0.18.2",
"serde",
"serde_json",
"tokio",
"tokio-util",
"tracing",
"uuid 1.18.1",
]
[[package]]
name = "dynamo-parsers"
version = "0.9.0"
......
......@@ -8,6 +8,7 @@ members = [
"lib/runtime",
"lib/config",
"lib/tokens",
"lib/mocker",
"lib/kv-router",
"lib/memory",
"lib/async-openai",
......@@ -26,6 +27,7 @@ default-members = [
"lib/runtime",
"lib/config",
"lib/tokens",
"lib/mocker",
"lib/memory",
"lib/async-openai",
"lib/parsers",
......@@ -50,6 +52,7 @@ dynamo-llm = { path = "lib/llm", version = "0.9.0" }
dynamo-config = { path = "lib/config", version = "0.9.0" }
dynamo-tokens = { path = "lib/tokens", version = "0.9.0" }
dynamo-memory = { path = "lib/memory", version = "0.9.0" }
dynamo-mocker = { path = "lib/mocker", version = "0.9.0" }
dynamo-kv-router = { path = "lib/kv-router", version = "0.9.0", features = ["metrics"] }
dynamo-async-openai = { path = "lib/async-openai", version = "0.9.0", features = ["byot"] }
dynamo-parsers = { path = "lib/parsers", version = "0.9.0" }
......
......@@ -164,8 +164,7 @@ async fn engine_for(
let args = flags.mocker_config();
let endpoint = local_model.endpoint_id().clone();
let engine =
dynamo_llm::mocker::engine::make_mocker_engine(drt, endpoint, args).await?;
let engine = dynamo_llm::mocker::make_mocker_engine(drt, endpoint, args).await?;
Ok(EngineConfig::InProcessTokens {
engine,
......
......@@ -1632,6 +1632,7 @@ dependencies = [
"ahash",
"aho-corasick",
"akin",
"aligned-vec",
"anyhow",
"async-nats",
"async-stream",
......@@ -1648,6 +1649,7 @@ dependencies = [
"bytes",
"candle-core",
"chrono",
"cudarc",
"dashmap 5.5.3",
"derive-getters",
"derive_builder",
......@@ -1655,6 +1657,7 @@ dependencies = [
"dynamo-async-openai",
"dynamo-kv-router",
"dynamo-memory",
"dynamo-mocker",
"dynamo-parsers",
"dynamo-runtime",
"dynamo-tokens",
......@@ -1681,6 +1684,7 @@ dependencies = [
"ndarray",
"ndarray-interp",
"ndarray-npy",
"nix 0.26.4",
"nixl-sys",
"object_store",
"offset-allocator",
......@@ -1739,6 +1743,28 @@ dependencies = [
"tracing",
]
[[package]]
name = "dynamo-mocker"
version = "0.9.0"
dependencies = [
"anyhow",
"dashmap 6.1.0",
"derive-getters",
"derive_builder",
"dynamo-kv-router",
"dynamo-tokens",
"ndarray",
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
"serde",
"serde_json",
"tokio",
"tokio-util",
"tracing",
"uuid",
]
[[package]]
name = "dynamo-parsers"
version = "0.9.0"
......@@ -3913,6 +3939,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.1"
......@@ -4206,6 +4241,19 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.4",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]]
name = "nix"
version = "0.29.0"
......@@ -5452,7 +5500,7 @@ dependencies = [
"cfg-if 1.0.4",
"indoc",
"libc",
"memoffset",
"memoffset 0.9.1",
"once_cell",
"portable-atomic",
"pyo3-build-config",
......
......@@ -402,7 +402,7 @@ async fn select_engine(
let endpoint = local_model.endpoint_id().clone();
let engine = dynamo_llm::mocker::engine::make_mocker_engine(
let engine = dynamo_llm::mocker::make_mocker_engine(
distributed_runtime.inner,
endpoint,
mocker_args,
......
......@@ -50,6 +50,7 @@ dynamo-runtime = { workspace = true }
dynamo-tokens = { workspace = true }
dynamo-kv-router = { workspace = true, features = ["metrics"] }
dynamo-memory = { workspace = true }
dynamo-mocker = { workspace = true }
# workspace
aho-corasick = "1.1"
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler
//!
//! This module provides an AsyncEngine implementation that wraps the Scheduler
//! to provide streaming token generation with realistic timing simulation.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::StreamExt;
use rand::Rng;
use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{
component::Component,
engine::AsyncEngineContextProvider,
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
};
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::bootstrap::{BootstrapServer, connect_to_prefill};
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
pub const MOCKER_COMPONENT: &str = "mocker";
fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng();
rng.random_range(1000..2000)
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
}
impl MockVllmEngine {
/// Create a new MockVllmEngine with the given parameters
pub fn new(args: MockEngineArgs) -> Self {
Self {
active_requests: Arc::new(Mutex::new(HashMap::new())),
request_senders: Arc::new(OnceCell::new()),
engine_args: args,
bootstrap_server: Arc::new(OnceCell::new()),
}
}
pub async fn start(&self, component: Component) -> Result<()> {
// Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
// child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1.
// primary_token() is only cancelled in Phase 3, after waiting for inflight requests.
let cancel_token = component.drt().primary_token();
// Simulate engine startup time if configured
if let Some(startup_time_secs) = self.engine_args.startup_time {
tracing::info!("Simulating engine startup time: {:.2}s", startup_time_secs);
tokio::time::sleep(Duration::from_secs_f64(startup_time_secs)).await;
tracing::info!("Engine startup simulation completed");
}
// Start bootstrap server for prefill workers in disaggregated mode
if self.engine_args.worker_type == WorkerType::Prefill
&& let Some(port) = self.engine_args.bootstrap_port
{
let server = BootstrapServer::start(port, cancel_token.clone()).await?;
let _ = self.bootstrap_server.set(server);
tracing::info!(port = port, "Bootstrap server started for prefill worker");
}
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode
{
Some(component.clone())
} else {
None
};
let schedulers = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
scheduler_component,
cancel_token.clone(),
);
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?;
Ok(())
}
pub fn direct(&self, request: DirectRequest, dp_rank: usize) {
let senders = self.request_senders.get().expect("Not initialized");
let _ = senders[dp_rank].send(request);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
fn start_schedulers(
&self,
args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
cancel_token: CancellationToken,
) -> Vec<Scheduler> {
let mut schedulers = Vec::<Scheduler>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks
for dp_rank in 0..args.dp_size {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = Scheduler::new(
args.clone(),
dp_rank,
Some(output_tx),
component.clone(),
Some(cancel_token.clone()),
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
let active_requests_clone = active_requests.clone();
let cancel_token_cloned = cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
signal_result = output_rx.recv() => {
let Some(signal) = signal_result else {
break; // Channel closed
};
// Notify the specific request that a token was generated
let active = active_requests_clone.lock().await;
if let Some(request_tx) = active.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
}
_ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
break;
}
}
}
});
}
// Set the senders once
self.request_senders
.set(senders)
.expect("Already initialized");
schedulers
}
/// Start background tasks to publish metrics on change
async fn start_metrics_publishing(
schedulers: &[Scheduler],
component: Option<Component>,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::debug!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::debug!("Metrics publisher created");
if let Some(comp) = component {
tracing::debug!("Creating metrics endpoint");
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
});
// Give it a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
tracing::debug!("Metrics endpoint started (background)");
}
tracing::debug!("Starting metrics background tasks");
for scheduler in schedulers.iter() {
let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone();
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
// Watch for metrics changes
Ok(_) = metrics_rx.changed() => {
// Get the latest metrics
let metrics = metrics_rx.borrow().clone();
// Publish metrics using flat API
if let Err(e) = publisher.publish(Some(metrics.dp_rank), metrics.active_decode_blocks) {
tracing::warn!("Failed to publish metrics for DP rank {}: {e}", metrics.dp_rank);
} else {
tracing::trace!("Published metrics for DP rank {}", metrics.dp_rank);
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Metrics publishing cancelled");
break;
}
}
}
});
}
tracing::info!("Metrics background tasks started");
Ok(())
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
for MockVllmEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts();
// Extract dp_rank from routing hints (defaults to 0 if not set)
let dp_rank = request
.routing
.as_ref()
.and_then(|r| r.dp_rank)
.unwrap_or(0);
// Validate dp_rank
if dp_rank >= self.engine_args.dp_size {
return Err(Error::msg(format!(
"dp_rank {} is out of bounds for dp_size {}",
dp_rank, self.engine_args.dp_size
)));
}
// Bootstrap rendezvous for disaggregated serving
// - Decode: connect to prefill's server, block until prefill completes
// - Prefill: complete_room() is called after first token (see below)
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.worker_type == WorkerType::Decode
{
connect_to_prefill(
&bootstrap_info.bootstrap_host,
bootstrap_info.bootstrap_port,
bootstrap_info.bootstrap_room,
)
.await
.map_err(|e| Error::msg(format!("Bootstrap connection failed: {e}")))?;
}
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
let max_output_tokens = if is_prefill {
1
} else {
request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker") as usize
};
// Convert PreprocessedRequest to DirectRequest for scheduler
let direct_request = DirectRequest {
tokens: request.token_ids.clone(),
max_output_tokens,
uuid: Some(request_uuid),
dp_rank,
};
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
{
let mut active = self.active_requests.lock().await;
active.insert(request_uuid, request_tx);
}
// Send the request to the appropriate scheduler based on dp_rank
self.direct(direct_request, dp_rank as usize);
// Create a simple channel for the stream
let (stream_tx, stream_rx) = mpsc::unbounded_channel::<LLMEngineOutput>();
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let bootstrap_server = self.bootstrap_server.clone();
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
let mut token_count = 0;
loop {
tokio::select! {
maybe_signal = request_rx.recv() => {
let Some(signal) = maybe_signal else {
let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string()));
break;
};
// Generate a new token
let token_id = generate_random_token();
token_count += 1;
let output = LLMEngineOutput {
token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None,
completion_usage: None,
};
// Prefill: after first token, mark room complete (unblocks decode)
if is_prefill
&& token_count == 1
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{
server.complete_room(room_id);
}
if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break;
}
if signal.completed {
let _ = stream_tx.send(output);
let _ = stream_tx.send(LLMEngineOutput::length());
break;
}
if stream_tx.send(output).is_err() {
tracing::error!("Output stream receiver closed.");
break;
}
}
_ = async_context.stopped() => {
let _ = stream_tx.send(LLMEngineOutput::cancelled());
break;
}
}
}
// Clean up: remove this request from active requests
let mut active = active_requests.lock().await;
active.remove(&request_uuid);
});
// Create a simple UnboundedReceiverStream which is naturally Send + Sync
let stream = UnboundedReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
pub struct AnnotatedMockEngine {
inner: Arc<MockVllmEngine>,
}
impl AnnotatedMockEngine {
pub fn new(
inner: MockVllmEngine,
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
) -> Self {
let inner = Arc::new(inner);
let inner_clone = inner.clone();
// Start background task to wait for component service and start the engine
tokio::spawn(async move {
loop {
// Try to create component
let Ok(namespace) = distributed_runtime.namespace(&endpoint_id.namespace) else {
tracing::debug!("Namespace not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let Ok(component) = namespace.component(&endpoint_id.component) else {
tracing::debug!("Component not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
// Check if service is available by trying to list instances
let Ok(instances) = component.list_instances().await else {
tracing::debug!("Cannot list instances yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
if instances.is_empty() {
tracing::debug!("No instances available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
tracing::debug!("Component service is now available, starting mocker engine");
// Start the engine with the component
if let Err(e) = inner_clone.start(component).await {
tracing::error!("Failed to start mocker engine: {e}");
}
break;
}
});
Self { inner }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for AnnotatedMockEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let stream = self.inner.generate(input).await?;
let context = stream.context();
// Convert stream of LLMEngineOutput to Annotated<LLMEngineOutput>
let annotated_stream = stream.map(Annotated::from_data);
Ok(ResponseStream::new(Box::pin(annotated_stream), context))
}
}
/// Create a mocker engine as ExecutionContext
pub async fn make_mocker_engine(
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine
tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
Ok(Arc::new(annotated_engine))
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-mocker"
description = "Mock LLM scheduler and KV manager for testing"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
[dependencies]
# repo
dynamo-tokens = { workspace = true }
dynamo-kv-router = { workspace = true }
# workspace
anyhow = { workspace = true }
dashmap = { workspace = true }
derive_builder = { workspace = true }
derive-getters = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
# crate-specific
ndarray = "0.16"
ndarray-npy = "0.9"
ndarray-interp = "0.5"
[dev-dependencies]
rstest = "0.18.2"
......@@ -33,16 +33,14 @@
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager.
use crate::kv_router::protocols::{
use crate::evictor::LRUEvictor;
use crate::protocols::{KvCacheEventSink, MoveBlock, PrefillCost};
use crate::sequence::ActiveSequence;
use derive_getters::Getters;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::kv_router::publisher::KvEventPublisher;
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, PrefillCost};
use crate::mocker::sequence::ActiveSequence;
use derive_getters::Getters;
use dynamo_runtime::component::Component;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap;
......@@ -60,7 +58,7 @@ pub struct KvManager {
inactive_blocks: LRUEvictor<UniqueBlock>,
kv_event_publisher: Option<Arc<KvEventPublisher>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
#[getter(copy)]
dp_rank: u32,
......@@ -70,41 +68,36 @@ pub struct KvManager {
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_publisher(max_capacity, block_size, None, 0, false)
Self::new_with_event_sink(max_capacity, block_size, None, 0)
}
pub fn new_with_publisher(
pub fn new_with_event_sink(
max_capacity: usize,
block_size: usize,
component: Option<Component>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
dp_rank: u32,
enable_local_indexer: bool,
) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let kv_event_publisher = component.map(|comp| {
if kv_event_sink.is_some() {
tracing::info!(
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}"
"KvManager initialized with event sink for DP rank {dp_rank} with block_size {block_size}"
);
Arc::new(
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer, dp_rank)
.expect("Failed to create KV event publisher"),
)
});
}
KvManager {
max_capacity,
block_size,
active_blocks,
inactive_blocks,
kv_event_publisher,
kv_event_sink,
dp_rank,
next_event_id: 0,
}
}
/// Converts stored/removed blocks into KvCacheEventData and publishes if publisher is available
/// Converts stored/removed blocks into KvCacheEventData and publishes if sink is available
fn publish_kv_event(
&mut self,
full_blocks: Vec<SequenceHash>,
......@@ -116,7 +109,7 @@ impl KvManager {
return;
}
let Some(ref publisher) = self.kv_event_publisher else {
let Some(ref sink) = self.kv_event_sink else {
return;
};
......@@ -158,7 +151,7 @@ impl KvManager {
dp_rank: self.dp_rank,
};
if let Err(e) = publisher.publish(event) {
if let Err(e) = sink.publish(event) {
tracing::warn!("Failed to publish KV event: {e}");
}
}
......@@ -207,7 +200,7 @@ impl KvManager {
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
if self.kv_event_publisher.is_some()
if self.kv_event_sink.is_some()
&& let UniqueBlock::FullBlock(stored_full_block) = hash
{
blocks_stored.push(*stored_full_block);
......@@ -230,7 +223,7 @@ impl KvManager {
self.active_blocks.remove(hash).unwrap();
// Track blocks for batch sending
if self.kv_event_publisher.is_some()
if self.kv_event_sink.is_some()
&& let UniqueBlock::FullBlock(destroyed_full_block) = hash
{
blocks_destroyed.push(*destroyed_full_block);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Mock LLM scheduler and KV manager for testing.
//!
//! This crate provides a mock implementation of an LLM scheduler that simulates
//! KV cache management, request scheduling, and token generation timing without
//! requiring actual GPU resources or a full distributed runtime.
pub mod bootstrap;
pub mod evictor;
pub mod kv_manager;
pub mod perf_model;
pub mod protocols;
pub mod running_mean;
pub mod scheduler;
pub mod sequence;
// Re-export commonly used types
pub use protocols::{DirectRequest, KvCacheEventSink, MockEngineArgs, MockEngineArgsBuilder};
pub use scheduler::Scheduler;
......@@ -8,10 +8,17 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use uuid::Uuid;
use crate::mocker::perf_model::PerfModel;
use crate::perf_model::PerfModel;
use dynamo_kv_router::protocols::KvCacheEvent;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash, Token};
/// Trait for publishing KV cache events.
/// This abstracts the runtime dependency so mocker components can remain generic.
pub trait KvCacheEventSink: Send + Sync {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()>;
}
pub type NumBlocks = usize;
/// Represents different block movement operations in the cache
......
......@@ -28,17 +28,19 @@
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::kv_router::protocols::DpRank;
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager;
use crate::mocker::perf_model::PerfModel;
use crate::mocker::protocols::{
DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType,
use crate::evictor::LRUEvictor;
use crate::kv_manager::KvManager;
use crate::perf_model::PerfModel;
use crate::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost,
WorkerType,
};
use crate::mocker::running_mean::RunningMean;
use crate::mocker::sequence::ActiveSequence;
use crate::running_mean::RunningMean;
use crate::sequence::ActiveSequence;
use dynamo_kv_router::protocols::DpRank;
use dynamo_tokens::blocks::UniqueBlock;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
......@@ -254,7 +256,7 @@ impl Scheduler {
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
component: Option<dynamo_runtime::component::Component>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
// Assert speedup_ratio is non-negative (0 means infinite speedup)
......@@ -279,12 +281,11 @@ impl Scheduler {
tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut kv_manager = KvManager::new_with_publisher(
let mut kv_manager = KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
component,
kv_event_sink,
dp_rank,
args.enable_local_indexer,
);
let mut hit_rates = RunningMean::new(1000);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::mocker::protocols::MoveBlock;
use crate::protocols::MoveBlock;
use derive_getters::Getters;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{TokenBlockSequence, Tokens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment