Unverified Commit 236cb17d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(mocker): align vLLM scheduler with v1 — drop watermark, LIFO preemption, retry loop (#7020)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent b407b419
......@@ -21,7 +21,7 @@ The mocker engine now supports a vLLM-style CLI interface with individual argume
- `--max-num-batched-tokens`: Maximum number of batched tokens per iteration (default: 8192)
- `--enable-prefix-caching` / `--no-enable-prefix-caching`: Enable/disable automatic prefix caching (default: True)
- `--enable-chunked-prefill` / `--no-enable-chunked-prefill`: Enable/disable chunked prefill (default: True)
- `--watermark`: KV cache watermark threshold as a fraction (default: 0.01)
- `--preemption-mode`: Preemption mode for decode eviction under memory pressure: `lifo` (default, matches vLLM v1) or `fifo`
- `--speedup-ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster. Use `0` for infinite speedup (no simulation delays)
- `--data-parallel-size`: Number of data parallel workers to simulate (default: 1)
- `--num-workers`: Number of mocker workers to launch in the same process (default: 1). All workers share the same tokio runtime and thread pool
......
......@@ -107,7 +107,7 @@ def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
"max_num_batched_tokens": getattr(args, "max_num_batched_tokens", None),
"enable_prefix_caching": getattr(args, "enable_prefix_caching", None),
"enable_chunked_prefill": getattr(args, "enable_chunked_prefill", None),
"watermark": getattr(args, "watermark", None),
"preemption_mode": getattr(args, "preemption_mode", None),
"speedup_ratio": getattr(args, "speedup_ratio", None),
"dp_size": getattr(args, "dp_size", None),
"startup_time": getattr(args, "startup_time", None),
......@@ -287,10 +287,13 @@ def parse_args() -> argparse.Namespace:
help="Disable chunked prefill",
)
parser.add_argument(
"--watermark",
type=float,
"--preemption-mode",
type=str,
default=None,
help="Watermark value for the mocker engine (default: 0.01)",
choices=["lifo", "fifo"],
help="Preemption mode for decode eviction under memory pressure. "
"'lifo' (default) evicts the newest request (matches vLLM v1), "
"'fifo' evicts the oldest request.",
)
parser.add_argument(
"--speedup-ratio",
......
......@@ -75,6 +75,16 @@ pub struct OutputSignal {
pub completed: bool,
}
/// Preemption policy for evicting decode requests under memory pressure
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PreemptionMode {
/// Evict the newest request (matches vLLM v1 default)
#[default]
Lifo,
/// Evict the oldest request
Fifo,
}
/// Worker type for disaggregated serving configurations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
......@@ -147,10 +157,6 @@ pub struct MockEngineArgs {
#[builder(default = true)]
pub enable_chunked_prefill: bool,
#[builder(default = "0.01")]
#[validate(range(min = 0.0, max = 1.0))]
pub watermark: f64,
#[builder(default = "1.0")]
#[validate(range(min = 0.0))]
pub speedup_ratio: f64,
......@@ -205,6 +211,11 @@ pub struct MockEngineArgs {
/// A KvEventPublisher relay subscribes to this socket and forwards events to NATS.
#[builder(default = "None")]
pub zmq_kv_events_port: Option<u16>,
/// Preemption mode for decode eviction under memory pressure.
/// Lifo (default) evicts the newest request; Fifo evicts the oldest.
#[builder(default)]
pub preemption_mode: PreemptionMode,
}
impl Default for MockEngineArgs {
......@@ -248,7 +259,6 @@ impl MockEngineArgs {
"max_num_batched_tokens",
"enable_prefix_caching",
"enable_chunked_prefill",
"watermark",
"speedup_ratio",
"dp_size",
"startup_time",
......@@ -261,6 +271,7 @@ impl MockEngineArgs {
"kv_transfer_bandwidth",
"reasoning",
"zmq_kv_events_port",
"preemption_mode",
]
.iter()
.cloned()
......@@ -318,12 +329,6 @@ impl MockEngineArgs {
builder = builder.enable_chunked_prefill(enabled);
}
if let Some(value) = extra_args.get("watermark")
&& let Some(num) = value.as_f64()
{
builder = builder.watermark(num);
}
if let Some(value) = extra_args.get("speedup_ratio")
&& let Some(num) = value.as_f64()
{
......@@ -378,6 +383,22 @@ impl MockEngineArgs {
builder = builder.zmq_kv_events_port(Some(port as u16));
}
if let Some(value) = extra_args.get("preemption_mode")
&& let Some(mode_str) = value.as_str()
{
let mode = match mode_str {
"lifo" => PreemptionMode::Lifo,
"fifo" => PreemptionMode::Fifo,
_ => {
return Err(anyhow::anyhow!(
"Invalid preemption_mode: '{}'. Must be 'lifo' or 'fifo'.",
mode_str
));
}
};
builder = builder.preemption_mode(mode);
}
// Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args
.get("is_prefill")
......
......@@ -54,7 +54,8 @@ pub struct ActiveSequence {
#[getter(copy)]
num_input_tokens: usize,
creation_signal: Option<MoveBlock>,
#[getter(copy)]
num_allocated_tokens: usize,
#[getter(copy)]
enable_prefix_caching: bool,
......@@ -75,28 +76,9 @@ impl ActiveSequence {
let block_size = block_size.unwrap_or(64);
let num_input_tokens = tokens.len();
let block_token_ids: Option<Vec<Vec<u32>>> = if emit_token_ids {
let num_complete = tokens.len() / block_size;
Some(
tokens
.chunks(block_size)
.take(num_complete)
.map(|c| c.to_vec())
.collect(),
)
} else {
None
};
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let unique_blocks =
create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching);
let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect();
let creation_signal = Some(MoveBlock::Use(
unique_blocks.clone(),
block_hashes,
block_token_ids,
));
let seq = Self {
unique_blocks,
......@@ -105,7 +87,7 @@ impl ActiveSequence {
max_output_tokens,
generated_tokens: 0,
num_input_tokens,
creation_signal,
num_allocated_tokens: 0,
enable_prefix_caching,
emit_token_ids,
};
......@@ -125,8 +107,60 @@ impl ActiveSequence {
self.tokens.total_tokens() == 0
}
/// Build a `MoveBlock::Use` signal for blocks up to `cumulative_tokens`
/// without updating internal state. Returns `None` if no new blocks are needed.
/// Call `commit_allocation` after the signal is successfully processed.
pub fn prepare_allocation(&self, cumulative_tokens: usize) -> Option<MoveBlock> {
let prev_blocks = self
.num_allocated_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
let target_blocks = cumulative_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
if target_blocks <= prev_blocks {
return None;
}
let range = prev_blocks..target_blocks;
let blocks = self.unique_blocks[range.clone()].to_vec();
let all_hashes = self.block_hashes();
let num_full = all_hashes.len();
let hash_start = prev_blocks.min(num_full);
let hash_end = target_blocks.min(num_full);
let hashes = all_hashes[hash_start..hash_end].to_vec();
let token_ids = if self.emit_token_ids && hash_start < hash_end {
let all_token_ids: Vec<Vec<u32>> = self
.tokens
.blocks()
.iter()
.map(|b| b.tokens().to_vec())
.collect();
Some(all_token_ids[hash_start..hash_end].to_vec())
} else {
None
};
Some(MoveBlock::Use(blocks, hashes, token_ids))
}
/// Commit a successful allocation by advancing `num_allocated_tokens`.
pub fn commit_allocation(&mut self, cumulative_tokens: usize) {
self.num_allocated_tokens = cumulative_tokens;
}
/// Prepare + commit in one call (convenience for paths where failure is impossible).
pub fn allocate_blocks_for_chunk(&mut self, cumulative_tokens: usize) -> Option<MoveBlock> {
let signal = self.prepare_allocation(cumulative_tokens);
self.commit_allocation(cumulative_tokens);
signal
}
/// Allocate all remaining blocks at once (backward compat).
pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
self.creation_signal.take()
self.allocate_blocks_for_chunk(self.len())
}
pub fn block_hashes(&self) -> Vec<u64> {
......@@ -262,31 +296,12 @@ impl ActiveSequence {
.collect()
}
/// Move the request to a preempted state and return the free signals from freeing current blocks
/// Move the request to a preempted state and return the free signals from freeing current blocks.
/// Upon preemption, the sequence retains the tokens generated during the decode phase (if any).
/// Resets `num_allocated_tokens` so re-admission will re-allocate from scratch.
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal();
// Don't reset generated_tokens since we're keeping the tokens in the sequence
let block_token_ids = if self.emit_token_ids {
Some(
self.tokens
.blocks()
.iter()
.map(|b| b.tokens().to_vec())
.collect(),
)
} else {
None
};
self.creation_signal = Some(MoveBlock::Use(
self.unique_blocks.clone(),
self.block_hashes(),
block_token_ids,
));
self.num_allocated_tokens = 0;
free_signal
}
......
......@@ -28,7 +28,7 @@
//! ## Preemption
//! If a Use operation fails (typically due to insufficient space), a false boolean signal
//! is returned to the scheduler for preemption. Initial KV block allocations for new requests
//! should not fail due to the watermark checking.
//! should not fail due to the capacity checking during scheduling.
//!
//! ## NOTE
//! For simplicity (or non-simplicity), reference counting is tracked manually instead of using
......@@ -177,8 +177,14 @@ impl KvManager {
}
}
/// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool {
/// Process a MoveBlock instruction synchronously.
///
/// For `MoveBlock::Use`, returns the number of blocks successfully allocated.
/// On partial failure, blocks 0..N are committed but block N+1 could not be
/// allocated. Callers should use the return value to track partial progress.
///
/// For other variants, returns the total block count (they always succeed or panic).
pub fn process(&mut self, event: &MoveBlock) -> usize {
match event {
MoveBlock::Use(hashes, local_hashes, token_ids) => {
let mut blocks_stored = Vec::<u64>::new();
......@@ -186,25 +192,28 @@ impl KvManager {
token_ids.as_ref().map(|_| Vec::new());
let mut parent_block: Option<&UniqueBlock> = None;
let mut allocated = 0;
for (i, hash) in hashes.iter().enumerate() {
// First check if it already exists in active blocks
if self.cache.contains_active(hash) {
// Block already active, just increment reference count
self.cache.increment_ref(hash);
parent_block = Some(hash);
allocated += 1;
continue;
}
// Then check if it exists in inactive and move it to active if found
if self.cache.reactivate(hash) {
parent_block = Some(hash);
allocated += 1;
continue;
}
// If at max capacity, evict the oldest entry from inactive blocks
if self.cache.is_at_capacity() {
let Some(evicted) = self.cache.evict_inactive() else {
return false;
return allocated;
};
tracing::trace!(
"Evicting block from inactive pool: {evicted:?}, dp_rank={}",
......@@ -217,6 +226,7 @@ impl KvManager {
// Now insert the new block in active blocks with reference count 1
self.cache.insert_active(hash.clone(), 1);
allocated += 1;
// Track blocks for trace/event
if let UniqueBlock::FullBlock(stored_full_block) = hash {
blocks_stored.push(*stored_full_block);
......@@ -238,6 +248,7 @@ impl KvManager {
true,
stored_token_ids,
);
return allocated;
}
MoveBlock::Destroy(hashes) => {
......@@ -306,8 +317,7 @@ impl KvManager {
}
}
// Return true if we made it this far
true
1
}
/// Get the count of blocks that aren't in active or inactive pools
......@@ -406,8 +416,8 @@ mod tests {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
// Helper function to use multiple blocks that returns the count allocated
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> usize {
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes, None))
......@@ -415,16 +425,16 @@ mod tests {
// First use 10 blocks (0 to 9) in a batch
let response = use_blocks(&mut manager, (0..10).collect());
assert!(response, "Expected success response");
assert_eq!(response, 10, "Expected all 10 blocks allocated");
// Verify we are at capacity
assert_eq!(manager.current_capacity(), 10);
// The 11th block should return false, not panic
// The 11th block should return 0, not panic
let response = use_blocks(&mut manager, vec![10]);
assert!(
!response,
"Expected failure response when exceeding max capacity"
assert_eq!(
response, 0,
"Expected 0 blocks allocated when exceeding max capacity"
);
}
......
......@@ -9,7 +9,7 @@
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
//! The scheduler uses a watermark-based approach to determine if there's sufficient
//! The scheduler checks direct block capacity to determine if there's sufficient
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
......@@ -22,16 +22,15 @@
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
//! LRU-based preemption strategy where the oldest running request is evicted and
//! placed at the back of the waiting queue to be rescheduled later.
//! preemption strategy (LIFO by default, matching vLLM v1) where a running request
//! is evicted and placed at the front of the waiting queue to be rescheduled later.
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::common::evictor::LRUEvictor;
use crate::common::perf_model::PerfModel;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost,
DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
WorkerType,
};
use crate::common::running_mean::RunningMean;
......@@ -66,128 +65,55 @@ pub enum Request {
struct SchedulerState {
waiting: VecDeque<Uuid>,
prefill: VecDeque<Uuid>,
decode: LRUEvictor<Uuid>,
decode: VecDeque<Uuid>,
requests: HashMap<Uuid, Request>,
prefill_costs: HashMap<Uuid, PrefillCost>,
max_num_batched_tokens: Option<usize>,
active_tokens: usize,
waiting_tokens: usize,
}
impl SchedulerState {
fn new(max_num_batched_tokens: Option<usize>) -> Self {
SchedulerState {
max_num_batched_tokens,
..Default::default()
}
}
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
self.requests.insert(uuid, Request::Direct(request));
self.waiting.push_back(uuid);
uuid
}
/// Get the next UUID from ready or waiting queue and its associated Request.
fn next(&mut self) -> Option<(Uuid, Request)> {
let uuid = self.waiting.pop_front()?;
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
Some((uuid, request))
}
/// Move a UUID and its Request to the waiting queue (front).
fn first_in_line(&mut self, uuid: Uuid, request: Request) {
self.requests.insert(uuid, request);
self.waiting.push_front(uuid);
}
/// Move a UUID and its Request to the ready queue.
fn move_to_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: PrefillCost) {
self.waiting_tokens += cost.new_tokens;
self.requests.insert(uuid, Request::Active(active_seq));
self.prefill.push_back(uuid);
self.prefill_costs.insert(uuid, cost);
}
/// Try (chunked) prefill and move to decode queue
///
/// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self, perf_model: &PerfModel) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?;
// Remove and extract prefill_compute from prefill_costs
let mut prefill_cost = self
.prefill_costs
.remove(&uuid)
.expect("Expects valid prefill cost.");
let new_tokens = prefill_cost.new_tokens;
let maybe_prefill_tokens = self.max_num_batched_tokens.and_then(|max_tokens| {
let remaining_tokens = max_tokens - self.active_tokens;
if prefill_cost.new_tokens > remaining_tokens {
Some(remaining_tokens)
} else {
None
}
});
let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
{
let prefill_compute =
prefill_cost.predict_prefill_compute(Some(prefill_tokens), perf_model);
prefill_cost.new_tokens -= prefill_tokens;
assert!(
prefill_cost.new_tokens > 0,
"Encountered negative prefill tokens."
);
self.prefill.push_front(uuid);
self.prefill_costs.insert(uuid, prefill_cost);
self.active_tokens = self.max_num_batched_tokens.unwrap();
self.waiting_tokens -= prefill_tokens;
(prefill_compute, false)
} else {
// Assume possible to complete prefilling the sequence, transfer to decode
self.decode.insert(uuid);
self.active_tokens += new_tokens;
self.waiting_tokens -= new_tokens;
(prefill_cost.predict_prefill_compute(None, perf_model), true)
/// Try to admit one request from waiting → prefill.
/// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
/// later in simulate_prefill when the request reaches the front of the queue.
fn admit_one(&mut self, args: &MockEngineArgs) -> bool {
let Some(&uuid) = self.waiting.front() else {
return false;
};
let num_active = self.prefill.len() + self.decode.len();
if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
return false;
}
// NOTE: the current behavior allocates the KV blocks for the entire sequence,
// even if only a chunk is prefilled
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
self.waiting.pop_front();
Some((
prefill_compute,
sequence.take_creation_signal(),
is_full_prefill,
))
}
// Convert DirectRequest → ActiveSequence if needed
if let Some(Request::Direct(_)) = self.requests.get(&uuid) {
let Some(Request::Direct(direct)) = self.requests.remove(&uuid) else {
unreachable!()
};
self.requests.insert(
uuid,
Request::Active(ActiveSequence::new(
direct.tokens,
direct.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
args.zmq_kv_events_port.is_some(),
)),
);
}
// assume (chunked) prefills are completed, then active tokens would be 1 per decoding sequence
fn reset_active_tokens(&mut self) {
self.active_tokens = self.decode.len();
self.prefill.push_back(uuid);
true
}
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
......@@ -200,47 +126,37 @@ impl SchedulerState {
Some(sequence)
}
fn num_active_requests(&self) -> usize {
self.prefill.len() + self.decode.len()
}
/// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) {
tracing::trace!("Request {uuid} will complete");
self.decode.remove(uuid);
self.decode.retain(|u| u != uuid);
self.requests.remove(uuid);
self.prefill_costs.remove(uuid);
self.active_tokens -= 1;
}
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
/// and adding it back to the waiting queue.
/// Returns the signal from reset_with_signal or None if no requests are running.
fn preempt(&mut self) -> Vec<MoveBlock> {
// Evict the oldest UUID from running
let uuid = self
.decode
.evict()
.expect("Nothing to evict for preemption.");
/// Preempt a running request by evicting it from decode, resetting the sequence,
/// and adding it back to the front of the waiting queue.
/// In LIFO mode, evicts the newest request (matches vLLM v1).
/// In FIFO mode, evicts the oldest request.
fn preempt(&mut self, mode: PreemptionMode) -> Vec<MoveBlock> {
let uuid = match mode {
PreemptionMode::Lifo => self.decode.pop_back(),
PreemptionMode::Fifo => self.decode.pop_front(),
}
.expect("Nothing to evict for preemption.");
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
self.prefill_costs.remove(&uuid);
self.active_tokens -= 1;
tracing::warn!("Request {uuid} will be preempted");
// Reset the sequence and get the new sequence and signal
// Insert the new sequence back into the requests map and add to waiting queue
// Reset the sequence and re-queue for prefill
let Request::Active(mut active_sequence) = request else {
panic!("Expected ActiveSequence in running queue")
};
let signals = active_sequence.reset_with_signal();
// Note: For preemption, we don't compute hit rate since we don't have access to new_tokens
// and the sequence is being reset anyway. Hit rate tracking is primarily for new scheduling attempts.
self.first_in_line(uuid, Request::Active(active_sequence));
self.requests.insert(uuid, Request::Active(active_sequence));
self.waiting.push_front(uuid);
signals
}
......@@ -291,7 +207,7 @@ impl Scheduler {
// Spawn main background task with cancellation token
tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
......@@ -309,18 +225,8 @@ impl Scheduler {
break;
}
// 2. Schedule waiting requests (once per iteration)
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode
simulate_prefill(
&mut state,
&mut kv_manager,
&args.perf_model,
args.worker_type,
args.speedup_ratio,
)
.await;
// 2. Simulate prefill + decode
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
simulate_decode(
&mut state,
......@@ -329,10 +235,11 @@ impl Scheduler {
&args.perf_model,
args.block_size,
args.speedup_ratio,
args.preemption_mode,
)
.await;
// 4. Send metrics once per forward pass (after all prefill and decode processing)
// 3. Send metrics once per forward pass (after all prefill and decode processing)
let _ = metrics_tx.send(MockerMetrics {
dp_rank,
active_decode_blocks: kv_manager.num_active_blocks() as u64,
......@@ -347,7 +254,7 @@ impl Scheduler {
}
}
/// Add a new request to the waiting queue
/// Add a new request to the prefill queue
pub async fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
......@@ -399,41 +306,115 @@ async fn receive_requests(
}
/// Simulate prefill phase for all pending prefill requests.
/// Returns the total prefill compute time.
///
/// Handles token budget, block allocation, and preemption inline.
/// Token budget: `max_num_batched_tokens - decode.len()` (1 token per decode request).
/// When blocks are unavailable, decode requests are preempted (LIFO by default)
/// to free capacity, matching vLLM v1 behavior.
async fn simulate_prefill(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
perf_model: &PerfModel,
worker_type: WorkerType,
speedup_ratio: f64,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
) -> Duration {
let start_time = Instant::now();
let mut total_time = Duration::ZERO;
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state.try_prefill(perf_model)
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
// For decode workers, skip adding prefill compute time
if worker_type != WorkerType::Decode {
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
let mut token_budget = args
.max_num_batched_tokens
.map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));
'prefill: while token_budget > 0 {
// Drain prefill first, then pull from waiting one at a time
if state.prefill.is_empty() && !state.admit_one(args) {
break;
}
let uuid = state.prefill[0];
let Some(Request::Active(seq)) = state.requests.get(&uuid) else {
panic!("Request does not exist.");
};
let prefill_cost = kv_manager.get_prefill_cost(seq);
let sequence_len = seq.len();
let allocated_tokens = seq.num_allocated_tokens();
let remaining = prefill_cost.new_tokens;
// Token budget check
let tokens_left = sequence_len - allocated_tokens;
if !args.enable_chunked_prefill && tokens_left > token_budget {
break;
}
let chunk = tokens_left.min(token_budget);
let cumulative = allocated_tokens + chunk;
// Allocate blocks. process() returns the number of blocks committed.
// On partial success, preempt a decode request and retry — the next
// loop iteration re-prepares from the updated num_allocated_tokens.
let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
if let Some(signal) = seq.prepare_allocation(cumulative) {
let expected = match &signal {
MoveBlock::Use(blocks, ..) => blocks.len(),
_ => unreachable!(),
};
let allocated = kv_manager.process(&signal);
// Commit the blocks that were actually allocated
let committed_tokens = if allocated == expected {
cumulative
} else {
// Partial: compute token boundary from block count
let prev_blocks = allocated_tokens
.div_ceil(seq.block_size())
.min(seq.unique_blocks().len());
(prev_blocks + allocated) * seq.block_size()
};
seq.commit_allocation(committed_tokens.min(cumulative));
if allocated < expected {
if state.decode.is_empty() {
break;
}
for signal in state.preempt(args.preemption_mode) {
kv_manager.process(&signal);
}
continue 'prefill; // retry with freed capacity
}
} else {
seq.commit_allocation(cumulative);
}
if let Some(creation_signal) = maybe_creation_signal
&& !process_signals(kv_manager, std::slice::from_ref(&creation_signal))
{
panic!("Block allocation for prefilling cannot fail.");
// Accumulate prefill compute time (only for the new tokens in this chunk)
let new_tokens_in_chunk = chunk.min(remaining);
if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
total_time += Duration::from_secs_f64(
prefill_cost.predict_prefill_compute(Some(new_tokens_in_chunk), &args.perf_model)
/ 1000.0,
);
}
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill {
// Hit rate: fraction of tokens that were already cached
let hit_rate = if sequence_len > 0 {
1.0 - (remaining as f32 / sequence_len as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
token_budget -= chunk;
if cumulative >= sequence_len {
// Fully prefilled — promote to decode queue
state.prefill.pop_front();
state.decode.push_back(uuid);
} else {
// Partially prefilled — resume next iteration with updated allocated_tokens
break;
}
}
if speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
......@@ -451,6 +432,7 @@ async fn simulate_decode(
perf_model: &PerfModel,
block_size: usize,
speedup_ratio: f64,
preemption_mode: PreemptionMode,
) -> Duration {
let start_time = Instant::now();
......@@ -460,7 +442,7 @@ async fn simulate_decode(
// Compute average context length across all active decode requests
let total_length: usize = state
.decode
.keys()
.iter()
.map(|uuid| {
if let Request::Active(seq) = state.requests.get(uuid).unwrap() {
seq.len()
......@@ -475,26 +457,48 @@ async fn simulate_decode(
let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
let total_time = Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens();
// Process decoding
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
for uuid in uuids {
let Some(sequence) = state.run(uuid) else {
continue;
};
let signals = sequence.generate();
// Try to generate; if allocation fails, preempt until it succeeds
// or nothing is left to preempt (matches vLLM v1 scheduler loop).
// Reborrow sequence each iteration so the mutable ref doesn't
// conflict with state.preempt().
let mut allocated = false;
loop {
let Some(sequence) = state.run(uuid) else {
break;
};
let signals = sequence.generate();
if process_signals(kv_manager, &signals) {
allocated = true;
break;
}
sequence.pop(); // revert the failed generation
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(kv_manager, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state.preempt() {
if state.decode.is_empty() {
break;
}
// Preempt one request and free its blocks
for signal in state.preempt(preemption_mode) {
kv_manager.process(&signal);
}
// If the current request was the one preempted, stop retrying
if !state.decode.contains(&uuid) {
break;
}
}
if !allocated {
continue;
}
let Some(sequence) = state.run(uuid) else {
continue;
};
// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
......@@ -527,78 +531,6 @@ async fn simulate_decode(
total_time
}
/// Attempts to schedule waiting requests from the state queue.
/// Returns the number of requests successfully scheduled.
fn try_schedule(
state: &mut SchedulerState,
kv_manager: &KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
) -> usize {
let mut scheduled_count = 0;
let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state.next() {
// Convert Request to ActiveSequence
let active_sequence = match request {
Request::Active(active_seq) => active_seq,
Request::Direct(direct_request) => ActiveSequence::new(
direct_request.tokens,
direct_request.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
args.zmq_kv_events_port.is_some(),
),
};
// Update predictive budgets
let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
// this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check various budgets to see if possible to schedule
let under_block_budget =
current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
// If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {
current_tokens - new_tokens
} else {
current_tokens
};
let under_token_budget = args
.max_num_batched_tokens
.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) {
state.first_in_line(uuid, Request::Active(active_sequence));
break;
}
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() {
1.0 - (new_tokens as f32 / active_sequence.len() as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
state.move_to_prefill(uuid, active_sequence, prefill_cost);
scheduled_count += 1;
}
scheduled_count
}
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
......@@ -608,7 +540,7 @@ fn try_schedule(
/// indicate an unexpected state in the system.
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
for signal in signals {
if kv_manager.process(signal) {
if kv_manager.process(signal) > 0 {
continue;
}
......@@ -875,6 +807,186 @@ mod tests {
println!("Test passed! Received {received_tokens} tokens");
}
/// White-box unit test that directly creates SchedulerState + KvManager,
/// manually invokes simulate_prefill / simulate_decode, and asserts on
/// queue states and block counts after each step.
#[tokio::test]
async fn test_scheduler_internal_state_transitions() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6)
.max_num_batched_tokens(Some(12))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let r1_uuid = Uuid::from_u128(1);
let r2_uuid = Uuid::from_u128(2);
let r3_uuid = Uuid::from_u128(3);
// ── Step 1: Receive 3 requests ──
// R1: 8 input, 2 max_output → 2 blocks
// R2: 8 input, 2 max_output → 2 blocks
// R3: 12 input, 2 max_output → 3 blocks
state.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(r1_uuid),
dp_rank: 0,
});
state.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 2,
uuid: Some(r2_uuid),
dp_rank: 0,
});
state.receive(DirectRequest {
tokens: (200..212).collect(),
max_output_tokens: 2,
uuid: Some(r3_uuid),
dp_rank: 0,
});
assert_eq!(state.waiting.len(), 3);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 0);
assert_eq!(kv_manager.num_active_blocks(), 0);
// ── Step 2: First simulate_prefill ──
// Budget=12. R1 takes 8 tokens (2 blocks), fully prefilled → decode.
// R2 takes 4 tokens (1 block, chunked), partially prefilled → stays in prefill.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 1);
assert_eq!(state.prefill.len(), 1);
assert_eq!(state.decode.len(), 1);
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(state.prefill[0], r2_uuid);
assert_eq!(state.waiting[0], r3_uuid);
assert_eq!(kv_manager.num_active_blocks(), 3); // 2 for R1 + 1 for R2
let seq = match state.requests.get(&r1_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 8);
assert_eq!(seq.generated_tokens(), 0);
let seq = match state.requests.get(&r2_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 4);
assert_eq!(seq.generated_tokens(), 0);
// ── Step 3: First simulate_decode ──
// R1 generates 1 token, gains a partial block.
simulate_decode(
&mut state,
&mut kv_manager,
&output_tx,
&args.perf_model,
args.block_size,
args.speedup_ratio,
args.preemption_mode,
)
.await;
assert_eq!(state.decode.len(), 1);
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(kv_manager.num_active_blocks(), 4); // +1 partial for R1
let seq = match state.requests.get(&r1_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.generated_tokens(), 1);
// ── Step 4: Second simulate_prefill ──
// Budget=11. R2 finishes (4 more tokens, 1 block → active=5, decode).
// R3 admitted, needs 2 blocks for chunk of 7. Only 1 free slot → partial.
// Preempt R2 (LIFO) → R2 back to waiting. Retry R3 → evicts R2's
// inactive blocks, allocates 2 more → R3 allocated_tokens=11.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 1, "R2 preempted back to waiting");
assert_eq!(state.waiting[0], r2_uuid);
assert_eq!(state.prefill.len(), 1, "R3 partially prefilled");
assert_eq!(state.prefill[0], r3_uuid);
assert_eq!(state.decode.len(), 1, "R1 still decoding");
assert_eq!(state.decode[0], r1_uuid);
assert_eq!(kv_manager.num_active_blocks(), 6); // at capacity
let seq = match state.requests.get(&r3_uuid).unwrap() {
Request::Active(s) => s,
_ => panic!("expected ActiveSequence"),
};
assert_eq!(seq.num_allocated_tokens(), 11);
// ── Step 5: Second simulate_decode ──
// R1 generates 2nd token → complete. Frees 3 blocks (1 destroyed, 2 deactivated).
simulate_decode(
&mut state,
&mut kv_manager,
&output_tx,
&args.perf_model,
args.block_size,
args.speedup_ratio,
args.preemption_mode,
)
.await;
assert!(!state.requests.contains_key(&r1_uuid), "R1 completed");
assert_eq!(state.decode.len(), 0);
assert_eq!(state.prefill.len(), 1);
assert_eq!(state.waiting.len(), 1);
assert_eq!(kv_manager.num_active_blocks(), 3); // only R3's 3 blocks
// ── Step 6: Third simulate_prefill ──
// R3 finishes prefill (1 token left, no new blocks) → decode.
// R2 re-admitted, fully prefilled (2 blocks via inactive eviction) → decode.
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
assert_eq!(state.waiting.len(), 0);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 2);
assert!(state.decode.contains(&r3_uuid));
assert!(state.decode.contains(&r2_uuid));
assert_eq!(kv_manager.num_active_blocks(), 5); // 3 for R3 + 2 for R2
// ── Steps 7+: Cycle until all requests complete ──
loop {
simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
simulate_decode(
&mut state,
&mut kv_manager,
&output_tx,
&args.perf_model,
args.block_size,
args.speedup_ratio,
args.preemption_mode,
)
.await;
if state.is_empty() {
break;
}
}
assert_eq!(state.waiting.len(), 0);
assert_eq!(state.prefill.len(), 0);
assert_eq!(state.decode.len(), 0);
assert_eq!(kv_manager.num_active_blocks(), 0);
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let block_size: usize = 64;
......
......@@ -158,8 +158,8 @@ def _build_mocker_command(
command.append("--enable-chunked-prefill")
else:
command.append("--no-enable-chunked-prefill")
if "watermark" in mocker_args:
command.extend(["--watermark", str(mocker_args["watermark"])])
if "preemption_mode" in mocker_args:
command.extend(["--preemption-mode", str(mocker_args["preemption_mode"])])
if "dp_size" in mocker_args:
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
# Use --durable-kv-events to enable JetStream mode (local indexer disabled)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment