Unverified Commit 377711d3 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: de-spaghetti mocker scheduler (#4789)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 8f0ac731
...@@ -280,51 +280,130 @@ impl Scheduler { ...@@ -280,51 +280,130 @@ impl Scheduler {
loop { loop {
// 1. Receive requests // 1. Receive requests
if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
.await
.is_none()
{
break;
}
// Start timing for this forward pass (schedule + simulate)
let iteration_start = std::time::Instant::now();
// 2. Schedule waiting requests (once per iteration)
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode
let prefill_time = simulate_prefill(
&mut state,
&mut kv_manager,
&args.perf_model,
args.worker_type,
);
let decode_time = simulate_decode(
&mut state,
&mut kv_manager,
&output_tx,
&args.perf_model,
args.block_size,
);
let total_time = prefill_time + decode_time;
// 4. Send metrics once per forward pass (after all prefill and decode processing)
let _ = metrics_tx.send(get_fwd_pass_metrics(
&state,
&kv_manager,
&hit_rates,
dp_rank,
));
// 5. Sleep to maintain target iteration timing
let target_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let elapsed = iteration_start.elapsed();
if elapsed < target_duration {
tokio::time::sleep(target_duration - elapsed).await;
}
}
});
Self {
request_tx,
metrics_rx,
}
}
/// Add a new request to the waiting queue
pub async fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
/// Get a watch receiver for forward pass metrics
pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
self.metrics_rx.clone()
}
}
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
state: &mut SchedulerState,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if state.is_empty() { if state.is_empty() {
// Fully idle - block until new request arrives // Fully idle - block until new request arrives
tokio::select! { tokio::select! {
biased; biased;
_ = cancel_token.cancelled() => {
return None;
}
Some(request) = request_rx.recv() => { Some(request) = request_rx.recv() => {
state.receive(request); state.receive(request);
return Some(());
} }
_ = cancel_token_clone.cancelled() => {
break;
} }
} }
} else {
// Has active/waiting work - collect any pending requests without blocking // Has active/waiting work - collect any pending requests without blocking
while let Ok(request) = request_rx.try_recv() { while let Ok(request) = request_rx.try_recv() {
state.receive(request); state.receive(request);
} }
// Check for cancellation Some(())
if cancel_token_clone.is_cancelled() { }
break;
}
}
// Start timing for this forward pass (schedule + simulate)
let iteration_start = std::time::Instant::now();
// 2. Schedule waiting requests (once per iteration)
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode /// Simulate prefill phase for all pending prefill requests.
/// Returns the total prefill compute time.
fn simulate_prefill(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
perf_model: &PerfModel,
worker_type: WorkerType,
) -> Duration {
let mut total_time = Duration::ZERO; let mut total_time = Duration::ZERO;
// Process prefilling
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state.try_prefill(&args.perf_model) state.try_prefill(perf_model)
{ {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they // NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior. // could be cached by other requests in the same batch. This matches vLLM behavior.
// For decode workers, skip adding prefill compute time // For decode workers, skip adding prefill compute time
if args.worker_type != WorkerType::Decode { if worker_type != WorkerType::Decode {
total_time += Duration::from_secs_f64(prefill_compute / 1000.0); total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
} }
if let Some(creation_signal) = maybe_creation_signal if let Some(creation_signal) = maybe_creation_signal
&& !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal)) && !process_signals(kv_manager, std::slice::from_ref(&creation_signal))
{ {
panic!("Block allocation for prefilling cannot fail."); panic!("Block allocation for prefilling cannot fail.");
} }
...@@ -335,8 +414,20 @@ impl Scheduler { ...@@ -335,8 +414,20 @@ impl Scheduler {
} }
} }
total_time
}
/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
fn simulate_decode(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
perf_model: &PerfModel,
block_size: usize,
) -> Duration {
// Compute decode timing // Compute decode timing
let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size; let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
// Compute average context length across all active decode requests // Compute average context length across all active decode requests
let (total_length, count) = state let (total_length, count) = state
.decode .decode
...@@ -350,10 +441,8 @@ impl Scheduler { ...@@ -350,10 +441,8 @@ impl Scheduler {
} }
}); });
let context_length = if count > 0 { total_length / count } else { 0 }; let context_length = if count > 0 { total_length / count } else { 0 };
let decoding_time = args let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
.perf_model let total_time = Duration::from_secs_f64(decoding_time / 1000.0);
.predict_decode_time(active_kv_tokens, context_length);
total_time += Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens(); state.reset_active_tokens();
...@@ -367,7 +456,7 @@ impl Scheduler { ...@@ -367,7 +456,7 @@ impl Scheduler {
// Process all signals with the KvManager // Process all signals with the KvManager
// Handling of preemption on failure // Handling of preemption on failure
if !process_signals(&mut kv_manager, &signals) { if !process_signals(kv_manager, &signals) {
sequence.pop(); // revert the failed generation op sequence.pop(); // revert the failed generation op
for signal in state.preempt() { for signal in state.preempt() {
kv_manager.process(&signal); kv_manager.process(&signal);
...@@ -377,19 +466,16 @@ impl Scheduler { ...@@ -377,19 +466,16 @@ impl Scheduler {
// Check completion and send notification // Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
sequence.generated_tokens() > sequence.already_generated_tokens();
let mut send_failed = false; let send_failed = should_output
if should_output { && output_tx.as_ref().is_some_and(|tx| {
send_failed = output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal { tx.send(OutputSignal {
uuid, uuid,
completed: is_complete, completed: is_complete,
}) })
.is_err() .is_err()
}); });
}
if send_failed { if send_failed {
for signal in &sequence.free_signal() { for signal in &sequence.free_signal() {
...@@ -399,46 +485,10 @@ impl Scheduler { ...@@ -399,46 +485,10 @@ impl Scheduler {
if send_failed || is_complete { if send_failed || is_complete {
state.complete(&uuid); state.complete(&uuid);
continue;
} }
} }
// Send metrics once per forward pass (after all prefill and decode processing) total_time
{
let metrics = get_fwd_pass_metrics(&state, &kv_manager, &hit_rates, dp_rank);
let _ = metrics_tx.send(metrics);
}
// 4. Sleep to maintain target iteration timing
let target_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let elapsed = iteration_start.elapsed();
if elapsed < target_duration {
tokio::time::sleep(target_duration - elapsed).await;
}
}
});
Self {
request_tx,
metrics_rx,
}
}
/// Add a new request to the waiting queue
pub async fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
/// Get a watch receiver for forward pass metrics
pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
self.metrics_rx.clone()
}
} }
/// Calculate forward pass metrics from current state /// Calculate forward pass metrics from current state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment