Unverified Commit 930721c8 authored by Yongming Ding's avatar Yongming Ding Committed by GitHub
Browse files

feat(mocker): add SGLang engine simulation (#6977)


Signed-off-by: default avatarYongming Ding <yongmingd@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 22fb3398
......@@ -243,18 +243,18 @@ impl Scheduler {
_cancel_guard: cancel_guard,
}
}
}
/// Add a new request to the prefill queue
pub async fn receive(&self, request: DirectRequest) {
impl super::SchedulerHandle for Scheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
/// Get a watch receiver for forward pass metrics
pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
......@@ -564,6 +564,7 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
#[cfg(test)]
mod tests {
use super::*;
use crate::scheduler::SchedulerHandle;
use rstest::rstest;
use std::time::Duration;
use tokio::time::interval;
......@@ -592,125 +593,28 @@ mod tests {
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
unsafe { std::env::set_var("RUST_LOG", "debug") };
let kv_capacity: usize = 500;
let block_size: usize = 64;
let num_requests: usize = 200;
let input_len: usize = 1000;
let max_output_tokens: usize = 100;
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args using builder - now including enable_prefix_caching
let args = MockEngineArgs::builder()
.num_gpu_blocks(kv_capacity)
.block_size(block_size)
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
// Create scheduler with new args struct
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
// Create shared tokens for caching case
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
// Create test requests
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
// For caching case: use shared tokens for first half, random for second half
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
// For random case: create unique random token vector for each request
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
let request = DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
};
scheduler.receive(request).await;
}
let start_time = std::time::Instant::now();
// Collect all generated tokens (should be num_requests * max_output_tokens)
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
// Set up a timeout that causes the test to panic if no tokens are received for 2 seconds
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_) = output_rx.recv() => {
received_tokens += 1;
// Reset timeout whenever we receive a token
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => {
// Break instead of panicking when timeout occurs
break;
}
}
}
// Calculate and print elapsed time
let elapsed = start_time.elapsed();
println!(
"Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
if use_shared_tokens {
"caching"
} else {
"random"
}
);
// Assert that we received the expected number of tokens
assert!(
received_tokens == expected_tokens,
"Received {received_tokens} tokens but expected exactly {expected_tokens}"
);
// Wait a bit for final metrics update to propagate
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_scheduler_idle(&metrics);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
#[tokio::test]
......@@ -746,7 +650,7 @@ mod tests {
uuid: None,
dp_rank: 0,
};
scheduler.receive(request).await;
scheduler.receive(request);
// Sleep for 0.1 second after each request
tokio::time::sleep(Duration::from_millis(100)).await;
}
......@@ -979,7 +883,7 @@ mod tests {
dp_rank: 0,
};
scheduler.receive(request).await;
scheduler.receive(request);
// Receive exactly 129 tokens
let mut received_count = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment