Unverified Commit d870d4b0 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: mocker states do not need arc mutex (#3883)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent ee275cbf
...@@ -38,8 +38,7 @@ use crate::tokens::BlockHash; ...@@ -38,8 +38,7 @@ use crate::tokens::BlockHash;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use std::collections::HashMap; use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use tokio::sync::mpsc;
use tokio::sync::{Mutex, mpsc};
use tokio::time::Duration; use tokio::time::Duration;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
...@@ -238,8 +237,6 @@ impl SchedulerState { ...@@ -238,8 +237,6 @@ impl SchedulerState {
/// Manages scheduling of requests using KvManager resources /// Manages scheduling of requests using KvManager resources
#[derive(Clone)] #[derive(Clone)]
pub struct Scheduler { pub struct Scheduler {
state: Arc<Mutex<SchedulerState>>,
kv_manager: Arc<Mutex<KvManager>>,
request_tx: mpsc::UnboundedSender<DirectRequest>, request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<ForwardPassMetrics>, metrics_rx: tokio::sync::watch::Receiver<ForwardPassMetrics>,
} }
...@@ -253,8 +250,6 @@ impl Scheduler { ...@@ -253,8 +250,6 @@ impl Scheduler {
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>, kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
let state = Arc::new(Mutex::new(SchedulerState::new(args.max_num_batched_tokens)));
// Create internal channel for KV events only if needed // Create internal channel for KV events only if needed
let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() { let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>(); let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
...@@ -263,13 +258,6 @@ impl Scheduler { ...@@ -263,13 +258,6 @@ impl Scheduler {
(None, None) (None, None)
}; };
let kv_manager = Arc::new(Mutex::new(KvManager::new_with_sender(
args.num_gpu_blocks,
args.block_size,
block_resp_tx,
)));
let hit_rates = Arc::new(Mutex::new(VecDeque::with_capacity(1000)));
// Assert speedup_ratio is greater than 0 // Assert speedup_ratio is greater than 0
assert!( assert!(
args.speedup_ratio > 0.0, args.speedup_ratio > 0.0,
...@@ -284,30 +272,26 @@ impl Scheduler { ...@@ -284,30 +272,26 @@ impl Scheduler {
let (metrics_tx, metrics_rx) = let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics); tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
// Create a clone for the background task
let state_clone = state.clone();
let kv_manager_clone = kv_manager.clone();
let output_tx_clone = output_tx.clone();
let cancel_token_clone = cancellation_token.unwrap_or_default().clone(); let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
// Spawn main background task with cancellation token // Spawn main background task with cancellation token
tokio::spawn(async move { tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut kv_manager =
KvManager::new_with_sender(args.num_gpu_blocks, args.block_size, block_resp_tx);
let mut hit_rates = VecDeque::with_capacity(1000);
let mut should_schedule = true; let mut should_schedule = true;
loop { loop {
{ {
let state_guard = state_clone.lock().await;
// Enqueue new request, blocks until at least one is received, so no redundant work is done // Enqueue new request, blocks until at least one is received, so no redundant work is done
// TODO: clean this up? double lock acquisition is ugly, but needed to not hold the lock forever if state.is_empty() {
if state_guard.is_empty() {
drop(state_guard);
let Some(request) = request_rx.recv().await else { let Some(request) = request_rx.recv().await else {
tracing::warn!("request sender is dropped"); tracing::warn!("request sender is dropped");
break; break;
}; };
let mut state_guard = state_clone.lock().await; state.receive(request);
state_guard.receive(request);
} }
} }
...@@ -316,7 +300,6 @@ impl Scheduler { ...@@ -316,7 +300,6 @@ impl Scheduler {
// Enqueue new request // Enqueue new request
Some(request) = request_rx.recv() => { Some(request) = request_rx.recv() => {
let mut state = state_clone.lock().await;
state.receive(request); state.receive(request);
} }
...@@ -327,20 +310,17 @@ impl Scheduler { ...@@ -327,20 +310,17 @@ impl Scheduler {
continue; continue;
} }
let mut state_guard = state_clone.lock().await;
let kv_manager_guard = kv_manager_clone.lock().await;
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't // Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore. // schedule anymore.
let mut current_blocks = kv_manager_guard.num_active_blocks(); let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state_guard.active_tokens + state_guard.waiting_tokens; let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state_guard.num_active_requests(); let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state_guard.next() { while let Some((uuid, request)) = state.next() {
let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching); let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching);
// Update predictive budgets // Update predictive budgets
let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence); let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len(); let total_tokens = active_sequence.len();
// this is conservative, assumes no cache hit so never over-schedules // this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize; let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
...@@ -351,7 +331,7 @@ impl Scheduler { ...@@ -351,7 +331,7 @@ impl Scheduler {
current_seqs += 1; current_seqs += 1;
// Check various budgets to see if possible to schedule // Check various budgets to see if possible to schedule
let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64; let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
// If chunked prefill is enabled, we can be under token budget when scheduling // If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens}; let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit); let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
...@@ -359,21 +339,18 @@ impl Scheduler { ...@@ -359,21 +339,18 @@ impl Scheduler {
// Cannot schedule, put first in line instead // Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) { if !(under_block_budget && under_token_budget && under_seq_budget) {
state_guard.first_in_line(uuid, Request::Active(active_sequence)); state.first_in_line(uuid, Request::Active(active_sequence));
break; break;
} }
// Compute and store hit rate // Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 }; let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
{ hit_rates.push_back(hit_rate);
let mut hit_rates_guard = hit_rates.lock().await; if hit_rates.len() > 1000 {
hit_rates_guard.push_back(hit_rate); hit_rates.pop_front();
if hit_rates_guard.len() > 1000 {
hit_rates_guard.pop_front();
}
} }
state_guard.move_to_prefill(uuid, active_sequence, prefill_cost); state.move_to_prefill(uuid, active_sequence, prefill_cost);
should_schedule = false; should_schedule = false;
} }
} }
...@@ -385,11 +362,8 @@ impl Scheduler { ...@@ -385,11 +362,8 @@ impl Scheduler {
} }
// Simulates prefill + decode // Simulates prefill + decode
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;
// Base time needed for decoding using active percentage and quadratic formula // Base time needed for decoding using active percentage and quadratic formula
let active_perc = kv_manager_guard.get_active_perc(); let active_perc = kv_manager.get_active_perc();
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44; let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0); let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
...@@ -399,17 +373,15 @@ impl Scheduler { ...@@ -399,17 +373,15 @@ impl Scheduler {
maybe_creation_signal, maybe_creation_signal,
block_hashes, block_hashes,
is_full_prefill, is_full_prefill,
)) = state_guard.try_prefill() )) = state.try_prefill()
{ {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they // NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior. // could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0); total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
if let Some(creation_signal) = maybe_creation_signal { if let Some(creation_signal) = maybe_creation_signal {
if !process_signals( if !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal))
&mut kv_manager_guard, {
std::slice::from_ref(&creation_signal),
) {
panic!("Block allocation for prefilling cannot fail."); panic!("Block allocation for prefilling cannot fail.");
} }
...@@ -428,36 +400,25 @@ impl Scheduler { ...@@ -428,36 +400,25 @@ impl Scheduler {
} }
} }
state_guard.reset_active_tokens(); state.reset_active_tokens();
{
let hit_rates_guard = hit_rates.lock().await;
let metrics = get_fwd_pass_metrics(
&state_guard,
&kv_manager_guard,
&hit_rates_guard,
dp_rank,
);
let _ = metrics_tx.send(metrics);
}
// Process decoding // Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect(); let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
if !uuids.is_empty() { if !uuids.is_empty() {
should_schedule = true should_schedule = true
}; };
for uuid in uuids { for uuid in uuids {
let Some(sequence) = state_guard.run(uuid) else { let Some(sequence) = state.run(uuid) else {
continue; continue;
}; };
let signals = sequence.generate(); let signals = sequence.generate();
// Process all signals with the KvManager // Process all signals with the KvManager
// Handling of preemption on failure // Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) { if !process_signals(&mut kv_manager, &signals) {
sequence.pop(); // revert the failed generation op sequence.pop(); // revert the failed generation op
for signal in state_guard.preempt() { for signal in state.preempt() {
kv_manager_guard.process(&signal); kv_manager.process(&signal);
} }
continue; continue;
} }
...@@ -477,7 +438,7 @@ impl Scheduler { ...@@ -477,7 +438,7 @@ impl Scheduler {
let mut send_failed = false; let mut send_failed = false;
if should_output { if should_output {
send_failed = output_tx_clone.as_ref().is_some_and(|tx| { send_failed = output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal { tx.send(OutputSignal {
uuid, uuid,
completed: is_complete, completed: is_complete,
...@@ -488,30 +449,23 @@ impl Scheduler { ...@@ -488,30 +449,23 @@ impl Scheduler {
if send_failed { if send_failed {
for signal in &sequence.free_signal() { for signal in &sequence.free_signal() {
kv_manager_guard.process(signal); kv_manager.process(signal);
} }
} }
{
let hit_rates_guard = hit_rates.lock().await;
let metrics = get_fwd_pass_metrics(
&state_guard,
&kv_manager_guard,
&hit_rates_guard,
dp_rank,
);
let _ = metrics_tx.send(metrics);
}
if send_failed || is_complete { if send_failed || is_complete {
state_guard.complete(&uuid); state.complete(&uuid);
continue; continue;
} }
} }
// Send metrics once per forward pass (after all prefill and decode processing)
{
let metrics = get_fwd_pass_metrics(&state, &kv_manager, &hit_rates, dp_rank);
let _ = metrics_tx.send(metrics);
}
// Sleep once for the adjusted duration // Sleep once for the adjusted duration
drop(kv_manager_guard);
drop(state_guard);
let adjusted_time = let adjusted_time =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 { if adjusted_time.as_millis() > 0 {
...@@ -521,8 +475,6 @@ impl Scheduler { ...@@ -521,8 +475,6 @@ impl Scheduler {
}); });
Self { Self {
state,
kv_manager,
request_tx, request_tx,
metrics_rx, metrics_rx,
} }
...@@ -537,31 +489,6 @@ impl Scheduler { ...@@ -537,31 +489,6 @@ impl Scheduler {
self.request_tx.clone() self.request_tx.clone()
} }
pub async fn waiting_count(&self) -> usize {
let state = self.state.lock().await;
state.waiting.len()
}
pub async fn running_count(&self) -> usize {
let state = self.state.lock().await;
state.decode.len()
}
pub async fn waiting_tokens(&self) -> usize {
let state = self.state.lock().await;
state.waiting_tokens
}
pub async fn active_tokens(&self) -> usize {
let state = self.state.lock().await;
state.active_tokens
}
pub async fn kv_usage_perc(&self) -> f64 {
let kv_manager = self.kv_manager.lock().await;
kv_manager.current_capacity_perc()
}
/// Get a watch receiver for forward pass metrics /// Get a watch receiver for forward pass metrics
pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> { pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
self.metrics_rx.clone() self.metrics_rx.clone()
...@@ -648,12 +575,9 @@ fn get_active_sequence( ...@@ -648,12 +575,9 @@ fn get_active_sequence(
/// This validation is important because in normal operation, the only legitimate failure /// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would /// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system. /// indicate an unexpected state in the system.
fn process_signals( fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
kv_manager_guard: &mut tokio::sync::MutexGuard<'_, KvManager>,
signals: &[MoveBlock],
) -> bool {
for signal in signals { for signal in signals {
if kv_manager_guard.process(signal) { if kv_manager.process(signal) {
continue; continue;
} }
...@@ -666,7 +590,7 @@ fn process_signals( ...@@ -666,7 +590,7 @@ fn process_signals(
// Verify the signal contains exactly one block // Verify the signal contains exactly one block
let num_blocks = blocks.len(); let num_blocks = blocks.len();
let num_active_blocks = kv_manager_guard.num_active_blocks(); let num_active_blocks = kv_manager.num_active_blocks();
if num_blocks != 1 { if num_blocks != 1 {
panic!( panic!(
"Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks." "Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
...@@ -691,6 +615,30 @@ mod tests { ...@@ -691,6 +615,30 @@ mod tests {
use std::time::Duration; use std::time::Duration;
use tokio::time::interval; use tokio::time::interval;
/// Helper function to verify that the scheduler is idle (no active or waiting requests/resources)
fn assert_scheduler_idle(metrics: &ForwardPassMetrics) {
assert_eq!(
metrics.worker_stats.request_active_slots, 0,
"Expected 0 active slots, got {}",
metrics.worker_stats.request_active_slots
);
assert_eq!(
metrics.worker_stats.num_requests_waiting, 0,
"Expected 0 waiting requests, got {}",
metrics.worker_stats.num_requests_waiting
);
assert_eq!(
metrics.kv_stats.kv_active_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.kv_stats.kv_active_blocks
);
assert_eq!(
metrics.kv_stats.gpu_cache_usage_perc, 0.0,
"Expected 0% GPU cache usage, got {}",
metrics.kv_stats.gpu_cache_usage_perc
);
}
#[rstest] #[rstest]
#[case::case_1(false, false, false)] #[case::case_1(false, false, false)]
#[case::case_2(false, true, false)] #[case::case_2(false, true, false)]
...@@ -820,17 +768,11 @@ mod tests { ...@@ -820,17 +768,11 @@ mod tests {
"Received {received_tokens} tokens but expected exactly {expected_tokens}" "Received {received_tokens} tokens but expected exactly {expected_tokens}"
); );
let active_tokens = scheduler.active_tokens().await; // Wait a bit for final metrics update to propagate
assert!( tokio::time::sleep(Duration::from_millis(100)).await;
active_tokens == 0,
"Scheduler still have {active_tokens} active tokens but expected 0"
);
let waiting_tokens = scheduler.waiting_tokens().await; let metrics = scheduler.metrics_receiver().borrow().clone();
assert!( assert_scheduler_idle(&metrics);
waiting_tokens == 0,
"Scheduler still have {waiting_tokens} waiting tokens but expected 0"
);
} }
#[tokio::test] #[tokio::test]
...@@ -913,12 +855,7 @@ mod tests { ...@@ -913,12 +855,7 @@ mod tests {
// Verify forward pass metrics // Verify forward pass metrics
let metrics = metrics_rx.borrow().clone(); let metrics = metrics_rx.borrow().clone();
assert_eq!( assert_scheduler_idle(&metrics);
metrics.worker_stats.num_requests_waiting, 0,
"Expected no waiting requests, got {}",
metrics.worker_stats.num_requests_waiting
);
assert!( assert!(
metrics.kv_stats.gpu_prefix_cache_hit_rate > 0.8, metrics.kv_stats.gpu_prefix_cache_hit_rate > 0.8,
"Expected cache hit rate > 0.8, got {}", "Expected cache hit rate > 0.8, got {}",
...@@ -983,17 +920,6 @@ mod tests { ...@@ -983,17 +920,6 @@ mod tests {
let metrics_rx = scheduler.metrics_receiver(); let metrics_rx = scheduler.metrics_receiver();
let metrics = metrics_rx.borrow().clone(); let metrics = metrics_rx.borrow().clone();
assert_eq!( assert_scheduler_idle(&metrics);
metrics.kv_stats.gpu_cache_usage_perc,
0.0,
"Expected GPU cache usage to be 0%, got {}%",
metrics.kv_stats.gpu_cache_usage_perc * 100.0
);
assert_eq!(
metrics.kv_stats.kv_active_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.kv_stats.kv_active_blocks
);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment