Unverified Commit e330d969 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: enable / disable chunked prefill for mockers (#2015)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 353146e2
...@@ -9,15 +9,13 @@ The mocker engine is a mock vLLM implementation designed for testing and develop ...@@ -9,15 +9,13 @@ The mocker engine is a mock vLLM implementation designed for testing and develop
**Basic usage:** **Basic usage:**
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block-size`, `num-gpu-blocks`, `max-num-seqs`, `max-num-batched-tokens`, and `enable-prefix-caching` are common arguments shared with the real VLLM engine. The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block_size`, `num_gpu_blocks`, `max_num_seqs`, `max_num_batched_tokens`, `enable_prefix_caching`, and `enable_chunked_prefill` are common arguments shared with the real VLLM engine.
And below are arguments that are mocker-specific: And below are arguments that are mocker-specific:
- `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster. - `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster.
- `dp_size`: Number of data parallel workers to simulate (default: 1) - `dp_size`: Number of data parallel workers to simulate (default: 1)
- `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg. - `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg.
>[!NOTE]
>Currently, `enable_chunked_prefill` is always assumed to be false, which mirrors the vllm v0 behavior. This is also the current behavior in `examples/llm`. This will be updated in the near future as we move to support vllm v1 (and deprecate support for vllm v0).
```bash ```bash
echo '{"speedup_ratio": 10.0}' > mocker_args.json echo '{"speedup_ratio": 10.0}' > mocker_args.json
python -m dynamo.mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json python -m dynamo.mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json
......
...@@ -549,15 +549,13 @@ The mocker engine is a mock vLLM implementation designed for testing and develop ...@@ -549,15 +549,13 @@ The mocker engine is a mock vLLM implementation designed for testing and develop
**Basic usage:** **Basic usage:**
The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights. The arguments `block-size`, `num-gpu-blocks`, `max-num-seqs`, `max-num-batched-tokens`, and `enable-prefix-caching` are common arguments shared with the real VLLM engine. The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block_size`, `num_gpu_blocks`, `max_num_seqs`, `max_num_batched_tokens`, `enable_prefix_caching`, and `enable_chunked_prefill` are common arguments shared with the real VLLM engine.
And below are arguments that are mocker-specific: And below are arguments that are mocker-specific:
- `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster. - `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster.
- `dp_size`: Number of data parallel workers to simulate (default: 1) - `dp_size`: Number of data parallel workers to simulate (default: 1)
- `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg. - `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg.
>[!NOTE]
>Currently, `enable_chunked_prefill` is always assumed to be false, which mirrors the vllm v0 behavior. This is also the current behavior in `examples/llm`. This will be updated in the near future as we move to support vllm v1 (and deprecate support for vllm v0).
```bash ```bash
echo '{"speedup_ratio": 10.0}' > mocker_args.json echo '{"speedup_ratio": 10.0}' > mocker_args.json
dynamo-run in=dyn://dynamo.mocker.generate out=mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json dynamo-run in=dyn://dynamo.mocker.generate out=mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json
......
...@@ -293,14 +293,9 @@ impl KvManager { ...@@ -293,14 +293,9 @@ impl KvManager {
let overlap_blocks = seq_blocks.len() - new_blocks; let overlap_blocks = seq_blocks.len() - new_blocks;
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size; let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
// Calculate prefill compute
let prefill_compute =
1.25e-6 * (new_tokens as f64).powi(2) + 7.41e-2 * (new_tokens as f64) + 2.62e1;
PrefillCost { PrefillCost {
new_blocks, new_blocks,
new_tokens, new_tokens,
prefill_compute,
} }
} }
} }
......
...@@ -58,7 +58,13 @@ pub struct DirectRequest { ...@@ -58,7 +58,13 @@ pub struct DirectRequest {
pub struct PrefillCost { pub struct PrefillCost {
pub new_blocks: usize, pub new_blocks: usize,
pub new_tokens: usize, pub new_tokens: usize,
pub prefill_compute: f64, }
impl PrefillCost {
pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens);
1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1
}
} }
/// Signal for output token generation with completion status /// Signal for output token generation with completion status
...@@ -89,6 +95,9 @@ pub struct MockEngineArgs { ...@@ -89,6 +95,9 @@ pub struct MockEngineArgs {
#[builder(default = true)] #[builder(default = true)]
pub enable_prefix_caching: bool, pub enable_prefix_caching: bool,
#[builder(default = true)]
pub enable_chunked_prefill: bool,
#[builder(default = "0.01")] #[builder(default = "0.01")]
pub watermark: f64, pub watermark: f64,
...@@ -127,6 +136,7 @@ impl MockEngineArgs { ...@@ -127,6 +136,7 @@ impl MockEngineArgs {
"max_num_seqs", "max_num_seqs",
"max_num_batched_tokens", "max_num_batched_tokens",
"enable_prefix_caching", "enable_prefix_caching",
"enable_chunked_prefill",
"watermark", "watermark",
"speedup_ratio", "speedup_ratio",
"dp_size", "dp_size",
...@@ -181,6 +191,12 @@ impl MockEngineArgs { ...@@ -181,6 +191,12 @@ impl MockEngineArgs {
} }
} }
if let Some(value) = extra_args.get("enable_chunked_prefill") {
if let Some(enabled) = value.as_bool() {
builder = builder.enable_chunked_prefill(enabled);
}
}
if let Some(value) = extra_args.get("watermark") { if let Some(value) = extra_args.get("watermark") {
if let Some(num) = value.as_f64() { if let Some(num) = value.as_f64() {
builder = builder.watermark(num); builder = builder.watermark(num);
......
...@@ -67,10 +67,20 @@ struct SchedulerState { ...@@ -67,10 +67,20 @@ struct SchedulerState {
prefill: VecDeque<Uuid>, prefill: VecDeque<Uuid>,
decode: LRUEvictor<Uuid>, decode: LRUEvictor<Uuid>,
requests: HashMap<Uuid, Request>, requests: HashMap<Uuid, Request>,
prefill_costs: HashMap<Uuid, Option<PrefillCost>>, prefill_costs: HashMap<Uuid, PrefillCost>,
max_num_batched_tokens: Option<usize>,
active_tokens: usize,
waiting_tokens: usize,
} }
impl SchedulerState { impl SchedulerState {
fn new(max_num_batched_tokens: Option<usize>) -> Self {
SchedulerState {
max_num_batched_tokens,
..Default::default()
}
}
/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting. /// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
fn receive(&mut self, request: DirectRequest) -> Uuid { fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one // Use the provided UUID if available, otherwise generate a new one
...@@ -97,34 +107,81 @@ impl SchedulerState { ...@@ -97,34 +107,81 @@ impl SchedulerState {
} }
/// Move a UUID and its Request to the ready queue. /// Move a UUID and its Request to the ready queue.
fn start_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: Option<PrefillCost>) { fn move_to_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: PrefillCost) {
self.waiting_tokens += cost.new_tokens;
self.requests.insert(uuid, Request::Active(active_seq)); self.requests.insert(uuid, Request::Active(active_seq));
self.prefill.push_back(uuid); self.prefill.push_back(uuid);
self.prefill_costs.insert(uuid, cost); self.prefill_costs.insert(uuid, cost);
} }
/// Pop from prefill queue and move to decode queue. /// Try (chunked) prefill and move to decode queue
/// Returns the prefill_compute value if available. ///
fn start_decode(&mut self) -> Option<(f64, MoveBlock)> { /// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?; let uuid = self.prefill.pop_front()?;
self.decode.insert(uuid);
// Remove and extract prefill_compute from prefill_costs // Remove and extract prefill_compute from prefill_costs
let prefill_cost = self let mut prefill_cost = self
.prefill_costs .prefill_costs
.remove(&uuid) .remove(&uuid)
.flatten()
.expect("Expects valid prefill cost."); .expect("Expects valid prefill cost.");
let Some(Request::Active(sequence)) = self.requests.get(&uuid) else { let new_tokens = prefill_cost.new_tokens;
let maybe_prefill_tokens = self.max_num_batched_tokens.and_then(|max_tokens| {
let remaining_tokens = max_tokens - self.active_tokens;
if prefill_cost.new_tokens > remaining_tokens {
Some(remaining_tokens)
} else {
None
}
});
let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
{
let prefill_compute = prefill_cost.predict_prefill_compute(Some(prefill_tokens));
prefill_cost.new_tokens -= prefill_tokens;
assert!(
(prefill_cost.new_tokens > 0) && (prefill_compute > 0.0),
"Encountered negative prefill tokens or prefill compute cost."
);
self.prefill.push_front(uuid);
self.prefill_costs.insert(uuid, prefill_cost);
self.active_tokens = self.max_num_batched_tokens.unwrap();
self.waiting_tokens -= prefill_tokens;
(prefill_compute, false)
} else {
// Assume possible to complete prefilling the sequence, transfer to decode
self.decode.insert(uuid);
self.active_tokens += new_tokens;
self.waiting_tokens -= new_tokens;
(prefill_cost.predict_prefill_compute(None), true)
};
// NOTE: the current behavior allocates the KV blocks for the entire sequence,
// even if only a chunk is prefilled
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist."); panic!("Request does not exist.");
}; };
let creation_signal = sequence
.creation_signal()
.clone()
.expect("Must have creation signal.");
Some((prefill_cost.prefill_compute, creation_signal)) Some((
prefill_compute,
sequence.take_creation_signal(),
is_full_prefill,
))
}
// assume (chunked) prefills are completed, then active tokens would be 1 per decoding sequence
fn reset_active_tokens(&mut self) {
self.active_tokens = self.decode.len();
} }
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> { fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
...@@ -141,23 +198,13 @@ impl SchedulerState { ...@@ -141,23 +198,13 @@ impl SchedulerState {
self.prefill.len() + self.decode.len() self.prefill.len() + self.decode.len()
} }
/// Calculate the current running batched tokens
fn num_batched_tokens(&self) -> usize {
self.prefill_costs
.values()
.map(|cost| match cost {
Some(cost) => cost.new_tokens,
None => 1,
})
.sum()
}
/// Remove a UUID and its associated Request from collections. /// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) { fn complete(&mut self, uuid: &Uuid) {
// println!("Request {} will complete", uuid); tracing::debug!("Request {} will complete", uuid);
self.decode.remove(uuid); self.decode.remove(uuid);
self.requests.remove(uuid); self.requests.remove(uuid);
self.prefill_costs.remove(uuid); self.prefill_costs.remove(uuid);
self.active_tokens -= 1;
} }
/// Preempt the oldest running request by evicting it from running, resetting the sequence, /// Preempt the oldest running request by evicting it from running, resetting the sequence,
...@@ -174,7 +221,8 @@ impl SchedulerState { ...@@ -174,7 +221,8 @@ impl SchedulerState {
.remove(&uuid) .remove(&uuid)
.expect("Request does not exist."); .expect("Request does not exist.");
self.prefill_costs.remove(&uuid); self.prefill_costs.remove(&uuid);
eprintln!("Request {uuid} will be preempted"); self.active_tokens -= 1;
tracing::warn!("Request {uuid} will be preempted");
// Reset the sequence and get the new sequence and signal // Reset the sequence and get the new sequence and signal
// Insert the new sequence back into the requests map and add to waiting queue // Insert the new sequence back into the requests map and add to waiting queue
...@@ -211,7 +259,7 @@ impl Scheduler { ...@@ -211,7 +259,7 @@ impl Scheduler {
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>, kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
let state = Arc::new(Mutex::new(SchedulerState::default())); let state = Arc::new(Mutex::new(SchedulerState::new(args.max_num_batched_tokens)));
// Create internal channel for KV events only if needed // Create internal channel for KV events only if needed
let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() { let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
...@@ -274,7 +322,7 @@ impl Scheduler { ...@@ -274,7 +322,7 @@ impl Scheduler {
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't // Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore. // schedule anymore.
let mut current_blocks = kv_manager_guard.num_active_blocks(); let mut current_blocks = kv_manager_guard.num_active_blocks();
let mut current_tokens = state_guard.num_batched_tokens(); let mut current_tokens = state_guard.active_tokens + state_guard.waiting_tokens;
let mut current_seqs = state_guard.num_active_requests(); let mut current_seqs = state_guard.num_active_requests();
while let Some((uuid, request)) = state_guard.next() { while let Some((uuid, request)) = state_guard.next() {
...@@ -283,16 +331,19 @@ impl Scheduler { ...@@ -283,16 +331,19 @@ impl Scheduler {
// Update predictive budgets // Update predictive budgets
let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence); let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len(); let total_tokens = active_sequence.len();
let new_blocks = (total_tokens + 1) / args.block_size; // this is conservative, assumes no cache hit // this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens; let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks; current_blocks += new_blocks;
current_tokens += new_tokens; current_tokens += new_tokens;
current_seqs += 1; current_seqs += 1;
// Check if it can be scheduled // Check various budgets to see if possible to schedule
let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64; let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64;
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| current_tokens <= limit); // If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit); let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead // Cannot schedule, put first in line instead
...@@ -311,7 +362,7 @@ impl Scheduler { ...@@ -311,7 +362,7 @@ impl Scheduler {
} }
} }
state_guard.start_prefill(uuid, active_sequence, Some(prefill_cost)); state_guard.move_to_prefill(uuid, active_sequence, prefill_cost);
should_schedule = false; should_schedule = false;
} }
} }
...@@ -332,12 +383,13 @@ impl Scheduler { ...@@ -332,12 +383,13 @@ impl Scheduler {
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0); let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Process prefilling // Process prefilling
while let Some((prefill_compute, creation_signal)) = state_guard.start_decode() { while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = state_guard.try_prefill() {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they // NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior. // could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0); total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
let prefill_success = process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal));
if !prefill_success { if let Some(creation_signal) = maybe_creation_signal {
if !process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal)) {
panic!("Block allocation for prefilling cannot fail."); panic!("Block allocation for prefilling cannot fail.");
} }
...@@ -347,8 +399,14 @@ impl Scheduler { ...@@ -347,8 +399,14 @@ impl Scheduler {
let _ = relay_tx.send(block_response_to_kv_event(event)); let _ = relay_tx.send(block_response_to_kv_event(event));
} }
} }
};
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill { break; }
} }
state_guard.reset_active_tokens();
// Process decoding // Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect(); let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {should_schedule = true}; if !uuids.is_empty() {should_schedule = true};
...@@ -424,24 +482,30 @@ impl Scheduler { ...@@ -424,24 +482,30 @@ impl Scheduler {
let _ = self.request_tx.send(request); let _ = self.request_tx.send(request);
} }
/// Expose the sender
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> { pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone() self.request_tx.clone()
} }
/// Get the count of waiting requests
pub async fn waiting_count(&self) -> usize { pub async fn waiting_count(&self) -> usize {
let state = self.state.lock().await; let state = self.state.lock().await;
state.waiting.len() state.waiting.len()
} }
/// Get the count of running requests
pub async fn running_count(&self) -> usize { pub async fn running_count(&self) -> usize {
let state = self.state.lock().await; let state = self.state.lock().await;
state.decode.len() state.decode.len()
} }
/// Get the current capacity of the KvManager pub async fn waiting_tokens(&self) -> usize {
let state = self.state.lock().await;
state.waiting_tokens
}
pub async fn active_tokens(&self) -> usize {
let state = self.state.lock().await;
state.active_tokens
}
pub async fn kv_usage_perc(&self) -> f64 { pub async fn kv_usage_perc(&self) -> f64 {
let kv_manager = self.kv_manager.lock().await; let kv_manager = self.kv_manager.lock().await;
kv_manager.current_capacity_perc() kv_manager.current_capacity_perc()
...@@ -540,12 +604,16 @@ fn process_signals( ...@@ -540,12 +604,16 @@ fn process_signals(
// Check we have a Use signal with blocks // Check we have a Use signal with blocks
let MoveBlock::Use(blocks) = signal else { let MoveBlock::Use(blocks) = signal else {
panic!("Failed signal is Invalid. Has to fail on generation signal."); panic!("Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}");
}; };
// Verify the signal contains exactly one block // Verify the signal contains exactly one block
if blocks.len() != 1 { let num_blocks = blocks.len();
panic!("Failed signal is Invalid. Can have only one generation signal."); let num_active_blocks = kv_manager_guard.num_active_blocks();
if num_blocks != 1 {
panic!(
"Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
);
} }
// Verify the block is a PartialBlock (generation block) // Verify the block is a PartialBlock (generation block)
...@@ -566,20 +634,25 @@ mod tests { ...@@ -566,20 +634,25 @@ mod tests {
use std::time::Duration; use std::time::Duration;
#[rstest] #[rstest]
#[case::random_no_prefix_caching(false, false)] #[case::case_1(false, false, false)]
#[case::random_with_prefix_caching(false, true)] #[case::case_2(false, true, false)]
#[case::caching_no_prefix_caching(true, false)] #[case::case_3(true, false, false)]
#[case::caching_with_prefix_caching(true, true)] #[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test] #[tokio::test]
async fn test_scheduler_token_generation_patterns( async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool, #[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool, #[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) { ) {
std::env::set_var("RUST_LOG", "debug"); std::env::set_var("RUST_LOG", "debug");
let kv_capacity: usize = 500; let kv_capacity: usize = 500;
let block_size: usize = 64; let block_size: usize = 64;
let num_requests: usize = 100; let num_requests: usize = 200;
let input_len: usize = 1000; let input_len: usize = 1000;
let max_output_tokens: usize = 100; let max_output_tokens: usize = 100;
...@@ -592,6 +665,7 @@ mod tests { ...@@ -592,6 +665,7 @@ mod tests {
.block_size(block_size) .block_size(block_size)
.speedup_ratio(10.0) .speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching) .enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build() .build()
.unwrap(); .unwrap();
...@@ -671,14 +745,12 @@ mod tests { ...@@ -671,14 +745,12 @@ mod tests {
// Calculate and print elapsed time // Calculate and print elapsed time
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
println!( println!(
"Test completed in: {:?} for {} case with prefix_caching={}", "Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
elapsed,
if use_shared_tokens { if use_shared_tokens {
"caching" "caching"
} else { } else {
"random" "random"
}, }
enable_prefix_caching
); );
// Assert that we received the expected number of tokens // Assert that we received the expected number of tokens
...@@ -686,6 +758,18 @@ mod tests { ...@@ -686,6 +758,18 @@ mod tests {
received_tokens == expected_tokens, received_tokens == expected_tokens,
"Received {received_tokens} tokens but expected exactly {expected_tokens}" "Received {received_tokens} tokens but expected exactly {expected_tokens}"
); );
let active_tokens = scheduler.active_tokens().await;
assert!(
active_tokens == 0,
"Scheduler still have {active_tokens} active tokens but expected 0"
);
let waiting_tokens = scheduler.waiting_tokens().await;
assert!(
waiting_tokens == 0,
"Scheduler still have {waiting_tokens} waiting tokens but expected 0"
);
} }
#[tokio::test] #[tokio::test]
......
...@@ -120,6 +120,10 @@ impl ActiveSequence { ...@@ -120,6 +120,10 @@ impl ActiveSequence {
self.tokens.total_tokens() == 0 self.tokens.total_tokens() == 0
} }
pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
self.creation_signal.take()
}
/// Create a new ActiveSequence instance and return the creation signal /// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal( pub fn new_with_signal(
tokens: Vec<u32>, tokens: Vec<u32>,
...@@ -128,7 +132,7 @@ impl ActiveSequence { ...@@ -128,7 +132,7 @@ impl ActiveSequence {
enable_prefix_caching: bool, enable_prefix_caching: bool,
) -> (Self, Option<MoveBlock>) { ) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching); let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching);
let signal = sequence.creation_signal.take(); let signal = sequence.take_creation_signal();
(sequence, signal) (sequence, signal)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment