Unverified Commit 941ad640 authored by luc-hiverge's avatar luc-hiverge Committed by GitHub
Browse files

fix: emit first token creation signal after sleeping. (#5681)


Signed-off-by: default avatarLuc Grosheintz <luc@hiverge.ai>
parent d1697dc3
...@@ -297,42 +297,33 @@ impl Scheduler { ...@@ -297,42 +297,33 @@ impl Scheduler {
break; break;
} }
// Start timing for this forward pass (schedule + simulate)
let iteration_start = std::time::Instant::now();
// 2. Schedule waiting requests (once per iteration) // 2. Schedule waiting requests (once per iteration)
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args); try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode // 3. Simulate prefill + decode
let prefill_time = simulate_prefill( simulate_prefill(
&mut state, &mut state,
&mut kv_manager, &mut kv_manager,
&args.perf_model, &args.perf_model,
args.worker_type, args.worker_type,
); args.speedup_ratio,
let decode_time = simulate_decode( )
.await;
simulate_decode(
&mut state, &mut state,
&mut kv_manager, &mut kv_manager,
&output_tx, &output_tx,
&args.perf_model, &args.perf_model,
args.block_size, args.block_size,
); args.speedup_ratio,
let total_time = prefill_time + decode_time; )
.await;
// 4. Send metrics once per forward pass (after all prefill and decode processing) // 4. Send metrics once per forward pass (after all prefill and decode processing)
let _ = metrics_tx.send(MockerMetrics { let _ = metrics_tx.send(MockerMetrics {
dp_rank, dp_rank,
active_decode_blocks: kv_manager.num_active_blocks() as u64, active_decode_blocks: kv_manager.num_active_blocks() as u64,
}); });
// 5. Sleep to maintain target iteration timing
let target_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let elapsed = iteration_start.elapsed();
if elapsed < target_duration {
tokio::time::sleep(target_duration - elapsed).await;
}
} }
}); });
...@@ -392,12 +383,14 @@ async fn receive_requests( ...@@ -392,12 +383,14 @@ async fn receive_requests(
/// Simulate prefill phase for all pending prefill requests. /// Simulate prefill phase for all pending prefill requests.
/// Returns the total prefill compute time. /// Returns the total prefill compute time.
fn simulate_prefill( async fn simulate_prefill(
state: &mut SchedulerState, state: &mut SchedulerState,
kv_manager: &mut KvManager, kv_manager: &mut KvManager,
perf_model: &PerfModel, perf_model: &PerfModel,
worker_type: WorkerType, worker_type: WorkerType,
speedup_ratio: f64,
) -> Duration { ) -> Duration {
let start_time = tokio::time::Instant::now();
let mut total_time = Duration::ZERO; let mut total_time = Duration::ZERO;
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
...@@ -422,18 +415,23 @@ fn simulate_prefill( ...@@ -422,18 +415,23 @@ fn simulate_prefill(
} }
} }
let deadline = start_time + Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
tokio::time::sleep_until(deadline).await;
total_time total_time
} }
/// Simulate decode phase for all active decode requests. /// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time. /// Returns the total decode compute time.
fn simulate_decode( async fn simulate_decode(
state: &mut SchedulerState, state: &mut SchedulerState,
kv_manager: &mut KvManager, kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
perf_model: &PerfModel, perf_model: &PerfModel,
block_size: usize, block_size: usize,
speedup_ratio: f64,
) -> Duration { ) -> Duration {
let start_time = tokio::time::Instant::now();
// Compute decode timing // Compute decode timing
let active_kv_tokens = kv_manager.num_active_blocks() * block_size; let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
// Compute average context length across all active decode requests // Compute average context length across all active decode requests
...@@ -496,6 +494,9 @@ fn simulate_decode( ...@@ -496,6 +494,9 @@ fn simulate_decode(
} }
} }
let deadline = start_time + Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
tokio::time::sleep_until(deadline).await;
total_time total_time
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment