Unverified Commit e330d969 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: enable / disable chunked prefill for mockers (#2015)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 353146e2
......@@ -9,15 +9,13 @@ The mocker engine is a mock vLLM implementation designed for testing and develop
**Basic usage:**
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block-size`, `num-gpu-blocks`, `max-num-seqs`, `max-num-batched-tokens`, and `enable-prefix-caching` are common arguments shared with the real VLLM engine.
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block_size`, `num_gpu_blocks`, `max_num_seqs`, `max_num_batched_tokens`, `enable_prefix_caching`, and `enable_chunked_prefill` are common arguments shared with the real VLLM engine.
And below are arguments that are mocker-specific:
- `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster.
- `dp_size`: Number of data parallel workers to simulate (default: 1)
- `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg.
>[!NOTE]
>Currently, `enable_chunked_prefill` is always assumed to be false, which mirrors the vllm v0 behavior. This is also the current behavior in `examples/llm`. This will be updated in the near future as we move to support vllm v1 (and deprecate support for vllm v0).
```bash
echo '{"speedup_ratio": 10.0}' > mocker_args.json
python -m dynamo.mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json
......
......@@ -549,15 +549,13 @@ The mocker engine is a mock vLLM implementation designed for testing and develop
**Basic usage:**
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights. The arguments `block-size`, `num-gpu-blocks`, `max-num-seqs`, `max-num-batched-tokens`, and `enable-prefix-caching` are common arguments shared with the real VLLM engine.
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block_size`, `num_gpu_blocks`, `max_num_seqs`, `max_num_batched_tokens`, `enable_prefix_caching`, and `enable_chunked_prefill` are common arguments shared with the real VLLM engine.
And below are arguments that are mocker-specific:
- `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster.
- `dp_size`: Number of data parallel workers to simulate (default: 1)
- `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg.
>[!NOTE]
>Currently, `enable_chunked_prefill` is always assumed to be false, which mirrors the vllm v0 behavior. This is also the current behavior in `examples/llm`. This will be updated in the near future as we move to support vllm v1 (and deprecate support for vllm v0).
```bash
echo '{"speedup_ratio": 10.0}' > mocker_args.json
dynamo-run in=dyn://dynamo.mocker.generate out=mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json
......
......@@ -293,14 +293,9 @@ impl KvManager {
let overlap_blocks = seq_blocks.len() - new_blocks;
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
// Calculate prefill compute
let prefill_compute =
1.25e-6 * (new_tokens as f64).powi(2) + 7.41e-2 * (new_tokens as f64) + 2.62e1;
PrefillCost {
new_blocks,
new_tokens,
prefill_compute,
}
}
}
......
......@@ -58,7 +58,13 @@ pub struct DirectRequest {
pub struct PrefillCost {
pub new_blocks: usize,
pub new_tokens: usize,
pub prefill_compute: f64,
}
impl PrefillCost {
pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens);
1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1
}
}
/// Signal for output token generation with completion status
......@@ -89,6 +95,9 @@ pub struct MockEngineArgs {
#[builder(default = true)]
pub enable_prefix_caching: bool,
#[builder(default = true)]
pub enable_chunked_prefill: bool,
#[builder(default = "0.01")]
pub watermark: f64,
......@@ -127,6 +136,7 @@ impl MockEngineArgs {
"max_num_seqs",
"max_num_batched_tokens",
"enable_prefix_caching",
"enable_chunked_prefill",
"watermark",
"speedup_ratio",
"dp_size",
......@@ -181,6 +191,12 @@ impl MockEngineArgs {
}
}
if let Some(value) = extra_args.get("enable_chunked_prefill") {
if let Some(enabled) = value.as_bool() {
builder = builder.enable_chunked_prefill(enabled);
}
}
if let Some(value) = extra_args.get("watermark") {
if let Some(num) = value.as_f64() {
builder = builder.watermark(num);
......
......@@ -67,10 +67,20 @@ struct SchedulerState {
prefill: VecDeque<Uuid>,
decode: LRUEvictor<Uuid>,
requests: HashMap<Uuid, Request>,
prefill_costs: HashMap<Uuid, Option<PrefillCost>>,
prefill_costs: HashMap<Uuid, PrefillCost>,
max_num_batched_tokens: Option<usize>,
active_tokens: usize,
waiting_tokens: usize,
}
impl SchedulerState {
fn new(max_num_batched_tokens: Option<usize>) -> Self {
SchedulerState {
max_num_batched_tokens,
..Default::default()
}
}
/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one
......@@ -97,34 +107,81 @@ impl SchedulerState {
}
/// Move a UUID and its Request to the ready queue.
fn start_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: Option<PrefillCost>) {
fn move_to_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: PrefillCost) {
self.waiting_tokens += cost.new_tokens;
self.requests.insert(uuid, Request::Active(active_seq));
self.prefill.push_back(uuid);
self.prefill_costs.insert(uuid, cost);
}
/// Pop from prefill queue and move to decode queue.
/// Returns the prefill_compute value if available.
fn start_decode(&mut self) -> Option<(f64, MoveBlock)> {
/// Try (chunked) prefill and move to decode queue
///
/// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?;
self.decode.insert(uuid);
// Remove and extract prefill_compute from prefill_costs
let prefill_cost = self
let mut prefill_cost = self
.prefill_costs
.remove(&uuid)
.flatten()
.expect("Expects valid prefill cost.");
let Some(Request::Active(sequence)) = self.requests.get(&uuid) else {
let new_tokens = prefill_cost.new_tokens;
let maybe_prefill_tokens = self.max_num_batched_tokens.and_then(|max_tokens| {
let remaining_tokens = max_tokens - self.active_tokens;
if prefill_cost.new_tokens > remaining_tokens {
Some(remaining_tokens)
} else {
None
}
});
let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
{
let prefill_compute = prefill_cost.predict_prefill_compute(Some(prefill_tokens));
prefill_cost.new_tokens -= prefill_tokens;
assert!(
(prefill_cost.new_tokens > 0) && (prefill_compute > 0.0),
"Encountered negative prefill tokens or prefill compute cost."
);
self.prefill.push_front(uuid);
self.prefill_costs.insert(uuid, prefill_cost);
self.active_tokens = self.max_num_batched_tokens.unwrap();
self.waiting_tokens -= prefill_tokens;
(prefill_compute, false)
} else {
// Assume possible to complete prefilling the sequence, transfer to decode
self.decode.insert(uuid);
self.active_tokens += new_tokens;
self.waiting_tokens -= new_tokens;
(prefill_cost.predict_prefill_compute(None), true)
};
// NOTE: the current behavior allocates the KV blocks for the entire sequence,
// even if only a chunk is prefilled
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
let creation_signal = sequence
.creation_signal()
.clone()
.expect("Must have creation signal.");
Some((prefill_cost.prefill_compute, creation_signal))
Some((
prefill_compute,
sequence.take_creation_signal(),
is_full_prefill,
))
}
// assume (chunked) prefills are completed, then active tokens would be 1 per decoding sequence
fn reset_active_tokens(&mut self) {
self.active_tokens = self.decode.len();
}
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
......@@ -141,23 +198,13 @@ impl SchedulerState {
self.prefill.len() + self.decode.len()
}
/// Calculate the current running batched tokens
fn num_batched_tokens(&self) -> usize {
self.prefill_costs
.values()
.map(|cost| match cost {
Some(cost) => cost.new_tokens,
None => 1,
})
.sum()
}
/// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) {
// println!("Request {} will complete", uuid);
tracing::debug!("Request {} will complete", uuid);
self.decode.remove(uuid);
self.requests.remove(uuid);
self.prefill_costs.remove(uuid);
self.active_tokens -= 1;
}
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
......@@ -174,7 +221,8 @@ impl SchedulerState {
.remove(&uuid)
.expect("Request does not exist.");
self.prefill_costs.remove(&uuid);
eprintln!("Request {uuid} will be preempted");
self.active_tokens -= 1;
tracing::warn!("Request {uuid} will be preempted");
// Reset the sequence and get the new sequence and signal
// Insert the new sequence back into the requests map and add to waiting queue
......@@ -211,7 +259,7 @@ impl Scheduler {
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
let state = Arc::new(Mutex::new(SchedulerState::default()));
let state = Arc::new(Mutex::new(SchedulerState::new(args.max_num_batched_tokens)));
// Create internal channel for KV events only if needed
let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
......@@ -274,7 +322,7 @@ impl Scheduler {
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
let mut current_blocks = kv_manager_guard.num_active_blocks();
let mut current_tokens = state_guard.num_batched_tokens();
let mut current_tokens = state_guard.active_tokens + state_guard.waiting_tokens;
let mut current_seqs = state_guard.num_active_requests();
while let Some((uuid, request)) = state_guard.next() {
......@@ -283,16 +331,19 @@ impl Scheduler {
// Update predictive budgets
let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
let new_blocks = (total_tokens + 1) / args.block_size; // this is conservative, assumes no cache hit
// this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check if it can be scheduled
// Check various budgets to see if possible to schedule
let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64;
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| current_tokens <= limit);
// If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
......@@ -311,7 +362,7 @@ impl Scheduler {
}
}
state_guard.start_prefill(uuid, active_sequence, Some(prefill_cost));
state_guard.move_to_prefill(uuid, active_sequence, prefill_cost);
should_schedule = false;
}
}
......@@ -332,23 +383,30 @@ impl Scheduler {
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Process prefilling
while let Some((prefill_compute, creation_signal)) = state_guard.start_decode() {
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = state_guard.try_prefill() {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
let prefill_success = process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal));
if !prefill_success {
panic!("Block allocation for prefilling cannot fail.");
}
// Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
if let Some(creation_signal) = maybe_creation_signal {
if !process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal)) {
panic!("Block allocation for prefilling cannot fail.");
}
}
// Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
};
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill { break; }
}
state_guard.reset_active_tokens();
// Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {should_schedule = true};
......@@ -424,24 +482,30 @@ impl Scheduler {
let _ = self.request_tx.send(request);
}
/// Expose the sender
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
/// Get the count of waiting requests
pub async fn waiting_count(&self) -> usize {
let state = self.state.lock().await;
state.waiting.len()
}
/// Get the count of running requests
pub async fn running_count(&self) -> usize {
let state = self.state.lock().await;
state.decode.len()
}
/// Get the current capacity of the KvManager
pub async fn waiting_tokens(&self) -> usize {
let state = self.state.lock().await;
state.waiting_tokens
}
pub async fn active_tokens(&self) -> usize {
let state = self.state.lock().await;
state.active_tokens
}
pub async fn kv_usage_perc(&self) -> f64 {
let kv_manager = self.kv_manager.lock().await;
kv_manager.current_capacity_perc()
......@@ -540,12 +604,16 @@ fn process_signals(
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks) = signal else {
panic!("Failed signal is Invalid. Has to fail on generation signal.");
panic!("Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}");
};
// Verify the signal contains exactly one block
if blocks.len() != 1 {
panic!("Failed signal is Invalid. Can have only one generation signal.");
let num_blocks = blocks.len();
let num_active_blocks = kv_manager_guard.num_active_blocks();
if num_blocks != 1 {
panic!(
"Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
);
}
// Verify the block is a PartialBlock (generation block)
......@@ -566,20 +634,25 @@ mod tests {
use std::time::Duration;
#[rstest]
#[case::random_no_prefix_caching(false, false)]
#[case::random_with_prefix_caching(false, true)]
#[case::caching_no_prefix_caching(true, false)]
#[case::caching_with_prefix_caching(true, true)]
#[case::case_1(false, false, false)]
#[case::case_2(false, true, false)]
#[case::case_3(true, false, false)]
#[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
std::env::set_var("RUST_LOG", "debug");
let kv_capacity: usize = 500;
let block_size: usize = 64;
let num_requests: usize = 100;
let num_requests: usize = 200;
let input_len: usize = 1000;
let max_output_tokens: usize = 100;
......@@ -592,6 +665,7 @@ mod tests {
.block_size(block_size)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
......@@ -671,14 +745,12 @@ mod tests {
// Calculate and print elapsed time
let elapsed = start_time.elapsed();
println!(
"Test completed in: {:?} for {} case with prefix_caching={}",
elapsed,
"Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
if use_shared_tokens {
"caching"
} else {
"random"
},
enable_prefix_caching
}
);
// Assert that we received the expected number of tokens
......@@ -686,6 +758,18 @@ mod tests {
received_tokens == expected_tokens,
"Received {received_tokens} tokens but expected exactly {expected_tokens}"
);
let active_tokens = scheduler.active_tokens().await;
assert!(
active_tokens == 0,
"Scheduler still have {active_tokens} active tokens but expected 0"
);
let waiting_tokens = scheduler.waiting_tokens().await;
assert!(
waiting_tokens == 0,
"Scheduler still have {waiting_tokens} waiting tokens but expected 0"
);
}
#[tokio::test]
......
......@@ -120,6 +120,10 @@ impl ActiveSequence {
self.tokens.total_tokens() == 0
}
pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
self.creation_signal.take()
}
/// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal(
tokens: Vec<u32>,
......@@ -128,7 +132,7 @@ impl ActiveSequence {
enable_prefix_caching: bool,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching);
let signal = sequence.creation_signal.take();
let signal = sequence.take_creation_signal();
(sequence, signal)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment