Unverified Commit cc4c3516 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: mocker disagg (#3833)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent e3346dab
...@@ -196,14 +196,20 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -196,14 +196,20 @@ for i in $(seq 1 $NUM_WORKERS); do
# Run mocker engine (no GPU assignment needed) # Run mocker engine (no GPU assignment needed)
MOCKER_ARGS=() MOCKER_ARGS=()
MOCKER_ARGS+=("--model-path" "$MODEL_PATH") MOCKER_ARGS+=("--model-path" "$MODEL_PATH")
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then # Set endpoint based on worker mode
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ "$MODE" = "prefill" ]; then if [ "$MODE" = "prefill" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.prefill.generate")
MOCKER_ARGS+=("--is-prefill-worker") MOCKER_ARGS+=("--is-prefill-worker")
elif [ "$MODE" = "decode" ]; then elif [ "$MODE" = "decode" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
MOCKER_ARGS+=("--is-decode-worker") MOCKER_ARGS+=("--is-decode-worker")
else
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
fi
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}") MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
......
...@@ -12,6 +12,7 @@ from . import __version__ ...@@ -12,6 +12,7 @@ from . import __version__
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate" DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -85,8 +86,8 @@ def parse_args(): ...@@ -85,8 +86,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--endpoint", "--endpoint",
type=str, type=str,
default=DEFAULT_ENDPOINT, default=None,
help=f"Dynamo endpoint string (default: {DEFAULT_ENDPOINT})", help=f"Dynamo endpoint string (default: {DEFAULT_ENDPOINT} for aggregated/decode, {DEFAULT_PREFILL_ENDPOINT} for prefill)",
) )
parser.add_argument( parser.add_argument(
"--model-name", "--model-name",
...@@ -199,4 +200,14 @@ def parse_args(): ...@@ -199,4 +200,14 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
validate_worker_type_args(args) validate_worker_type_args(args)
# Set endpoint default based on worker type if not explicitly provided
if args.endpoint is None:
if args.is_prefill_worker:
args.endpoint = DEFAULT_PREFILL_ENDPOINT
logger.debug(f"Using default prefill endpoint: {args.endpoint}")
else:
args.endpoint = DEFAULT_ENDPOINT
logger.debug(f"Using default endpoint: {args.endpoint}")
return args return args
...@@ -60,7 +60,6 @@ pub enum DynamoLlmResult { ...@@ -60,7 +60,6 @@ pub enum DynamoLlmResult {
pub unsafe extern "C" fn dynamo_llm_init( pub unsafe extern "C" fn dynamo_llm_init(
namespace_c_str: *const c_char, namespace_c_str: *const c_char,
component_c_str: *const c_char, component_c_str: *const c_char,
worker_id: i64,
kv_block_size: u32, kv_block_size: u32,
) -> DynamoLlmResult { ) -> DynamoLlmResult {
initialize_tracing(); initialize_tracing();
...@@ -102,7 +101,7 @@ pub unsafe extern "C" fn dynamo_llm_init( ...@@ -102,7 +101,7 @@ pub unsafe extern "C" fn dynamo_llm_init(
match result { match result {
Ok(_) => match KV_PUB.get_or_try_init(move || { Ok(_) => match KV_PUB.get_or_try_init(move || {
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size) dynamo_create_kv_publisher(namespace, component, kv_block_size)
}) { }) {
Ok(_) => DynamoLlmResult::OK, Ok(_) => DynamoLlmResult::OK,
Err(e) => { Err(e) => {
...@@ -144,7 +143,6 @@ pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult { ...@@ -144,7 +143,6 @@ pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
fn dynamo_create_kv_publisher( fn dynamo_create_kv_publisher(
namespace: String, namespace: String,
component: String, component: String,
worker_id: i64,
kv_block_size: u32, kv_block_size: u32,
) -> Result<KvEventPublisher, anyhow::Error> { ) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component); tracing::info!("Creating KV Publisher for model: {}", component);
...@@ -154,7 +152,7 @@ fn dynamo_create_kv_publisher( ...@@ -154,7 +152,7 @@ fn dynamo_create_kv_publisher(
{ {
Ok(drt) => { Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?; let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(backend, worker_id as u64, kv_block_size, None) KvEventPublisher::new(backend, kv_block_size, None)
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
......
...@@ -143,7 +143,6 @@ impl ZmqKvEventPublisher { ...@@ -143,7 +143,6 @@ impl ZmqKvEventPublisher {
fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> { fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner, component.inner,
config.worker_id,
config.kv_block_size as u32, config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq { Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint, endpoint: config.zmq_endpoint,
...@@ -239,20 +238,14 @@ pub(crate) struct KvEventPublisher { ...@@ -239,20 +238,14 @@ pub(crate) struct KvEventPublisher {
#[pymethods] #[pymethods]
impl KvEventPublisher { impl KvEventPublisher {
#[new] #[new]
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))] #[pyo3(signature = (component, kv_block_size, dp_rank=0))]
fn new( fn new(component: Component, kv_block_size: usize, dp_rank: DpRank) -> PyResult<Self> {
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
} }
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner, component.inner,
worker_id,
kv_block_size as u32, kv_block_size as u32,
None, None,
) )
......
...@@ -354,9 +354,10 @@ impl RadixTree { ...@@ -354,9 +354,10 @@ impl RadixTree {
None => { None => {
tracing::warn!( tracing::warn!(
worker_id = worker.worker_id.to_string(), worker_id = worker.worker_id.to_string(),
dp_rank = ?worker.dp_rank, dp_rank = worker.dp_rank,
id, id,
parent_hash = ?op.parent_hash, parent_hash = ?op.parent_hash,
num_blocks = op.blocks.len(),
"Failed to find parent block; skipping store operation" "Failed to find parent block; skipping store operation"
); );
return Err(KvCacheEventError::ParentBlockNotFound); return Err(KvCacheEventError::ParentBlockNotFound);
...@@ -412,8 +413,10 @@ impl RadixTree { ...@@ -412,8 +413,10 @@ impl RadixTree {
Some(entry) => entry.clone(), Some(entry) => entry.clone(),
None => { None => {
tracing::warn!( tracing::warn!(
worker_id = worker_id.to_string(), worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id, id,
block_hash = ?block,
"Failed to find block to remove; skipping remove operation" "Failed to find block to remove; skipping remove operation"
); );
return Err(KvCacheEventError::BlockNotFound); return Err(KvCacheEventError::BlockNotFound);
......
...@@ -213,8 +213,12 @@ impl ...@@ -213,8 +213,12 @@ impl
let (req, context) = request.into_parts(); let (req, context) = request.into_parts();
let request_id = context.id().to_string(); let request_id = context.id().to_string();
// Prepare prefill request with linked context for cancellation propagation // Save original max_tokens for decode
let prefill_req = req.clone(); let original_max_tokens = req.stop_conditions.max_tokens;
// Prepare prefill request with max_tokens = 1
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate // Link the prefill context as a child so that kill signals propagate
...@@ -230,6 +234,8 @@ impl ...@@ -230,6 +234,8 @@ impl
// Update request with disaggregated_params and router config // Update request with disaggregated_params and router config
let mut decode_req = req; let mut decode_req = req;
decode_req.disaggregated_params = Some(disaggregated_params); decode_req.disaggregated_params = Some(disaggregated_params);
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
// Set router_config_override for decode: overlap_score_weight = 0 // Set router_config_override for decode: overlap_score_weight = 0
let existing_override = decode_req.router_config_override.take(); let existing_override = decode_req.router_config_override.take();
......
...@@ -97,7 +97,6 @@ pub struct KvEventPublisher { ...@@ -97,7 +97,6 @@ pub struct KvEventPublisher {
impl KvEventPublisher { impl KvEventPublisher {
pub fn new( pub fn new(
component: Component, component: Component,
worker_id: u64,
kv_block_size: u32, kv_block_size: u32,
source_config: Option<KvEventSourceConfig>, source_config: Option<KvEventSourceConfig>,
) -> Result<Self> { ) -> Result<Self> {
...@@ -105,6 +104,9 @@ impl KvEventPublisher { ...@@ -105,6 +104,9 @@ impl KvEventPublisher {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
// Infer worker_id from component's connection
let worker_id = component.drt().connection_id();
// Create our event source (if any) // Create our event source (if any)
let mut source = None; let mut source = None;
if let Some(config) = source_config { if let Some(config) = source_config {
......
...@@ -5,5 +5,6 @@ pub mod engine; ...@@ -5,5 +5,6 @@ pub mod engine;
pub mod evictor; pub mod evictor;
pub mod kv_manager; pub mod kv_manager;
pub mod protocols; pub mod protocols;
pub mod running_mean;
pub mod scheduler; pub mod scheduler;
pub mod sequence; pub mod sequence;
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
use crate::kv_router::publisher::WorkerMetricsPublisher; use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest; use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal}; use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler; use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}; use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
...@@ -23,9 +23,6 @@ use dynamo_runtime::{ ...@@ -23,9 +23,6 @@ use dynamo_runtime::{
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait}, pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
}; };
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use crate::kv_router::publisher::KvEventPublisher;
use futures::StreamExt; use futures::StreamExt;
use rand::Rng; use rand::Rng;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -37,10 +34,9 @@ use uuid::Uuid; ...@@ -37,10 +34,9 @@ use uuid::Uuid;
pub const MOCKER_COMPONENT: &str = "mocker"; pub const MOCKER_COMPONENT: &str = "mocker";
/// Generate a random token ID from 1k to 5k
fn generate_random_token() -> TokenIdType { fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng(); let mut rng = rand::rng();
rng.random_range(1000..5000) rng.random_range(1000..2000)
} }
/// AsyncEngine wrapper around the Scheduler that generates random character tokens /// AsyncEngine wrapper around the Scheduler that generates random character tokens
...@@ -71,26 +67,25 @@ impl MockVllmEngine { ...@@ -71,26 +67,25 @@ impl MockVllmEngine {
tracing::info!("Engine startup simulation completed"); tracing::info!("Engine startup simulation completed");
} }
let (schedulers, kv_event_receiver) = self.start_schedulers( // Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode
{
Some(component.clone())
} else {
None
};
let schedulers = self.start_schedulers(
self.engine_args.clone(), self.engine_args.clone(),
self.active_requests.clone(), self.active_requests.clone(),
scheduler_component,
cancel_token.clone(), cancel_token.clone(),
); );
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone()) Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?; .await?;
// Start KV events publishing with the actual receivers from schedulers
if self.engine_args.enable_prefix_caching {
Self::start_kv_events_publishing(
kv_event_receiver,
Some(component.clone()),
self.engine_args.block_size,
cancel_token.clone(),
)
.await?;
}
Ok(()) Ok(())
} }
...@@ -100,18 +95,14 @@ impl MockVllmEngine { ...@@ -100,18 +95,14 @@ impl MockVllmEngine {
} }
/// Create schedulers and spawn their background tasks for distributing token notifications /// Create schedulers and spawn their background tasks for distributing token notifications
/// Returns schedulers and their corresponding KV event receivers
fn start_schedulers( fn start_schedulers(
&self, &self,
args: MockEngineArgs, args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>, active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> ( ) -> Vec<Scheduler> {
Vec<Scheduler>,
Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
) {
let mut schedulers = Vec::<Scheduler>::new(); let mut schedulers = Vec::<Scheduler>::new();
let mut kv_event_receivers = Vec::new();
let mut senders = Vec::with_capacity(args.dp_size as usize); let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks // Create multiple schedulers and their background tasks
...@@ -119,20 +110,16 @@ impl MockVllmEngine { ...@@ -119,20 +110,16 @@ impl MockVllmEngine {
// Create a shared output channel that this scheduler will use // Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create a channel for KV events from this scheduler
let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::<KvCacheEventData>();
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
args.clone(), args.clone(),
dp_rank, dp_rank,
Some(output_tx), Some(output_tx),
Some(kv_events_tx), // Pass the KV events sender to scheduler component.clone(),
Some(cancel_token.clone()), Some(cancel_token.clone()),
); );
senders.push(scheduler.request_sender()); senders.push(scheduler.request_sender());
schedulers.push(scheduler); schedulers.push(scheduler);
kv_event_receivers.push(kv_events_rx);
// Spawn a background task for this scheduler to distribute token notifications to active requests // Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx)); // let output_rx = Arc::new(Mutex::new(output_rx));
...@@ -166,7 +153,7 @@ impl MockVllmEngine { ...@@ -166,7 +153,7 @@ impl MockVllmEngine {
.set(senders) .set(senders)
.expect("Already initialized"); .expect("Already initialized");
(schedulers, kv_event_receivers) schedulers
} }
/// Start background tasks to publish metrics on change /// Start background tasks to publish metrics on change
...@@ -228,78 +215,6 @@ impl MockVllmEngine { ...@@ -228,78 +215,6 @@ impl MockVllmEngine {
tracing::info!("Metrics background tasks started"); tracing::info!("Metrics background tasks started");
Ok(()) Ok(())
} }
/// Start background tasks to collect and publish KV events from schedulers
async fn start_kv_events_publishing(
kv_event_receivers: Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
component: Option<Component>,
block_size: usize,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::debug!("Starting KV events publishing");
// Only start KV events publishing if we have a component
let Some(comp) = component else {
tracing::warn!("No component provided, skipping KV events publishing");
return Ok(());
};
tracing::debug!("Component found for KV events publishing");
tracing::debug!("Getting worker_id");
let worker_id = comp.drt().connection_id();
tracing::debug!("Worker_id set to: {worker_id}");
tracing::debug!("Creating KV event publisher");
let kv_event_publisher = Arc::new(KvEventPublisher::new(
comp.clone(),
worker_id,
block_size as u32,
None,
)?);
tracing::debug!("KV event publisher created");
tracing::debug!(
"Starting KV event background tasks for {} receivers",
kv_event_receivers.len()
);
for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() {
tracing::debug!("Starting background task for DP rank {dp_rank}");
let publisher = kv_event_publisher.clone();
let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
tracing::debug!("Background task started for DP rank {dp_rank}");
loop {
tokio::select! {
// Receive actual KV events from the scheduler
Some(event_data) = kv_events_rx.recv() => {
// Convert KvCacheEventData to KvCacheEvent with random UUID as event_id
let event = KvCacheEvent {
event_id: Uuid::new_v4().as_u128() as u64,
data: event_data,
dp_rank,
};
// Publish the event
if let Err(e) = publisher.publish(event) {
tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}");
} else {
tracing::trace!("Published KV event for DP rank {dp_rank}");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}");
break;
}
}
}
});
}
tracing::info!("All KV event background tasks started");
Ok(())
}
} }
#[async_trait] #[async_trait]
...@@ -325,14 +240,21 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -325,14 +240,21 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4()); let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
let max_output_tokens = if is_prefill {
1
} else {
request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker") as usize
};
// Convert PreprocessedRequest to DirectRequest for scheduler // Convert PreprocessedRequest to DirectRequest for scheduler
let direct_request = DirectRequest { let direct_request = DirectRequest {
tokens: request.token_ids.clone(), tokens: request.token_ids.clone(),
max_output_tokens: request max_output_tokens,
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker")
as usize,
uuid: Some(request_uuid), uuid: Some(request_uuid),
dp_rank, dp_rank,
}; };
...@@ -351,7 +273,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -351,7 +273,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let active_requests = self.active_requests.clone(); let active_requests = self.active_requests.clone();
let async_context = ctx.context(); let async_context = ctx.context();
let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize;
// Spawn a task to handle the complex async logic // Spawn a task to handle the complex async logic
tokio::spawn(async move { tokio::spawn(async move {
...@@ -378,11 +299,16 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -378,11 +299,16 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
disaggregated_params: None, // Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None, extra_args: None,
}; };
if signal.completed && token_count < max_tokens { if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string())); let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break; break;
} }
......
...@@ -33,13 +33,20 @@ ...@@ -33,13 +33,20 @@
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror //! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager. //! implementation of the main block manager.
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::kv_router::publisher::KvEventPublisher;
use crate::mocker::evictor::LRUEvictor; use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost}; use crate::mocker::protocols::{MoveBlock, PrefillCost};
use crate::mocker::sequence::ActiveSequence; use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash};
use derive_getters::Getters; use derive_getters::Getters;
use dynamo_runtime::component::Component;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc; use std::sync::Arc;
#[derive(Getters)] #[derive(Getters)]
pub struct KvManager { pub struct KvManager {
...@@ -55,60 +62,113 @@ pub struct KvManager { ...@@ -55,60 +62,113 @@ pub struct KvManager {
all_blocks: HashSet<UniqueBlock>, all_blocks: HashSet<UniqueBlock>,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>, kv_event_publisher: Option<Arc<KvEventPublisher>>,
#[getter(copy)]
dp_rank: u32,
next_event_id: u64,
} }
impl KvManager { impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self { pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_sender(max_capacity, block_size, None) Self::new_with_publisher(max_capacity, block_size, None, 0)
} }
pub fn new_with_sender( pub fn new_with_publisher(
max_capacity: usize, max_capacity: usize,
block_size: usize, block_size: usize,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>, component: Option<Component>,
dp_rank: u32,
) -> Self { ) -> Self {
let active_blocks = HashMap::new(); let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default(); let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new(); let all_blocks = HashSet::new();
let kv_event_publisher = component.map(|comp| {
tracing::info!(
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}"
);
Arc::new(
KvEventPublisher::new(comp, block_size as u32, None)
.expect("Failed to create KV event publisher"),
)
});
KvManager { KvManager {
max_capacity, max_capacity,
block_size, block_size,
active_blocks, active_blocks,
inactive_blocks, inactive_blocks,
all_blocks, all_blocks,
move_block_response_tx, kv_event_publisher,
dp_rank,
next_event_id: 0,
} }
} }
/// Utility method to send block responses with optional reversing /// Converts stored/removed blocks into KvCacheEventData and publishes if publisher is available
fn send_block_response( fn publish_kv_event(
&self, &mut self,
mut blocks: Vec<u64>, full_blocks: Vec<SequenceHash>,
reverse: bool, local_hashes: &[BlockHash],
store: bool,
parent_hash: Option<u64>, parent_hash: Option<u64>,
is_store: bool,
) { ) {
if let Some(ref tx) = self.move_block_response_tx if full_blocks.is_empty() {
&& !blocks.is_empty() return;
{ }
if reverse {
blocks.reverse(); let Some(ref publisher) = self.kv_event_publisher else {
} return;
let response = if store { };
MoveBlockResponse::Store(blocks, parent_hash)
} else { let event_data = if is_store {
MoveBlockResponse::Remove(blocks) let num_blocks = full_blocks.len();
}; let local_hashes_slice = &local_hashes[local_hashes
tx.send(response).unwrap(); .len()
.checked_sub(num_blocks)
.expect("local hashes fewer than stored blocks")..];
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.zip(local_hashes_slice.iter())
.map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
})
.collect(),
})
} else {
KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
})
};
// Use incremental event ID starting from 0
let event_id = self.next_event_id;
self.next_event_id += 1;
let event = KvCacheEvent {
event_id,
data: event_data,
dp_rank: self.dp_rank,
};
if let Err(e) = publisher.publish(event) {
tracing::warn!("Failed to publish KV event: {e}");
} }
} }
/// Process a MoveBlock instruction synchronously /// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool { pub fn process(&mut self, event: &MoveBlock) -> bool {
match event { match event {
MoveBlock::Use(hashes) => { MoveBlock::Use(hashes, local_hashes) => {
let mut blocks_stored = Vec::<u64>::new(); let mut blocks_stored = Vec::<u64>::new();
let mut parent_block: Option<&UniqueBlock> = None; let mut parent_block: Option<&UniqueBlock> = None;
...@@ -138,16 +198,20 @@ impl KvManager { ...@@ -138,16 +198,20 @@ impl KvManager {
let Some(evicted) = self.inactive_blocks.evict() else { let Some(evicted) = self.inactive_blocks.evict() else {
return false; return false;
}; };
tracing::trace!(
"Evicting block from inactive pool: {evicted:?}, dp_rank={}",
self.dp_rank
);
self.all_blocks.remove(&evicted); self.all_blocks.remove(&evicted);
if let UniqueBlock::FullBlock(evicted_full_block) = evicted { if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
self.send_block_response(vec![evicted_full_block], false, false, None); self.publish_kv_event(vec![evicted_full_block], &[], None, false);
} }
} }
// Now insert the new block in active blocks with reference count 1 // Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1); self.active_blocks.insert(hash.clone(), 1);
self.all_blocks.insert(hash.clone()); self.all_blocks.insert(hash.clone());
if self.move_block_response_tx.is_some() if self.kv_event_publisher.is_some()
&& let UniqueBlock::FullBlock(stored_full_block) = hash && let UniqueBlock::FullBlock(stored_full_block) = hash
{ {
blocks_stored.push(*stored_full_block); blocks_stored.push(*stored_full_block);
...@@ -159,32 +223,32 @@ impl KvManager { ...@@ -159,32 +223,32 @@ impl KvManager {
Some(UniqueBlock::FullBlock(block)) => Some(*block), Some(UniqueBlock::FullBlock(block)) => Some(*block),
Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"), Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
}; };
self.send_block_response(blocks_stored, false, true, parent_hash); self.publish_kv_event(blocks_stored, local_hashes, parent_hash, true);
} }
MoveBlock::Destroy(hashes) => { MoveBlock::Destroy(hashes) => {
let mut blocks_destroyed = Vec::<u64>::new(); let mut blocks_destroyed = Vec::<u64>::new();
// Loop in inverse direction // Process blocks in order (already reversed by caller if needed)
for hash in hashes.iter().rev() { for hash in hashes.iter() {
self.active_blocks.remove(hash).unwrap(); self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed // Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash)); assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending // Track blocks for batch sending
if self.move_block_response_tx.is_some() if self.kv_event_publisher.is_some()
&& let UniqueBlock::FullBlock(destroyed_full_block) = hash && let UniqueBlock::FullBlock(destroyed_full_block) = hash
{ {
blocks_destroyed.push(*destroyed_full_block); blocks_destroyed.push(*destroyed_full_block);
} }
} }
self.send_block_response(blocks_destroyed, true, false, None); self.publish_kv_event(blocks_destroyed, &[], None, false);
} }
MoveBlock::Deref(hashes) => { MoveBlock::Deref(hashes) => {
// Loop in inverse direction // Process blocks in order (already reversed by caller if needed)
for hash in hashes.iter().rev() { for hash in hashes.iter() {
// Decrement reference count and check if we need to move to inactive // Decrement reference count and check if we need to move to inactive
if let Some(ref_count) = self.active_blocks.get_mut(hash) { if let Some(ref_count) = self.active_blocks.get_mut(hash) {
if *ref_count == 0 { if *ref_count == 0 {
...@@ -202,24 +266,30 @@ impl KvManager { ...@@ -202,24 +266,30 @@ impl KvManager {
} }
} }
MoveBlock::Promote(uuid, hash, parent_hash) => { MoveBlock::Promote(uuid, hash, parent_hash, local_hash) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid); let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash); let hash_block = UniqueBlock::FullBlock(*hash);
let Some(ref_count) = self.active_blocks.remove(&uuid_block) else { assert_eq!(
let in_all_blocks = self.all_blocks.contains(&uuid_block); self.active_blocks.remove(&uuid_block),
panic!( Some(1),
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}" "uuid_block {uuid_block:?} should exist and be unique with ref_count=1"
); );
let hash_ref_count = if let Some(ref_count) = self.active_blocks.get(&hash_block) {
*ref_count
} else if self.inactive_blocks.remove(&hash_block) {
0
} else {
self.publish_kv_event(vec![*hash], &[*local_hash], *parent_hash, true);
0
}; };
// Replace with hash block, keeping the same reference count self.active_blocks
self.active_blocks.insert(hash_block.clone(), ref_count); .insert(hash_block.clone(), hash_ref_count + 1);
// Update all_blocks
assert!(self.all_blocks.remove(&uuid_block)); assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block); self.all_blocks.insert(hash_block);
self.send_block_response(vec![*hash], false, true, *parent_hash);
} }
} }
...@@ -291,7 +361,6 @@ impl KvManager { ...@@ -291,7 +361,6 @@ impl KvManager {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tokio::sync::mpsc;
#[test] #[test]
fn test_failure_on_max_capacity() { fn test_failure_on_max_capacity() {
...@@ -300,8 +369,9 @@ mod tests { ...@@ -300,8 +369,9 @@ mod tests {
// Helper function to use multiple blocks that returns the response // Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool { fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
manager.process(&MoveBlock::Use(blocks)) let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes))
} }
// First use 10 blocks (0 to 9) in a batch // First use 10 blocks (0 to 9) in a batch
...@@ -321,16 +391,14 @@ mod tests { ...@@ -321,16 +391,14 @@ mod tests {
#[test] #[test]
fn test_block_lifecycle_stringent() { fn test_block_lifecycle_stringent() {
// Create a channel to listen to block responses // Create a KvManager with 10 blocks capacity (no KV event publisher for tests)
let (tx, mut rx) = mpsc::unbounded_channel::<MoveBlockResponse>(); let mut manager = KvManager::new(10, 16);
// Create a KvManager with 10 blocks capacity and the response sender
let mut manager = KvManager::new_with_sender(10, 16, Some(tx));
// Helper function to use multiple blocks // Helper function to use multiple blocks
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) { fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
manager.process(&MoveBlock::Use(blocks)); let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes));
} }
// Helper function to destroy multiple blocks // Helper function to destroy multiple blocks
...@@ -345,56 +413,6 @@ mod tests { ...@@ -345,56 +413,6 @@ mod tests {
manager.process(&MoveBlock::Deref(blocks)); manager.process(&MoveBlock::Deref(blocks));
} }
// Helper function to assert block responses
fn assert_block_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
expected_type: &str,
expected_blocks: Vec<u64>,
description: &str,
) {
let response = rx
.try_recv()
.unwrap_or_else(|_| panic!("Expected {expected_type} response {description}"));
match (&response, expected_type) {
(MoveBlockResponse::Store(blocks, _parent_hash), "Store") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Store response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Store blocks don't match expected {description}"
);
}
(MoveBlockResponse::Remove(blocks), "Remove") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Remove response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Remove blocks don't match expected {description}"
);
}
_ => panic!("Expected {expected_type} response, got {response:?} {description}"),
}
}
// Helper function to assert no response is received
fn assert_no_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
description: &str,
) {
assert!(rx.try_recv().is_err(), "Expected no response {description}",);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts // Helper function to check if active blocks contain expected blocks with expected ref counts
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) { fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!( assert_eq!(
...@@ -442,11 +460,9 @@ mod tests { ...@@ -442,11 +460,9 @@ mod tests {
// First use blocks 0, 1, 2, 3, 4 in a batch // First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks(&mut manager, (0..5).collect()); use_blocks(&mut manager, (0..5).collect());
assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use");
// Then use blocks 0, 1, 5, 6 in a batch // Then use blocks 0, 1, 5, 6 in a batch
use_blocks(&mut manager, vec![0, 1, 5, 6]); use_blocks(&mut manager, vec![0, 1, 5, 6]);
assert_block_response(&mut rx, "Store", vec![5, 6], "after second use");
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2 // Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks( assert_active_blocks(
...@@ -456,11 +472,9 @@ mod tests { ...@@ -456,11 +472,9 @@ mod tests {
// Now destroy block 4 // Now destroy block 4
destroy_blocks(&mut manager, vec![4]); destroy_blocks(&mut manager, vec![4]);
assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4");
// And deref blocks 3, 2, 1, 0 in this order as a batch // And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks(&mut manager, vec![0, 1, 2, 3]); deref_blocks(&mut manager, vec![0, 1, 2, 3]);
assert_no_response(&mut rx, "after deref operation");
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2 // Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks(&manager, 2, &[3, 2]); assert_inactive_blocks(&manager, 2, &[3, 2]);
...@@ -468,7 +482,6 @@ mod tests { ...@@ -468,7 +482,6 @@ mod tests {
// Now destroy block 6 // Now destroy block 6
destroy_blocks(&mut manager, vec![6]); destroy_blocks(&mut manager, vec![6]);
assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction");
// And deref blocks 5, 1, 0 as a batch // And deref blocks 5, 1, 0 as a batch
deref_blocks(&mut manager, vec![0, 1, 5]); deref_blocks(&mut manager, vec![0, 1, 5]);
...@@ -479,7 +492,6 @@ mod tests { ...@@ -479,7 +492,6 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch // Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]); use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use");
// Check that the inactive_blocks is size 2, and contains 3 and 5 // Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks(&manager, 2, &[3, 5]); assert_inactive_blocks(&manager, 2, &[3, 5]);
...@@ -494,14 +506,10 @@ mod tests { ...@@ -494,14 +506,10 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch // Now use blocks 10, 11, 12 as a batch
use_blocks(&mut manager, vec![10, 11, 12]); use_blocks(&mut manager, vec![10, 11, 12]);
assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use");
// Check that the inactive_blocks is size 1 and contains only 5 // Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks(&manager, 1, &[5]); assert_inactive_blocks(&manager, 1, &[5]);
use_blocks(&mut manager, vec![13]); use_blocks(&mut manager, vec![13]);
assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![13], "after block 13 use");
} }
} }
...@@ -7,23 +7,19 @@ use std::collections::{HashMap, HashSet}; ...@@ -7,23 +7,19 @@ use std::collections::{HashMap, HashSet};
use std::path::Path; use std::path::Path;
use uuid::Uuid; use uuid::Uuid;
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash, Token}; use crate::tokens::{BlockHash, SequenceHash, Token};
pub type NumBlocks = usize; pub type NumBlocks = usize;
/// Represents different block movement operations in the cache /// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field /// For Use and Promote variants, block hashes are included for KV event publishing
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock { pub enum MoveBlock {
Use(Vec<UniqueBlock>), Use(Vec<UniqueBlock>, Vec<BlockHash>),
Destroy(Vec<UniqueBlock>), Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>), Deref(Vec<UniqueBlock>),
Promote(Uuid, SequenceHash, Option<u64>), Promote(Uuid, SequenceHash, Option<u64>, BlockHash),
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
...@@ -50,7 +46,7 @@ pub struct PrefillCost { ...@@ -50,7 +46,7 @@ pub struct PrefillCost {
impl PrefillCost { impl PrefillCost {
pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 { pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens); let tokens = new_tokens.unwrap_or(self.new_tokens);
1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1 4.209989e-07 * (tokens as f64).powi(2) + 1.518344e-02 * (tokens as f64) + 1.650142e+01
} }
} }
...@@ -260,49 +256,6 @@ impl MockEngineArgs { ...@@ -260,49 +256,6 @@ impl MockEngineArgs {
} }
} }
/// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData.
///
/// This function assumes that the stored sequence hashes in the response always
/// correspond to the tail part of the local hashes array. This is the expected
/// behavior of KV block storage, where blocks are stored sequentially and the
/// response contains the most recent blocks that were stored.
///
/// # Panics
/// Panics if the number of blocks in the Store response exceeds the length
/// of local_hashes.
pub fn block_response_to_kv_event(
response: MoveBlockResponse,
local_hashes: &[BlockHash],
) -> KvCacheEventData {
match response {
MoveBlockResponse::Store(full_blocks, parent_hash) => {
let num_blocks = full_blocks.len();
let local_hashes_slice = &local_hashes[local_hashes
.len()
.checked_sub(num_blocks)
.expect("local hashes fewer than block response signal")..];
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.zip(local_hashes_slice.iter())
.map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
})
.collect(),
})
}
MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
}),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::ops::{Add, Div, Sub};
/// A generic running mean calculator with a fixed-size sliding window.
/// Maintains a running sum and count to compute the mean in O(1) time.
#[derive(Debug, Clone)]
pub struct RunningMean<T>
where
T: Copy + Add<Output = T> + Sub<Output = T> + Div<Output = T> + Default + From<u16>,
{
max_size: u16,
sum: T,
values: VecDeque<T>,
}
impl<T> RunningMean<T>
where
T: Copy + Add<Output = T> + Sub<Output = T> + Div<Output = T> + Default + From<u16>,
{
pub fn new(max_size: u16) -> Self {
Self {
max_size,
sum: T::default(),
values: VecDeque::with_capacity(max_size as usize),
}
}
pub fn push(&mut self, value: T) {
// If at capacity, remove the oldest value from sum
if self.values.len() >= self.max_size as usize
&& let Some(old_value) = self.values.pop_front()
{
self.sum = self.sum - old_value;
}
// Add new value
self.sum = self.sum + value;
self.values.push_back(value);
}
pub fn mean(&self) -> T {
if self.values.is_empty() {
T::default()
} else {
self.sum / T::from(self.values.len() as u16)
}
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
/// Clear all values from the window.
pub fn clear(&mut self) {
self.sum = T::default();
self.values.clear();
}
}
...@@ -28,16 +28,16 @@ ...@@ -28,16 +28,16 @@
//! ## NOTE //! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP //! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats}; use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats};
use crate::mocker::evictor::LRUEvictor; use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager; use crate::mocker::kv_manager::KvManager;
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse}; use crate::mocker::protocols::{
use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event}; DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType,
};
use crate::mocker::running_mean::RunningMean;
use crate::mocker::sequence::ActiveSequence; use crate::mocker::sequence::ActiveSequence;
use crate::tokens::BlockHash;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use std::collections::HashMap; use std::collections::{HashMap, VecDeque};
use std::collections::VecDeque;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::time::Duration; use tokio::time::Duration;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -111,9 +111,8 @@ impl SchedulerState { ...@@ -111,9 +111,8 @@ impl SchedulerState {
/// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where: /// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation /// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation /// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `block_hashes`: Block hashes of the sequence beign prefilled
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked /// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, Vec<BlockHash>, bool)> { fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?; let uuid = self.prefill.pop_front()?;
// Remove and extract prefill_compute from prefill_costs // Remove and extract prefill_compute from prefill_costs
...@@ -168,7 +167,6 @@ impl SchedulerState { ...@@ -168,7 +167,6 @@ impl SchedulerState {
Some(( Some((
prefill_compute, prefill_compute,
sequence.take_creation_signal(), sequence.take_creation_signal(),
sequence.block_hashes(),
is_full_prefill, is_full_prefill,
)) ))
} }
...@@ -247,17 +245,9 @@ impl Scheduler { ...@@ -247,17 +245,9 @@ impl Scheduler {
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>, component: Option<dynamo_runtime::component::Component>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
// Create internal channel for KV events only if needed
let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
(Some(tx), Some(rx))
} else {
(None, None)
};
// Assert speedup_ratio is greater than 0 // Assert speedup_ratio is greater than 0
assert!( assert!(
args.speedup_ratio > 0.0, args.speedup_ratio > 0.0,
...@@ -278,121 +268,64 @@ impl Scheduler { ...@@ -278,121 +268,64 @@ impl Scheduler {
tokio::spawn(async move { tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task // Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::new(args.max_num_batched_tokens); let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut kv_manager = let mut kv_manager = KvManager::new_with_publisher(
KvManager::new_with_sender(args.num_gpu_blocks, args.block_size, block_resp_tx); args.num_gpu_blocks,
let mut hit_rates = VecDeque::with_capacity(1000); args.block_size,
let mut should_schedule = true; component,
dp_rank,
);
let mut hit_rates = RunningMean::new(1000);
loop { loop {
{ // 1. Receive requests
// Enqueue new request, blocks until at least one is received, so no redundant work is done if state.is_empty() {
if state.is_empty() { // Fully idle - block until new request arrives
let Some(request) = request_rx.recv().await else { tokio::select! {
tracing::warn!("request sender is dropped"); biased;
Some(request) = request_rx.recv() => {
state.receive(request);
}
_ = cancel_token_clone.cancelled() => {
break; break;
}; }
state.receive(request);
} }
} } else {
// Has active/waiting work - collect any pending requests without blocking
tokio::select! { while let Ok(request) = request_rx.try_recv() {
biased;
// Enqueue new request
Some(request) = request_rx.recv() => {
state.receive(request); state.receive(request);
} }
// Try Scheduling Requests - runs on normal interval or after simulation
_ = tokio::task::yield_now() => {
// Skip if we just ran scheduling after simulation to prevent consecutive runs
if !should_schedule {
continue;
}
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state.next() {
let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching);
// Update predictive budgets
let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
// this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check various budgets to see if possible to schedule
let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
// If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) {
state.first_in_line(uuid, Request::Active(active_sequence));
break;
}
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
hit_rates.push_back(hit_rate);
if hit_rates.len() > 1000 {
hit_rates.pop_front();
}
state.move_to_prefill(uuid, active_sequence, prefill_cost);
should_schedule = false;
}
}
// Check for cancellation // Check for cancellation
_ = cancel_token_clone.cancelled() => { if cancel_token_clone.is_cancelled() {
break; break;
} }
} }
// Simulates prefill + decode // Start timing for this forward pass (schedule + simulate)
// Base time needed for decoding using active percentage and quadratic formula let iteration_start = std::time::Instant::now();
let active_perc = kv_manager.get_active_perc();
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44; // 2. Schedule waiting requests (once per iteration)
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0); try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode
let mut total_time = Duration::ZERO;
// Process prefilling // Process prefilling
while let Some(( while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
prefill_compute, state.try_prefill()
maybe_creation_signal,
block_hashes,
is_full_prefill,
)) = state.try_prefill()
{ {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they // NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior. // could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0); // For decode workers, skip adding prefill compute time
if args.worker_type != WorkerType::Decode {
if let Some(creation_signal) = maybe_creation_signal { total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
if !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal)) }
{
panic!("Block allocation for prefilling cannot fail.");
}
// Drain KV events and forward to relay after prefill signal processing if let Some(creation_signal) = maybe_creation_signal
if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) { && !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal))
while let Ok(event) = rx.try_recv() { {
let _ = panic!("Block allocation for prefilling cannot fail.");
relay_tx.send(block_response_to_kv_event(event, &block_hashes)); }
}
}
};
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill { if !is_full_prefill {
...@@ -400,13 +333,15 @@ impl Scheduler { ...@@ -400,13 +333,15 @@ impl Scheduler {
} }
} }
let active_perc = kv_manager.get_active_perc();
// TODO: share the same logic with Planner
let decoding_time = -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74;
total_time += Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens(); state.reset_active_tokens();
// Process decoding // Process decoding
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect(); let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
if !uuids.is_empty() {
should_schedule = true
};
for uuid in uuids { for uuid in uuids {
let Some(sequence) = state.run(uuid) else { let Some(sequence) = state.run(uuid) else {
continue; continue;
...@@ -423,14 +358,6 @@ impl Scheduler { ...@@ -423,14 +358,6 @@ impl Scheduler {
continue; continue;
} }
// Drain KV events and forward to relay after decode signal processing
if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx
.send(block_response_to_kv_event(event, &sequence.block_hashes()));
}
}
// Check completion and send notification // Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = let should_output =
...@@ -465,11 +392,13 @@ impl Scheduler { ...@@ -465,11 +392,13 @@ impl Scheduler {
let _ = metrics_tx.send(metrics); let _ = metrics_tx.send(metrics);
} }
// Sleep once for the adjusted duration // 4. Sleep to maintain target iteration timing
let adjusted_time = let target_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 { let elapsed = iteration_start.elapsed();
tokio::time::sleep(adjusted_time).await;
if elapsed < target_duration {
tokio::time::sleep(target_duration - elapsed).await;
} }
} }
}); });
...@@ -499,7 +428,7 @@ impl Scheduler { ...@@ -499,7 +428,7 @@ impl Scheduler {
fn get_fwd_pass_metrics( fn get_fwd_pass_metrics(
state: &SchedulerState, state: &SchedulerState,
kv_manager: &KvManager, kv_manager: &KvManager,
hit_rates: &VecDeque<f32>, hit_rates: &RunningMean<f32>,
dp_rank: u32, dp_rank: u32,
) -> ForwardPassMetrics { ) -> ForwardPassMetrics {
// Get state metrics // Get state metrics
...@@ -507,7 +436,7 @@ fn get_fwd_pass_metrics( ...@@ -507,7 +436,7 @@ fn get_fwd_pass_metrics(
let num_requests_waiting = state.waiting.len() as u64; let num_requests_waiting = state.waiting.len() as u64;
// Get KV manager metrics // Get KV manager metrics
let active_blocks_count = kv_manager.active_blocks().len() as u64; let active_blocks_count = kv_manager.num_active_blocks() as u64;
let total_capacity = kv_manager.max_capacity() as u64; let total_capacity = kv_manager.max_capacity() as u64;
let gpu_cache_usage_perc = if total_capacity > 0 { let gpu_cache_usage_perc = if total_capacity > 0 {
active_blocks_count as f32 / total_capacity as f32 active_blocks_count as f32 / total_capacity as f32
...@@ -515,13 +444,8 @@ fn get_fwd_pass_metrics( ...@@ -515,13 +444,8 @@ fn get_fwd_pass_metrics(
0.0 0.0
}; };
// Get hit rate metrics // Get hit rate metrics - O(1) access
let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() { let gpu_prefix_cache_hit_rate = hit_rates.mean();
0.0
} else {
let sum: f32 = hit_rates.iter().sum();
sum / hit_rates.len() as f32
};
let worker_stats = WorkerStats { let worker_stats = WorkerStats {
data_parallel_rank: Some(dp_rank), data_parallel_rank: Some(dp_rank),
...@@ -546,26 +470,75 @@ fn get_fwd_pass_metrics( ...@@ -546,26 +470,75 @@ fn get_fwd_pass_metrics(
} }
} }
/// Convert a Request to an ActiveSequence /// Attempts to schedule waiting requests from the state queue.
fn get_active_sequence( /// Returns the number of requests successfully scheduled.
request: Request, fn try_schedule(
block_size: usize, state: &mut SchedulerState,
enable_prefix_caching: bool, kv_manager: &KvManager,
) -> ActiveSequence { hit_rates: &mut RunningMean<f32>,
if let Request::Active(active_seq) = request { args: &MockEngineArgs,
return active_seq; ) -> usize {
} let mut scheduled_count = 0;
let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state.next() {
// Convert Request to ActiveSequence
let active_sequence = match request {
Request::Active(active_seq) => active_seq,
Request::Direct(direct_request) => ActiveSequence::new(
direct_request.tokens,
direct_request.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
),
};
let Request::Direct(direct_request) = request else { // Update predictive budgets
unreachable!("Request must be either Direct or Active"); let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
}; let total_tokens = active_sequence.len();
// this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check various budgets to see if possible to schedule
let under_block_budget =
current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
// If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {
current_tokens - new_tokens
} else {
current_tokens
};
let under_token_budget = args
.max_num_batched_tokens
.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) {
state.first_in_line(uuid, Request::Active(active_sequence));
break;
}
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() {
1.0 - (new_tokens as f32 / active_sequence.len() as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
state.move_to_prefill(uuid, active_sequence, prefill_cost);
scheduled_count += 1;
}
ActiveSequence::new( scheduled_count
direct_request.tokens,
direct_request.max_output_tokens,
Some(block_size),
enable_prefix_caching,
)
} }
/// Processes MoveBlock signals with the KvManager. /// Processes MoveBlock signals with the KvManager.
...@@ -582,7 +555,7 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool { ...@@ -582,7 +555,7 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
} }
// Check we have a Use signal with blocks // Check we have a Use signal with blocks
let MoveBlock::Use(blocks) = signal else { let MoveBlock::Use(blocks, _hashes) = signal else {
panic!( panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}" "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
); );
......
...@@ -6,12 +6,10 @@ use crate::tokens::blocks::UniqueBlock; ...@@ -6,12 +6,10 @@ use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{TokenBlockSequence, Tokens}; use crate::tokens::{TokenBlockSequence, Tokens};
use derive_getters::Getters; use derive_getters::Getters;
use rand::random; use rand::random;
use uuid;
/// Create unique blocks from a TokenBlockSequence /// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence( fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence, tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize, block_size: usize,
enable_prefix_caching: bool, enable_prefix_caching: bool,
) -> Vec<UniqueBlock> { ) -> Vec<UniqueBlock> {
...@@ -29,10 +27,7 @@ fn create_unique_blocks_from_sequence( ...@@ -29,10 +27,7 @@ fn create_unique_blocks_from_sequence(
// Only push the partial block if tokens count isn't a multiple of block_size // Only push the partial block if tokens count isn't a multiple of block_size
if !tokens.total_tokens().is_multiple_of(block_size) { if !tokens.total_tokens().is_multiple_of(block_size) {
unique_blocks.push(match uuid { unique_blocks.push(UniqueBlock::default());
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
} }
unique_blocks unique_blocks
} }
...@@ -80,8 +75,9 @@ impl ActiveSequence { ...@@ -80,8 +75,9 @@ impl ActiveSequence {
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337)); let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let unique_blocks = let unique_blocks =
create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching); create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone())); let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect();
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), block_hashes));
Self { Self {
unique_blocks, unique_blocks,
...@@ -132,17 +128,6 @@ impl ActiveSequence { ...@@ -132,17 +128,6 @@ impl ActiveSequence {
(sequence, signal) (sequence, signal)
} }
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn get_parent_hash(&self) -> Option<u64> {
if self.unique_blocks.len() < 2 {
return None;
}
match &self.unique_blocks[self.unique_blocks.len() - 2] {
UniqueBlock::FullBlock(hash) => Some(*hash),
_ => panic!("Cannot have a partial block as parent"),
}
}
/// Push a token to the sequence /// Push a token to the sequence
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> { pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed."); self.tokens.append(token).expect("Token push failed.");
...@@ -158,24 +143,33 @@ impl ActiveSequence { ...@@ -158,24 +143,33 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists // Replace last partial block with full block if it exists
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() { if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_block_hash = if self.enable_prefix_caching { let last_seq_hash = if self.enable_prefix_caching {
self.tokens.last_complete_block().unwrap().sequence_hash() self.tokens.last_complete_block().unwrap().sequence_hash()
} else { } else {
random::<u64>() random::<u64>()
}; };
let last_block_hash = self.tokens.last_complete_block().unwrap().block_hash();
self.unique_blocks.pop(); self.unique_blocks.pop();
// After pop, the last element is the parent block
let second_to_last_hash = self.unique_blocks.last().map(|block| match block {
UniqueBlock::FullBlock(hash) => *hash,
UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"),
});
self.unique_blocks self.unique_blocks
.push(UniqueBlock::FullBlock(last_block_hash)); .push(UniqueBlock::FullBlock(last_seq_hash));
signals.push(MoveBlock::Promote( signals.push(MoveBlock::Promote(
uuid, uuid,
last_seq_hash,
second_to_last_hash,
last_block_hash, last_block_hash,
self.get_parent_hash(),
)); ));
} }
let new_partial_block = UniqueBlock::default(); let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone()); self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block])); signals.push(MoveBlock::Use(vec![new_partial_block], vec![]));
Some(signals) Some(signals)
} }
...@@ -241,13 +235,15 @@ impl ActiveSequence { ...@@ -241,13 +235,15 @@ impl ActiveSequence {
self.tokens.truncate(self.num_input_tokens).unwrap(); self.tokens.truncate(self.num_input_tokens).unwrap();
self.unique_blocks = create_unique_blocks_from_sequence( self.unique_blocks = create_unique_blocks_from_sequence(
&self.tokens, &self.tokens,
None,
self.block_size, self.block_size,
self.enable_prefix_caching, self.enable_prefix_caching,
); );
self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens); self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens);
self.generated_tokens = 0; self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone())); self.creation_signal = Some(MoveBlock::Use(
self.unique_blocks.clone(),
self.block_hashes(),
));
free_signal free_signal
} }
...@@ -280,7 +276,7 @@ mod tests { ...@@ -280,7 +276,7 @@ mod tests {
// Check that we got a Use signal // Check that we got a Use signal
assert!(signal1.is_some()); assert!(signal1.is_some());
match &signal1 { match &signal1 {
Some(MoveBlock::Use(blocks)) => { Some(MoveBlock::Use(blocks, _hashes)) => {
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
} }
_ => panic!("Expected Use signal"), _ => panic!("Expected Use signal"),
...@@ -301,7 +297,7 @@ mod tests { ...@@ -301,7 +297,7 @@ mod tests {
// First signal should be Promote for the previous block // First signal should be Promote for the previous block
match &signal_16[0] { match &signal_16[0] {
MoveBlock::Promote(_, _, parent_hash) => { MoveBlock::Promote(_, _, parent_hash, _hash) => {
assert_eq!(*parent_hash, None); assert_eq!(*parent_hash, None);
} }
_ => panic!("Expected Promote signal as second signal"), _ => panic!("Expected Promote signal as second signal"),
...@@ -309,7 +305,7 @@ mod tests { ...@@ -309,7 +305,7 @@ mod tests {
// Second signal should be Use for new partial block // Second signal should be Use for new partial block
match &signal_16[1] { match &signal_16[1] {
MoveBlock::Use(blocks) => { MoveBlock::Use(blocks, _hashes) => {
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
} }
...@@ -396,7 +392,7 @@ mod tests { ...@@ -396,7 +392,7 @@ mod tests {
// Check that signal[0] is promote // Check that signal[0] is promote
match &signal[0] { match &signal[0] {
MoveBlock::Promote(_, _, parent_hash) => { MoveBlock::Promote(_, _, parent_hash, _hash) => {
// Check that the parent_hash matches unique_blocks[1], which should be a full block // Check that the parent_hash matches unique_blocks[1], which should be a full block
if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] { if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
assert_eq!( assert_eq!(
...@@ -430,7 +426,7 @@ mod tests { ...@@ -430,7 +426,7 @@ mod tests {
// Initial signal - should have received a Use signal for the partial block // Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some()); assert!(signal.is_some());
match signal { match signal {
Some(MoveBlock::Use(blocks)) => { Some(MoveBlock::Use(blocks, _hashes)) => {
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
} }
...@@ -448,7 +444,7 @@ mod tests { ...@@ -448,7 +444,7 @@ mod tests {
// First signal should be Promote // First signal should be Promote
match &signals_second[0] { match &signals_second[0] {
MoveBlock::Promote(_, _, parent_hash) => { MoveBlock::Promote(_, _, parent_hash, _hash) => {
assert_eq!(*parent_hash, None); assert_eq!(*parent_hash, None);
} }
_ => panic!("Expected Promote signal as first signal after second token"), _ => panic!("Expected Promote signal as first signal after second token"),
...@@ -456,7 +452,7 @@ mod tests { ...@@ -456,7 +452,7 @@ mod tests {
// Second signal should be Use for new partial block // Second signal should be Use for new partial block
match &signals_second[1] { match &signals_second[1] {
MoveBlock::Use(blocks) => { MoveBlock::Use(blocks, _hashes) => {
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment