"examples/vscode:/vscode.git/clone" did not exist on "e14be96a1b564cce17972799899805b1fbd93b95"
Unverified Commit 91fb78cd authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: make mocker it's own crate (#5958)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a72f41f6
...@@ -2347,6 +2347,7 @@ dependencies = [ ...@@ -2347,6 +2347,7 @@ dependencies = [
"dynamo-async-openai", "dynamo-async-openai",
"dynamo-kv-router", "dynamo-kv-router",
"dynamo-memory", "dynamo-memory",
"dynamo-mocker",
"dynamo-parsers", "dynamo-parsers",
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens", "dynamo-tokens",
...@@ -2442,6 +2443,29 @@ dependencies = [ ...@@ -2442,6 +2443,29 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "dynamo-mocker"
version = "0.9.0"
dependencies = [
"anyhow",
"dashmap 6.1.0",
"derive-getters",
"derive_builder",
"dynamo-kv-router",
"dynamo-tokens",
"ndarray",
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
"rstest 0.18.2",
"serde",
"serde_json",
"tokio",
"tokio-util",
"tracing",
"uuid 1.18.1",
]
[[package]] [[package]]
name = "dynamo-parsers" name = "dynamo-parsers"
version = "0.9.0" version = "0.9.0"
......
...@@ -8,6 +8,7 @@ members = [ ...@@ -8,6 +8,7 @@ members = [
"lib/runtime", "lib/runtime",
"lib/config", "lib/config",
"lib/tokens", "lib/tokens",
"lib/mocker",
"lib/kv-router", "lib/kv-router",
"lib/memory", "lib/memory",
"lib/async-openai", "lib/async-openai",
...@@ -26,6 +27,7 @@ default-members = [ ...@@ -26,6 +27,7 @@ default-members = [
"lib/runtime", "lib/runtime",
"lib/config", "lib/config",
"lib/tokens", "lib/tokens",
"lib/mocker",
"lib/memory", "lib/memory",
"lib/async-openai", "lib/async-openai",
"lib/parsers", "lib/parsers",
...@@ -50,6 +52,7 @@ dynamo-llm = { path = "lib/llm", version = "0.9.0" } ...@@ -50,6 +52,7 @@ dynamo-llm = { path = "lib/llm", version = "0.9.0" }
dynamo-config = { path = "lib/config", version = "0.9.0" } dynamo-config = { path = "lib/config", version = "0.9.0" }
dynamo-tokens = { path = "lib/tokens", version = "0.9.0" } dynamo-tokens = { path = "lib/tokens", version = "0.9.0" }
dynamo-memory = { path = "lib/memory", version = "0.9.0" } dynamo-memory = { path = "lib/memory", version = "0.9.0" }
dynamo-mocker = { path = "lib/mocker", version = "0.9.0" }
dynamo-kv-router = { path = "lib/kv-router", version = "0.9.0", features = ["metrics"] } dynamo-kv-router = { path = "lib/kv-router", version = "0.9.0", features = ["metrics"] }
dynamo-async-openai = { path = "lib/async-openai", version = "0.9.0", features = ["byot"] } dynamo-async-openai = { path = "lib/async-openai", version = "0.9.0", features = ["byot"] }
dynamo-parsers = { path = "lib/parsers", version = "0.9.0" } dynamo-parsers = { path = "lib/parsers", version = "0.9.0" }
......
...@@ -164,8 +164,7 @@ async fn engine_for( ...@@ -164,8 +164,7 @@ async fn engine_for(
let args = flags.mocker_config(); let args = flags.mocker_config();
let endpoint = local_model.endpoint_id().clone(); let endpoint = local_model.endpoint_id().clone();
let engine = let engine = dynamo_llm::mocker::make_mocker_engine(drt, endpoint, args).await?;
dynamo_llm::mocker::engine::make_mocker_engine(drt, endpoint, args).await?;
Ok(EngineConfig::InProcessTokens { Ok(EngineConfig::InProcessTokens {
engine, engine,
......
...@@ -1632,6 +1632,7 @@ dependencies = [ ...@@ -1632,6 +1632,7 @@ dependencies = [
"ahash", "ahash",
"aho-corasick", "aho-corasick",
"akin", "akin",
"aligned-vec",
"anyhow", "anyhow",
"async-nats", "async-nats",
"async-stream", "async-stream",
...@@ -1648,6 +1649,7 @@ dependencies = [ ...@@ -1648,6 +1649,7 @@ dependencies = [
"bytes", "bytes",
"candle-core", "candle-core",
"chrono", "chrono",
"cudarc",
"dashmap 5.5.3", "dashmap 5.5.3",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
...@@ -1655,6 +1657,7 @@ dependencies = [ ...@@ -1655,6 +1657,7 @@ dependencies = [
"dynamo-async-openai", "dynamo-async-openai",
"dynamo-kv-router", "dynamo-kv-router",
"dynamo-memory", "dynamo-memory",
"dynamo-mocker",
"dynamo-parsers", "dynamo-parsers",
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens", "dynamo-tokens",
...@@ -1681,6 +1684,7 @@ dependencies = [ ...@@ -1681,6 +1684,7 @@ dependencies = [
"ndarray", "ndarray",
"ndarray-interp", "ndarray-interp",
"ndarray-npy", "ndarray-npy",
"nix 0.26.4",
"nixl-sys", "nixl-sys",
"object_store", "object_store",
"offset-allocator", "offset-allocator",
...@@ -1739,6 +1743,28 @@ dependencies = [ ...@@ -1739,6 +1743,28 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "dynamo-mocker"
version = "0.9.0"
dependencies = [
"anyhow",
"dashmap 6.1.0",
"derive-getters",
"derive_builder",
"dynamo-kv-router",
"dynamo-tokens",
"ndarray",
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
"serde",
"serde_json",
"tokio",
"tokio-util",
"tracing",
"uuid",
]
[[package]] [[package]]
name = "dynamo-parsers" name = "dynamo-parsers"
version = "0.9.0" version = "0.9.0"
...@@ -3913,6 +3939,15 @@ version = "0.3.3" ...@@ -3913,6 +3939,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.9.1" version = "0.9.1"
...@@ -4206,6 +4241,19 @@ dependencies = [ ...@@ -4206,6 +4241,19 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.4",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.29.0" version = "0.29.0"
...@@ -5452,7 +5500,7 @@ dependencies = [ ...@@ -5452,7 +5500,7 @@ dependencies = [
"cfg-if 1.0.4", "cfg-if 1.0.4",
"indoc", "indoc",
"libc", "libc",
"memoffset", "memoffset 0.9.1",
"once_cell", "once_cell",
"portable-atomic", "portable-atomic",
"pyo3-build-config", "pyo3-build-config",
......
...@@ -402,7 +402,7 @@ async fn select_engine( ...@@ -402,7 +402,7 @@ async fn select_engine(
let endpoint = local_model.endpoint_id().clone(); let endpoint = local_model.endpoint_id().clone();
let engine = dynamo_llm::mocker::engine::make_mocker_engine( let engine = dynamo_llm::mocker::make_mocker_engine(
distributed_runtime.inner, distributed_runtime.inner,
endpoint, endpoint,
mocker_args, mocker_args,
......
...@@ -50,6 +50,7 @@ dynamo-runtime = { workspace = true } ...@@ -50,6 +50,7 @@ dynamo-runtime = { workspace = true }
dynamo-tokens = { workspace = true } dynamo-tokens = { workspace = true }
dynamo-kv-router = { workspace = true, features = ["metrics"] } dynamo-kv-router = { workspace = true, features = ["metrics"] }
dynamo-memory = { workspace = true } dynamo-memory = { workspace = true }
dynamo-mocker = { workspace = true }
# workspace # workspace
aho-corasick = "1.1" aho-corasick = "1.1"
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler
//!
//! This module provides an AsyncEngine implementation that wraps the Scheduler
//! to provide streaming token generation with realistic timing simulation.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::StreamExt;
use rand::Rng;
use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{
component::Component,
engine::AsyncEngineContextProvider,
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
};
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::bootstrap::{BootstrapServer, connect_to_prefill};
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
pub const MOCKER_COMPONENT: &str = "mocker";
fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng();
rng.random_range(1000..2000)
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
}
impl MockVllmEngine {
/// Create a new MockVllmEngine with the given parameters
pub fn new(args: MockEngineArgs) -> Self {
Self {
active_requests: Arc::new(Mutex::new(HashMap::new())),
request_senders: Arc::new(OnceCell::new()),
engine_args: args,
bootstrap_server: Arc::new(OnceCell::new()),
}
}
pub async fn start(&self, component: Component) -> Result<()> {
// Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
// child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1.
// primary_token() is only cancelled in Phase 3, after waiting for inflight requests.
let cancel_token = component.drt().primary_token();
// Simulate engine startup time if configured
if let Some(startup_time_secs) = self.engine_args.startup_time {
tracing::info!("Simulating engine startup time: {:.2}s", startup_time_secs);
tokio::time::sleep(Duration::from_secs_f64(startup_time_secs)).await;
tracing::info!("Engine startup simulation completed");
}
// Start bootstrap server for prefill workers in disaggregated mode
if self.engine_args.worker_type == WorkerType::Prefill
&& let Some(port) = self.engine_args.bootstrap_port
{
let server = BootstrapServer::start(port, cancel_token.clone()).await?;
let _ = self.bootstrap_server.set(server);
tracing::info!(port = port, "Bootstrap server started for prefill worker");
}
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode
{
Some(component.clone())
} else {
None
};
let schedulers = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
scheduler_component,
cancel_token.clone(),
);
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?;
Ok(())
}
pub fn direct(&self, request: DirectRequest, dp_rank: usize) {
let senders = self.request_senders.get().expect("Not initialized");
let _ = senders[dp_rank].send(request);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
fn start_schedulers(
&self,
args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
cancel_token: CancellationToken,
) -> Vec<Scheduler> {
let mut schedulers = Vec::<Scheduler>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks
for dp_rank in 0..args.dp_size {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = Scheduler::new(
args.clone(),
dp_rank,
Some(output_tx),
component.clone(),
Some(cancel_token.clone()),
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
let active_requests_clone = active_requests.clone();
let cancel_token_cloned = cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
signal_result = output_rx.recv() => {
let Some(signal) = signal_result else {
break; // Channel closed
};
// Notify the specific request that a token was generated
let active = active_requests_clone.lock().await;
if let Some(request_tx) = active.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
}
_ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
break;
}
}
}
});
}
// Set the senders once
self.request_senders
.set(senders)
.expect("Already initialized");
schedulers
}
/// Start background tasks to publish metrics on change
async fn start_metrics_publishing(
schedulers: &[Scheduler],
component: Option<Component>,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::debug!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::debug!("Metrics publisher created");
if let Some(comp) = component {
tracing::debug!("Creating metrics endpoint");
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
});
// Give it a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
tracing::debug!("Metrics endpoint started (background)");
}
tracing::debug!("Starting metrics background tasks");
for scheduler in schedulers.iter() {
let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone();
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
// Watch for metrics changes
Ok(_) = metrics_rx.changed() => {
// Get the latest metrics
let metrics = metrics_rx.borrow().clone();
// Publish metrics using flat API
if let Err(e) = publisher.publish(Some(metrics.dp_rank), metrics.active_decode_blocks) {
tracing::warn!("Failed to publish metrics for DP rank {}: {e}", metrics.dp_rank);
} else {
tracing::trace!("Published metrics for DP rank {}", metrics.dp_rank);
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Metrics publishing cancelled");
break;
}
}
}
});
}
tracing::info!("Metrics background tasks started");
Ok(())
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
for MockVllmEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts();
// Extract dp_rank from routing hints (defaults to 0 if not set)
let dp_rank = request
.routing
.as_ref()
.and_then(|r| r.dp_rank)
.unwrap_or(0);
// Validate dp_rank
if dp_rank >= self.engine_args.dp_size {
return Err(Error::msg(format!(
"dp_rank {} is out of bounds for dp_size {}",
dp_rank, self.engine_args.dp_size
)));
}
// Bootstrap rendezvous for disaggregated serving
// - Decode: connect to prefill's server, block until prefill completes
// - Prefill: complete_room() is called after first token (see below)
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.worker_type == WorkerType::Decode
{
connect_to_prefill(
&bootstrap_info.bootstrap_host,
bootstrap_info.bootstrap_port,
bootstrap_info.bootstrap_room,
)
.await
.map_err(|e| Error::msg(format!("Bootstrap connection failed: {e}")))?;
}
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
let max_output_tokens = if is_prefill {
1
} else {
request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker") as usize
};
// Convert PreprocessedRequest to DirectRequest for scheduler
let direct_request = DirectRequest {
tokens: request.token_ids.clone(),
max_output_tokens,
uuid: Some(request_uuid),
dp_rank,
};
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
{
let mut active = self.active_requests.lock().await;
active.insert(request_uuid, request_tx);
}
// Send the request to the appropriate scheduler based on dp_rank
self.direct(direct_request, dp_rank as usize);
// Create a simple channel for the stream
let (stream_tx, stream_rx) = mpsc::unbounded_channel::<LLMEngineOutput>();
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let bootstrap_server = self.bootstrap_server.clone();
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
let mut token_count = 0;
loop {
tokio::select! {
maybe_signal = request_rx.recv() => {
let Some(signal) = maybe_signal else {
let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string()));
break;
};
// Generate a new token
let token_id = generate_random_token();
token_count += 1;
let output = LLMEngineOutput {
token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None,
completion_usage: None,
};
// Prefill: after first token, mark room complete (unblocks decode)
if is_prefill
&& token_count == 1
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{
server.complete_room(room_id);
}
if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break;
}
if signal.completed {
let _ = stream_tx.send(output);
let _ = stream_tx.send(LLMEngineOutput::length());
break;
}
if stream_tx.send(output).is_err() {
tracing::error!("Output stream receiver closed.");
break;
}
}
_ = async_context.stopped() => {
let _ = stream_tx.send(LLMEngineOutput::cancelled());
break;
}
}
}
// Clean up: remove this request from active requests
let mut active = active_requests.lock().await;
active.remove(&request_uuid);
});
// Create a simple UnboundedReceiverStream which is naturally Send + Sync
let stream = UnboundedReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
pub struct AnnotatedMockEngine {
inner: Arc<MockVllmEngine>,
}
impl AnnotatedMockEngine {
pub fn new(
inner: MockVllmEngine,
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
) -> Self {
let inner = Arc::new(inner);
let inner_clone = inner.clone();
// Start background task to wait for component service and start the engine
tokio::spawn(async move {
loop {
// Try to create component
let Ok(namespace) = distributed_runtime.namespace(&endpoint_id.namespace) else {
tracing::debug!("Namespace not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let Ok(component) = namespace.component(&endpoint_id.component) else {
tracing::debug!("Component not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
// Check if service is available by trying to list instances
let Ok(instances) = component.list_instances().await else {
tracing::debug!("Cannot list instances yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
if instances.is_empty() {
tracing::debug!("No instances available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
tracing::debug!("Component service is now available, starting mocker engine");
// Start the engine with the component
if let Err(e) = inner_clone.start(component).await {
tracing::error!("Failed to start mocker engine: {e}");
}
break;
}
});
Self { inner }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for AnnotatedMockEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let stream = self.inner.generate(input).await?;
let context = stream.context();
// Convert stream of LLMEngineOutput to Annotated<LLMEngineOutput>
let annotated_stream = stream.map(Annotated::from_data);
Ok(ResponseStream::new(Box::pin(annotated_stream), context))
}
}
/// Create a mocker engine as ExecutionContext
pub async fn make_mocker_engine(
distributed_runtime: DistributedRuntime,
endpoint_id: dynamo_runtime::protocols::EndpointId,
args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine
tracing::info!("Creating mocker engine with config: {args:?}");
let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
Ok(Arc::new(annotated_engine))
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-mocker"
description = "Mock LLM scheduler and KV manager for testing"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
[dependencies]
# repo
dynamo-tokens = { workspace = true }
dynamo-kv-router = { workspace = true }
# workspace
anyhow = { workspace = true }
dashmap = { workspace = true }
derive_builder = { workspace = true }
derive-getters = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
# crate-specific
ndarray = "0.16"
ndarray-npy = "0.9"
ndarray-interp = "0.5"
[dev-dependencies]
rstest = "0.18.2"
...@@ -33,16 +33,14 @@ ...@@ -33,16 +33,14 @@
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror //! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager. //! implementation of the main block manager.
use crate::kv_router::protocols::{ use crate::evictor::LRUEvictor;
use crate::protocols::{KvCacheEventSink, MoveBlock, PrefillCost};
use crate::sequence::ActiveSequence;
use derive_getters::Getters;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, KvCacheStoredBlockData, LocalBlockHash,
}; };
use crate::kv_router::publisher::KvEventPublisher;
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, PrefillCost};
use crate::mocker::sequence::ActiveSequence;
use derive_getters::Getters;
use dynamo_runtime::component::Component;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash}; use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap; use std::collections::HashMap;
...@@ -60,7 +58,7 @@ pub struct KvManager { ...@@ -60,7 +58,7 @@ pub struct KvManager {
inactive_blocks: LRUEvictor<UniqueBlock>, inactive_blocks: LRUEvictor<UniqueBlock>,
kv_event_publisher: Option<Arc<KvEventPublisher>>, kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
#[getter(copy)] #[getter(copy)]
dp_rank: u32, dp_rank: u32,
...@@ -70,41 +68,36 @@ pub struct KvManager { ...@@ -70,41 +68,36 @@ pub struct KvManager {
impl KvManager { impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self { pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_publisher(max_capacity, block_size, None, 0, false) Self::new_with_event_sink(max_capacity, block_size, None, 0)
} }
pub fn new_with_publisher( pub fn new_with_event_sink(
max_capacity: usize, max_capacity: usize,
block_size: usize, block_size: usize,
component: Option<Component>, kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
dp_rank: u32, dp_rank: u32,
enable_local_indexer: bool,
) -> Self { ) -> Self {
let active_blocks = HashMap::new(); let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default(); let inactive_blocks = LRUEvictor::default();
let kv_event_publisher = component.map(|comp| { if kv_event_sink.is_some() {
tracing::info!( tracing::info!(
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}" "KvManager initialized with event sink for DP rank {dp_rank} with block_size {block_size}"
); );
Arc::new( }
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer, dp_rank)
.expect("Failed to create KV event publisher"),
)
});
KvManager { KvManager {
max_capacity, max_capacity,
block_size, block_size,
active_blocks, active_blocks,
inactive_blocks, inactive_blocks,
kv_event_publisher, kv_event_sink,
dp_rank, dp_rank,
next_event_id: 0, next_event_id: 0,
} }
} }
/// Converts stored/removed blocks into KvCacheEventData and publishes if publisher is available /// Converts stored/removed blocks into KvCacheEventData and publishes if sink is available
fn publish_kv_event( fn publish_kv_event(
&mut self, &mut self,
full_blocks: Vec<SequenceHash>, full_blocks: Vec<SequenceHash>,
...@@ -116,7 +109,7 @@ impl KvManager { ...@@ -116,7 +109,7 @@ impl KvManager {
return; return;
} }
let Some(ref publisher) = self.kv_event_publisher else { let Some(ref sink) = self.kv_event_sink else {
return; return;
}; };
...@@ -158,7 +151,7 @@ impl KvManager { ...@@ -158,7 +151,7 @@ impl KvManager {
dp_rank: self.dp_rank, dp_rank: self.dp_rank,
}; };
if let Err(e) = publisher.publish(event) { if let Err(e) = sink.publish(event) {
tracing::warn!("Failed to publish KV event: {e}"); tracing::warn!("Failed to publish KV event: {e}");
} }
} }
...@@ -207,7 +200,7 @@ impl KvManager { ...@@ -207,7 +200,7 @@ impl KvManager {
// Now insert the new block in active blocks with reference count 1 // Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1); self.active_blocks.insert(hash.clone(), 1);
if self.kv_event_publisher.is_some() if self.kv_event_sink.is_some()
&& let UniqueBlock::FullBlock(stored_full_block) = hash && let UniqueBlock::FullBlock(stored_full_block) = hash
{ {
blocks_stored.push(*stored_full_block); blocks_stored.push(*stored_full_block);
...@@ -230,7 +223,7 @@ impl KvManager { ...@@ -230,7 +223,7 @@ impl KvManager {
self.active_blocks.remove(hash).unwrap(); self.active_blocks.remove(hash).unwrap();
// Track blocks for batch sending // Track blocks for batch sending
if self.kv_event_publisher.is_some() if self.kv_event_sink.is_some()
&& let UniqueBlock::FullBlock(destroyed_full_block) = hash && let UniqueBlock::FullBlock(destroyed_full_block) = hash
{ {
blocks_destroyed.push(*destroyed_full_block); blocks_destroyed.push(*destroyed_full_block);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Mock LLM scheduler and KV manager for testing.
//!
//! This crate provides a mock implementation of an LLM scheduler that simulates
//! KV cache management, request scheduling, and token generation timing without
//! requiring actual GPU resources or a full distributed runtime.
pub mod bootstrap;
pub mod evictor;
pub mod kv_manager;
pub mod perf_model;
pub mod protocols;
pub mod running_mean;
pub mod scheduler;
pub mod sequence;
// Re-export commonly used types
pub use protocols::{DirectRequest, KvCacheEventSink, MockEngineArgs, MockEngineArgsBuilder};
pub use scheduler::Scheduler;
...@@ -8,10 +8,17 @@ use std::path::{Path, PathBuf}; ...@@ -8,10 +8,17 @@ use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
use crate::mocker::perf_model::PerfModel; use crate::perf_model::PerfModel;
use dynamo_kv_router::protocols::KvCacheEvent;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash, Token}; use dynamo_tokens::{BlockHash, SequenceHash, Token};
/// Trait for publishing KV cache events.
/// This abstracts the runtime dependency so mocker components can remain generic.
pub trait KvCacheEventSink: Send + Sync {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()>;
}
pub type NumBlocks = usize; pub type NumBlocks = usize;
/// Represents different block movement operations in the cache /// Represents different block movement operations in the cache
......
...@@ -28,17 +28,19 @@ ...@@ -28,17 +28,19 @@
//! ## NOTE //! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP //! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::kv_router::protocols::DpRank; use crate::evictor::LRUEvictor;
use crate::mocker::evictor::LRUEvictor; use crate::kv_manager::KvManager;
use crate::mocker::kv_manager::KvManager; use crate::perf_model::PerfModel;
use crate::mocker::perf_model::PerfModel; use crate::protocols::{
use crate::mocker::protocols::{ DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost,
DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType, WorkerType,
}; };
use crate::mocker::running_mean::RunningMean; use crate::running_mean::RunningMean;
use crate::mocker::sequence::ActiveSequence; use crate::sequence::ActiveSequence;
use dynamo_kv_router::protocols::DpRank;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::time::Duration; use tokio::time::Duration;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -254,7 +256,7 @@ impl Scheduler { ...@@ -254,7 +256,7 @@ impl Scheduler {
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
component: Option<dynamo_runtime::component::Component>, kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
// Assert speedup_ratio is non-negative (0 means infinite speedup) // Assert speedup_ratio is non-negative (0 means infinite speedup)
...@@ -279,12 +281,11 @@ impl Scheduler { ...@@ -279,12 +281,11 @@ impl Scheduler {
tokio::spawn(async move { tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task // Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::new(args.max_num_batched_tokens); let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut kv_manager = KvManager::new_with_publisher( let mut kv_manager = KvManager::new_with_event_sink(
args.num_gpu_blocks, args.num_gpu_blocks,
args.block_size, args.block_size,
component, kv_event_sink,
dp_rank, dp_rank,
args.enable_local_indexer,
); );
let mut hit_rates = RunningMean::new(1000); let mut hit_rates = RunningMean::new(1000);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::mocker::protocols::MoveBlock; use crate::protocols::MoveBlock;
use derive_getters::Getters; use derive_getters::Getters;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{TokenBlockSequence, Tokens}; use dynamo_tokens::{TokenBlockSequence, Tokens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment