Unverified Commit 377711d3 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: de-spaghetti mocker scheduler (#4789)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 8f0ac731
...@@ -280,27 +280,11 @@ impl Scheduler { ...@@ -280,27 +280,11 @@ impl Scheduler {
loop { loop {
// 1. Receive requests // 1. Receive requests
if state.is_empty() { if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
// Fully idle - block until new request arrives .await
tokio::select! { .is_none()
biased; {
Some(request) = request_rx.recv() => { break;
state.receive(request);
}
_ = cancel_token_clone.cancelled() => {
break;
}
}
} else {
// Has active/waiting work - collect any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
state.receive(request);
}
// Check for cancellation
if cancel_token_clone.is_cancelled() {
break;
}
} }
// Start timing for this forward pass (schedule + simulate) // Start timing for this forward pass (schedule + simulate)
...@@ -310,106 +294,30 @@ impl Scheduler { ...@@ -310,106 +294,30 @@ impl Scheduler {
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args); try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode // 3. Simulate prefill + decode
let mut total_time = Duration::ZERO; let prefill_time = simulate_prefill(
&mut state,
// Process prefilling &mut kv_manager,
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = &args.perf_model,
state.try_prefill(&args.perf_model) args.worker_type,
{ );
// NOTE: Prefill cost/time is always incremented for new blocks, even if they let decode_time = simulate_decode(
// could be cached by other requests in the same batch. This matches vLLM behavior. &mut state,
// For decode workers, skip adding prefill compute time &mut kv_manager,
if args.worker_type != WorkerType::Decode { &output_tx,
total_time += Duration::from_secs_f64(prefill_compute / 1000.0); &args.perf_model,
} args.block_size,
);
if let Some(creation_signal) = maybe_creation_signal let total_time = prefill_time + decode_time;
&& !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal))
{ // 4. Send metrics once per forward pass (after all prefill and decode processing)
panic!("Block allocation for prefilling cannot fail."); let _ = metrics_tx.send(get_fwd_pass_metrics(
} &state,
&kv_manager,
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill &hit_rates,
if !is_full_prefill { dp_rank,
break; ));
}
} // 5. Sleep to maintain target iteration timing
// Compute decode timing
let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
// Compute average context length across all active decode requests
let (total_length, count) = state
.decode
.keys()
.filter_map(|uuid| state.requests.get(uuid))
.fold((0usize, 0usize), |(sum, cnt), req| {
if let Request::Active(seq) = req {
(sum + seq.len(), cnt + 1)
} else {
(sum, cnt)
}
});
let context_length = if count > 0 { total_length / count } else { 0 };
let decoding_time = args
.perf_model
.predict_decode_time(active_kv_tokens, context_length);
total_time += Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens();
// Process decoding
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
for uuid in uuids {
let Some(sequence) = state.run(uuid) else {
continue;
};
let signals = sequence.generate();
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state.preempt() {
kv_manager.process(&signal);
}
continue;
}
// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output =
sequence.generated_tokens() > sequence.already_generated_tokens();
let mut send_failed = false;
if should_output {
send_failed = output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
}
if send_failed {
for signal in &sequence.free_signal() {
kv_manager.process(signal);
}
}
if send_failed || is_complete {
state.complete(&uuid);
continue;
}
}
// Send metrics once per forward pass (after all prefill and decode processing)
{
let metrics = get_fwd_pass_metrics(&state, &kv_manager, &hit_rates, dp_rank);
let _ = metrics_tx.send(metrics);
}
// 4. Sleep to maintain target iteration timing
let target_duration = let target_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let elapsed = iteration_start.elapsed(); let elapsed = iteration_start.elapsed();
...@@ -441,6 +349,148 @@ impl Scheduler { ...@@ -441,6 +349,148 @@ impl Scheduler {
} }
} }
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
state: &mut SchedulerState,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if state.is_empty() {
// Fully idle - block until new request arrives
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return None;
}
Some(request) = request_rx.recv() => {
state.receive(request);
return Some(());
}
}
}
// Has active/waiting work - collect any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
state.receive(request);
}
Some(())
}
/// Simulate prefill phase for all pending prefill requests.
/// Returns the total prefill compute time.
fn simulate_prefill(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
perf_model: &PerfModel,
worker_type: WorkerType,
) -> Duration {
let mut total_time = Duration::ZERO;
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state.try_prefill(perf_model)
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
// For decode workers, skip adding prefill compute time
if worker_type != WorkerType::Decode {
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
}
if let Some(creation_signal) = maybe_creation_signal
&& !process_signals(kv_manager, std::slice::from_ref(&creation_signal))
{
panic!("Block allocation for prefilling cannot fail.");
}
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill {
break;
}
}
total_time
}
/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
fn simulate_decode(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
perf_model: &PerfModel,
block_size: usize,
) -> Duration {
// Compute decode timing
let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
// Compute average context length across all active decode requests
let (total_length, count) = state
.decode
.keys()
.filter_map(|uuid| state.requests.get(uuid))
.fold((0usize, 0usize), |(sum, cnt), req| {
if let Request::Active(seq) = req {
(sum + seq.len(), cnt + 1)
} else {
(sum, cnt)
}
});
let context_length = if count > 0 { total_length / count } else { 0 };
let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
let total_time = Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens();
// Process decoding
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
for uuid in uuids {
let Some(sequence) = state.run(uuid) else {
continue;
};
let signals = sequence.generate();
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(kv_manager, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state.preempt() {
kv_manager.process(&signal);
}
continue;
}
// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
let send_failed = should_output
&& output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
if send_failed {
for signal in &sequence.free_signal() {
kv_manager.process(signal);
}
}
if send_failed || is_complete {
state.complete(&uuid);
}
}
total_time
}
/// Calculate forward pass metrics from current state /// Calculate forward pass metrics from current state
fn get_fwd_pass_metrics( fn get_fwd_pass_metrics(
state: &SchedulerState, state: &SchedulerState,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment