Unverified Commit ba3ac235 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test: add router e2e test with mockers to per-merge ci (#2073)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 2fc65ad8
...@@ -191,7 +191,7 @@ impl KvRouter { ...@@ -191,7 +191,7 @@ impl KvRouter {
} }
}; };
if let Err(e) = kv_events_tx.send(event).await { if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!( tracing::warn!(
"failed to send kv event to indexer; shutting down: {:?}", "failed to send kv event to indexer; shutting down: {:?}",
e e
); );
......
...@@ -177,6 +177,13 @@ impl KvScheduler { ...@@ -177,6 +177,13 @@ impl KvScheduler {
request.respond(response); request.respond(response);
continue 'outer; continue 'outer;
} }
Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update");
endpoints_rx.changed().await.ok();
endpoints = endpoints_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids());
continue;
}
// TODO: this is not actually hooked up // TODO: this is not actually hooked up
Err(KvSchedulerError::AllWorkersBusy) => { Err(KvSchedulerError::AllWorkersBusy) => {
tracing::trace!("all workers busy; waiting for more capacity"); tracing::trace!("all workers busy; waiting for more capacity");
......
...@@ -51,7 +51,7 @@ use std::collections::HashMap; ...@@ -51,7 +51,7 @@ use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, Duration}; use tokio::time::Duration;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
...@@ -81,6 +81,10 @@ impl SchedulerState { ...@@ -81,6 +81,10 @@ impl SchedulerState {
} }
} }
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting. /// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
fn receive(&mut self, request: DirectRequest) -> Uuid { fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one // Use the provided UUID if available, otherwise generate a new one
...@@ -295,11 +299,25 @@ impl Scheduler { ...@@ -295,11 +299,25 @@ impl Scheduler {
// Spawn main background task with cancellation token // Spawn main background task with cancellation token
tokio::spawn(async move { tokio::spawn(async move {
let mut schedule_interval = interval(Duration::from_secs_f64(1e-3));
let mut simulate_interval = interval(Duration::from_secs_f64(1e-4));
let mut should_schedule = true; let mut should_schedule = true;
loop { loop {
{
let state_guard = state_clone.lock().await;
// Enqueue new request, blocks until at least one is received, so no redundant work is done
// TODO: clean this up? double lock acquisition is ugly, but needed to not hold the lock forever
if state_guard.is_empty() {
drop(state_guard);
let Some(request) = request_rx.recv().await else {
tracing::warn!("request sender is dropped");
break;
};
let mut state_guard = state_clone.lock().await;
state_guard.receive(request);
}
}
tokio::select! { tokio::select! {
biased; biased;
...@@ -310,7 +328,7 @@ impl Scheduler { ...@@ -310,7 +328,7 @@ impl Scheduler {
} }
// Try Scheduling Requests - runs on normal interval or after simulation // Try Scheduling Requests - runs on normal interval or after simulation
_ = schedule_interval.tick() => { _ = tokio::task::yield_now() => {
// Skip if we just ran scheduling after simulation to prevent consecutive runs // Skip if we just ran scheduling after simulation to prevent consecutive runs
if !should_schedule { if !should_schedule {
continue; continue;
...@@ -371,100 +389,117 @@ impl Scheduler { ...@@ -371,100 +389,117 @@ impl Scheduler {
_ = cancel_token_clone.cancelled() => { _ = cancel_token_clone.cancelled() => {
break; break;
} }
}
// Simulate running requests (prefill + decode) // Simulates prefill + decode
_ = simulate_interval.tick() => { let mut state_guard = state_clone.lock().await;
let mut state_guard = state_clone.lock().await; let mut kv_manager_guard = kv_manager_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;
// Base time needed for decoding using active percentage and quadratic formula
// Base time needed for decoding using active percentage and quadratic formula let active_perc = kv_manager_guard.get_active_perc();
let active_perc = kv_manager_guard.get_active_perc(); let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44; let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Process prefilling
// Process prefilling while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = state_guard.try_prefill() { state_guard.try_prefill()
// NOTE: Prefill cost/time is always incremented for new blocks, even if they {
// could be cached by other requests in the same batch. This matches vLLM behavior. // NOTE: Prefill cost/time is always incremented for new blocks, even if they
total_time += Duration::from_secs_f64(prefill_compute / 1000.0); // could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
if let Some(creation_signal) = maybe_creation_signal {
if !process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal)) { if let Some(creation_signal) = maybe_creation_signal {
panic!("Block allocation for prefilling cannot fail."); if !process_signals(
} &mut kv_manager_guard,
std::slice::from_ref(&creation_signal),
// Drain KV events and forward to relay after prefill signal processing ) {
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) { panic!("Block allocation for prefilling cannot fail.");
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
};
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill { break; }
} }
state_guard.reset_active_tokens(); // Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) =
// Process decoding (&kv_events_tx, &mut block_resp_rx)
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect(); {
if !uuids.is_empty() {should_schedule = true}; while let Ok(event) = rx.try_recv() {
for uuid in uuids { let _ = relay_tx.send(block_response_to_kv_event(event));
let Some(sequence) = state_guard.run(uuid) else {
continue;
};
let signals = sequence.generate();
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state_guard.preempt() {
kv_manager_guard.process(&signal);
}
continue;
} }
}
};
// Drain KV events and forward to relay after decode signal processing // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) { if !is_full_prefill {
while let Ok(event) = rx.try_recv() { break;
let _ = relay_tx.send(block_response_to_kv_event(event)); }
} }
}
// Check completion and send notification state_guard.reset_active_tokens();
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens(); // Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {
should_schedule = true
};
for uuid in uuids {
let Some(sequence) = state_guard.run(uuid) else {
continue;
};
let signals = sequence.generate();
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state_guard.preempt() {
kv_manager_guard.process(&signal);
}
continue;
}
let mut send_failed = false; // Drain KV events and forward to relay after decode signal processing
if should_output { if let (Some(ref relay_tx), Some(ref mut rx)) =
send_failed = output_tx_clone.as_ref().is_some_and(|tx| { (&kv_events_tx, &mut block_resp_rx)
tx.send(OutputSignal { uuid, completed: is_complete }).is_err() {
}); while let Ok(event) = rx.try_recv() {
} let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
if send_failed { // Check completion and send notification
for signal in &sequence.free_signal() { let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
kv_manager_guard.process(signal); let should_output =
} sequence.generated_tokens() > sequence.already_generated_tokens();
}
let mut send_failed = false;
if should_output {
send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
}
if send_failed || is_complete { if send_failed {
state_guard.complete(&uuid); for signal in &sequence.free_signal() {
continue; kv_manager_guard.process(signal);
}
} }
}
// Sleep once for the adjusted duration if send_failed || is_complete {
drop(kv_manager_guard); state_guard.complete(&uuid);
drop(state_guard); continue;
let adjusted_time = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 {
tokio::time::sleep(adjusted_time).await;
}
} }
} }
// Sleep once for the adjusted duration
drop(kv_manager_guard);
drop(state_guard);
let adjusted_time =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 {
tokio::time::sleep(adjusted_time).await;
}
} }
}); });
...@@ -632,6 +667,7 @@ mod tests { ...@@ -632,6 +667,7 @@ mod tests {
use super::*; use super::*;
use rstest::rstest; use rstest::rstest;
use std::time::Duration; use std::time::Duration;
use tokio::time::interval;
#[rstest] #[rstest]
#[case::case_1(false, false, false)] #[case::case_1(false, false, false)]
......
...@@ -33,6 +33,66 @@ logging.basicConfig( ...@@ -33,6 +33,66 @@ logging.basicConfig(
datefmt=DATE_FORMAT, # ISO 8601 UTC format datefmt=DATE_FORMAT, # ISO 8601 UTC format
) )
# List of models used in tests
TEST_MODELS = [
"Qwen/Qwen3-0.6B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llava-hf/llava-1.5-7b-hf",
]
def download_models(model_list=None):
"""Download models - can be called directly or via fixture
Args:
model_list: List of model IDs to download. If None, downloads TEST_MODELS.
"""
if model_list is None:
model_list = TEST_MODELS
# Check for HF_TOKEN in environment
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
logging.info("HF_TOKEN found in environment")
else:
logging.warning(
"HF_TOKEN not found in environment. "
"Some models may fail to download or you may encounter rate limits. "
"Get a token from https://huggingface.co/settings/tokens"
)
try:
from huggingface_hub import snapshot_download
for model_id in model_list:
logging.info(f"Pre-downloading model: {model_id}")
try:
# Download the full model snapshot (includes all files)
# HuggingFace will handle caching automatically
snapshot_download(
repo_id=model_id,
token=hf_token,
)
logging.info(f"Successfully pre-downloaded: {model_id}")
except Exception as e:
logging.error(f"Failed to pre-download {model_id}: {e}")
# Don't fail the fixture - let individual tests handle missing models
except ImportError:
logging.warning(
"huggingface_hub not installed. "
"Models will be downloaded during test execution."
)
@pytest.fixture(scope="session")
def predownload_models():
"""Fixture wrapper around download_models for all TEST_MODELS"""
download_models()
yield
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def logger(request): def logger(request):
...@@ -64,6 +124,18 @@ def pytest_collection_modifyitems(config, items): ...@@ -64,6 +124,18 @@ def pytest_collection_modifyitems(config, items):
if "tensorrtllm" in item.keywords: if "tensorrtllm" in item.keywords:
item.add_marker(skip_tensorrtllm) item.add_marker(skip_tensorrtllm)
# Auto-inject predownload_models fixture for serve tests only (not router tests)
# Skip items that don't have fixturenames (like MypyFileItem)
if hasattr(item, "fixturenames"):
# Only apply to tests in the serve directory
if (
("serve" in str(item.path))
and ("predownload_models" not in item.fixturenames)
and (not item.get_closest_marker("skip_model_download"))
):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models")
class EtcdServer(ManagedProcess): class EtcdServer(ManagedProcess):
def __init__(self, request, port=2379, timeout=300): def __init__(self, request, port=2379, timeout=300):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import os
import aiohttp
import pytest
from tests.conftest import download_models
from tests.utils.managed_process import ManagedProcess
pytestmark = pytest.mark.pre_merge
logger = logging.getLogger(__name__)
MODEL_NAME = "Qwen/Qwen3-0.6B"
NUM_MOCKERS = 2
SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 100
PORT = 8090 # Starting port for mocker instances
class MockerProcess(ManagedProcess):
"""Manages a single mocker engine instance"""
def __init__(self, request, endpoint: str, mocker_args_file: str):
command = [
"python",
"-m",
"dynamo.mocker",
"--model-path",
MODEL_NAME,
"--extra-engine-args",
mocker_args_file,
"--endpoint",
endpoint,
]
super().__init__(
command=command,
timeout=60,
display_output=True,
health_check_ports=[],
health_check_urls=[],
log_dir=request.node.name,
terminate_existing=False,
)
self.endpoint = endpoint
class KVRouterProcess(ManagedProcess):
"""Manages the KV router process using dynamo.frontend"""
def __init__(self, request, frontend_port: int):
command = [
"python",
"-m",
"dynamo.frontend",
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
]
super().__init__(
command=command,
timeout=60,
display_output=True,
health_check_ports=[frontend_port],
health_check_urls=[
(f"http://localhost:{frontend_port}/v1/models", self._check_ready)
],
log_dir=request.node.name,
terminate_existing=False,
)
self.port = frontend_port
def _check_ready(self, response):
"""Check if KV router is ready"""
return response.status_code == 200
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
@pytest.mark.pre_merge
def test_mocker_kv_router(request, runtime_services):
"""
Test KV router with multiple mocker engine instances.
This test doesn't require GPUs and runs quickly for pre-merge validation.
"""
# Download only the Qwen model for this test
download_models([MODEL_NAME])
# runtime_services starts etcd and nats
logger.info("Starting mocker KV router test")
# Create mocker args file
mocker_args = {"speedup_ratio": SPEEDUP_RATIO}
mocker_args_file = os.path.join(request.node.name, "mocker_args.json")
with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f)
# Start mocker instances
mocker_processes = []
try:
# Start KV router (frontend)
frontend_port = PORT
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(request, frontend_port)
kv_router.__enter__()
for i in range(NUM_MOCKERS):
# Use unique endpoints for each mocker
endpoint = "dyn://test-namespace.mocker.generate"
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}")
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker_processes.append(mocker)
# Start all mockers
for mocker in mocker_processes:
mocker.__enter__()
# Send test requests
test_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Use async to send requests concurrently for better performance
asyncio.run(
send_concurrent_requests(
f"http://localhost:{frontend_port}/v1/chat/completions",
test_payload,
NUM_REQUESTS,
)
)
logger.info(f"Successfully completed {NUM_REQUESTS} requests")
finally:
# Clean up
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
for mocker in mocker_processes:
mocker.__exit__(None, None, None)
if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4):
"""Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second
for attempt in range(max_retries + 1):
await asyncio.sleep(wait_time)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.info(f"First request succeeded on attempt {attempt + 1}")
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
wait_time *= 2 # Double the wait time
return False
async def send_concurrent_requests(url: str, payload: dict, num_requests: int):
"""Send multiple requests concurrently and verify responses"""
# First, send a test request with retry to ensure the system is ready
logger.info("Sending initial test request with retry...")
if not await send_request_with_retry(url, payload):
raise RuntimeError("Failed to connect after multiple retries")
async def send_single_request(session: aiohttp.ClientSession, request_id: int):
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(
f"Request {request_id} failed with status {response.status}"
)
return False
# For streaming responses, read the entire stream
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug(
f"Request {request_id} completed with {len(chunks)} chunks"
)
return True
except Exception as e:
logger.error(f"Request {request_id} failed with error: {e}")
return False
# Send all requests at once
async with aiohttp.ClientSession() as session:
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r)
failed = sum(1 for r in results if not r)
logger.info(f"Completed all requests: {successful} successful, {failed} failed")
assert (
successful == num_requests
), f"Expected {num_requests} successful requests, got {successful}"
logger.info(f"All {num_requests} requests completed successfully")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import pytest
# List of models used in the serve tests
SERVE_TEST_MODELS = [
"Qwen/Qwen3-0.6B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llava-hf/llava-1.5-7b-hf",
]
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session")
def predownload_models():
# Check for HF_TOKEN in environment
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
logger.info("HF_TOKEN found in environment")
else:
logger.warning(
"HF_TOKEN not found in environment. "
"Some models may fail to download or you may encounter rate limits. "
"Get a token from https://huggingface.co/settings/tokens"
)
try:
from huggingface_hub import snapshot_download
for model_id in SERVE_TEST_MODELS:
logger.info(f"Pre-downloading model: {model_id}")
try:
# Download the full model snapshot (includes all files)
# HuggingFace will handle caching automatically
snapshot_download(
repo_id=model_id,
token=hf_token,
)
logger.info(f"Successfully pre-downloaded: {model_id}")
except Exception as e:
logger.error(f"Failed to pre-download {model_id}: {e}")
# Don't fail the fixture - let individual tests handle missing models
except ImportError:
logger.warning(
"huggingface_hub not installed. "
"Models will be downloaded during test execution."
)
yield
# Automatically use the predownload fixture for all serve tests
def pytest_collection_modifyitems(config, items):
for item in items:
# Skip items that don't have fixturenames (like MypyFileItem)
if not hasattr(item, "fixturenames"):
continue
# Only apply to tests in the serve directory
if "serve" in str(item.path):
# Check if the test already uses the fixture
if "predownload_models" not in item.fixturenames:
# Don't add if test explicitly marks to skip model download
if not item.get_closest_marker("skip_model_download"):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment