Unverified Commit 96a55928 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(test): declutter router common.py into helper and router_process modules (#7210)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent e094a2db
......@@ -4,9 +4,7 @@
import asyncio
import json
import logging
import os
import random
import string
import time
from typing import TYPE_CHECKING, Any, Optional
......@@ -14,661 +12,24 @@ import aiohttp
import nats
from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.runtime import DistributedRuntime
from tests.utils.managed_process import ManagedProcess
from tests.router.helper import (
_nats_server,
assert_event_dumps_equal,
get_runtime,
send_inflight_requests,
send_request_via_python_kv_router,
send_request_with_retry,
verify_response_timing,
wait_for_frontend_ready,
wait_for_workers_ready,
)
from tests.router.router_process import KVRouterProcess
if TYPE_CHECKING:
from tests.conftest import NatsServer
logger = logging.getLogger(__name__)
NUM_REQUESTS = 100
BLOCK_SIZE = 16
def _nats_server() -> str:
# Prefer dynamically-started NATS from per-test fixtures when present.
return os.environ.get("NATS_SERVER", "nats://localhost:4222")
########################################################
# Helper Classes
########################################################
class KVRouterProcess(ManagedProcess):
"""Manages the KV router process using dynamo.frontend"""
def __init__(
self,
request,
block_size: int,
frontend_port: int,
namespace: str,
store_backend: str = "etcd",
enforce_disagg: bool = False,
blocks_threshold: float | None = None,
tokens_threshold: float | None = None,
tokens_threshold_frac: float | None = None,
request_plane: str = "nats",
durable_kv_events: bool = False,
):
command = [
"python3",
"-m",
"dynamo.frontend",
"--kv-cache-block-size",
str(block_size),
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
"--discovery-backend",
store_backend,
"--namespace",
namespace,
]
if enforce_disagg:
command.append("--enforce-disagg")
if blocks_threshold is not None:
command.extend(["--active-decode-blocks-threshold", str(blocks_threshold)])
if tokens_threshold is not None:
command.extend(["--active-prefill-tokens-threshold", str(tokens_threshold)])
if tokens_threshold_frac is not None:
command.extend(
["--active-prefill-tokens-threshold-frac", str(tokens_threshold_frac)]
)
if durable_kv_events:
command.append("--router-durable-kv-events")
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
super().__init__(
command=command,
env=env,
timeout=60,
display_output=True,
health_check_ports=[frontend_port],
health_check_urls=[
(f"http://localhost:{frontend_port}/v1/models", self._check_ready)
],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
)
self.port = frontend_port
def _check_ready(self, response):
"""Check if KV router is ready"""
return response.status_code == 200
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
def generate_random_suffix() -> str:
"""Generate a 10-character random alphabetic suffix for namespace isolation."""
return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311
def assert_event_dumps_equal(
expected: list[dict],
actual: list[dict],
expected_label: str,
actual_label: str,
) -> None:
"""Assert two sorted event dump lists are equal, ignoring event_id fields."""
assert len(expected) == len(actual), (
f"{expected_label} has {len(expected)} events, "
f"{actual_label} has {len(actual)} events"
)
differences = []
for i, (exp_item, act_item) in enumerate(zip(expected, actual)):
exp_compare = exp_item.copy()
act_compare = act_item.copy()
if "event" in exp_compare and "event_id" in exp_compare["event"]:
del exp_compare["event"]["event_id"]
if "event" in act_compare and "event_id" in act_compare["event"]:
del act_compare["event"]["event_id"]
if exp_compare != act_compare:
differences.append(
{"index": i, expected_label: exp_item, actual_label: act_item}
)
if differences:
error_msg = (
f"{expected_label} and {actual_label} differ. "
f"Found {len(differences)} differences:\n"
)
for diff in differences:
error_msg += f"\nDifference at index {diff['index']}:\n"
error_msg += (
f"{expected_label}: {json.dumps(diff[expected_label], indent=2)}\n"
)
error_msg += f"{actual_label}: {json.dumps(diff[actual_label], indent=2)}\n"
error_msg += "-" * 80 + "\n"
assert False, error_msg
def verify_response_worker_ids(
response_worker_ids: list[dict[str, Optional[int]]],
key: str,
expected_worker_id: int,
) -> None:
"""Verify that all responses have the same worker ID for a given key.
Args:
response_worker_ids: List of dicts with worker ID info from responses.
key: The key to check (e.g., "decode_worker_id" or "prefill_worker_id").
expected_worker_id: The expected worker ID value.
Raises:
AssertionError: If any response is missing the key, values differ, or don't match expected.
"""
worker_ids = [r.get(key) for r in response_worker_ids]
logger.info(f"Response {key}s: {worker_ids}")
# All responses should have the key
assert all(
wid is not None for wid in worker_ids
), f"Expected all {len(response_worker_ids)} responses to have {key}, got: {worker_ids}"
# All values should be the same (due to prefix reuse routing)
unique_ids = set(worker_ids)
assert len(unique_ids) == 1, (
f"Expected all responses to have the same {key} (due to prefix reuse), "
f"but found {len(unique_ids)} unique values: {unique_ids}"
)
# The value should match the expected worker ID
actual_worker_id = worker_ids[0]
assert actual_worker_id == expected_worker_id, (
f"Expected {key}={expected_worker_id} (forced in first request), "
f"but got {key}={actual_worker_id}"
)
logger.info(
f"✓ Verified all {len(response_worker_ids)} responses have {key}={actual_worker_id}"
)
def verify_response_timing(timing_info: dict[str, Any]) -> None:
"""Verify timing info has valid values (ttft_ms > 0, total_time_ms > 0)."""
ttft_ms = timing_info.get("ttft_ms")
total_time_ms = timing_info.get("total_time_ms")
assert ttft_ms is not None and ttft_ms > 0, f"Expected ttft_ms > 0, got: {ttft_ms}"
assert (
total_time_ms is not None and total_time_ms > 0
), f"Expected total_time_ms > 0, got: {total_time_ms}"
assert (
total_time_ms >= ttft_ms
), f"Expected total_time_ms >= ttft_ms, got {total_time_ms} < {ttft_ms}"
logger.info(
f"✓ Verified timing: ttft_ms={ttft_ms:.2f}, total_time_ms={total_time_ms:.2f}"
)
########################################################
# Utility functions
########################################################
async def wait_for_frontend_ready(
frontend_url: str, expected_num_workers: int = 2, timeout: int = 120
):
"""Wait for backend worker(s) to be ready via the HTTP frontend (OpenAI API).
This function performs a two-phase readiness check through the frontend HTTP server:
1. Polls GET /v1/models until at least one model is registered (workers connected)
2. Sends a test POST to /v1/chat/completions to verify the request pipeline is functional
Use this when testing through the HTTP frontend server (dynamo.frontend).
For direct Python API testing with KvRouter, use wait_for_workers_ready() instead.
Args:
frontend_url: Base URL of the frontend HTTP server (e.g., "http://localhost:8000")
expected_num_workers: Number of workers to wait for (currently logs but doesn't enforce)
timeout: Maximum time to wait in seconds for both phases combined
Raises:
TimeoutError: If workers don't register or pipeline doesn't become ready within timeout
aiohttp.ClientError: If HTTP requests fail unexpectedly
"""
models_url = f"{frontend_url}/v1/models"
chat_url = f"{frontend_url}/v1/chat/completions"
start_time = asyncio.get_event_loop().time()
logger.info(
f"Waiting for {expected_num_workers} workers to register on HTTP frontend (timeout={timeout}s)..."
)
# Phase 1: Wait for models to appear in /v1/models
model_name = None
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > timeout:
raise TimeoutError(
f"Timeout waiting for vLLM workers. Waited {elapsed:.1f}s, no workers registered."
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(models_url) as response:
if response.status == 200:
data = await response.json()
models = data.get("data", [])
if len(models) > 0:
model_name = models[0].get("id")
logger.info(
f"Workers registered. Found {len(models)} model(s): {[m.get('id') for m in models]}"
)
break
else:
logger.debug(
f"No models registered yet (elapsed: {elapsed:.1f}s)"
)
except Exception as e:
logger.debug(f"Error checking models endpoint: {e}")
# Wait before next poll
await asyncio.sleep(1)
# Phase 2: Wait for chat completions pipeline to be ready
logger.info("Waiting for chat completions pipeline to be built...")
test_payload = {
"model": model_name,
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 1,
"stream": False,
}
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > timeout:
raise TimeoutError(
f"Timeout waiting for chat completions pipeline. Waited {elapsed:.1f}s."
)
try:
async with aiohttp.ClientSession() as session:
async with session.post(chat_url, json=test_payload) as response:
if response.status == 200:
logger.info("Chat completions pipeline ready!")
return
else:
logger.debug(
f"Chat completions not ready yet, status {response.status} (elapsed: {elapsed:.1f}s)"
)
except Exception as e:
logger.debug(f"Error testing chat completions: {e}")
# Wait before next poll
await asyncio.sleep(1)
async def wait_for_workers_ready(
endpoint,
router: KvRouter,
expected_num_workers: int,
model_name: str,
) -> list[int]:
"""Wait for workers to be ready and return their instance IDs.
Supports mocker and vLLM workers.
This function polls the endpoint's client for instance IDs until the expected
number of workers are available, then sends a warmup request to verify they
can handle requests.
Args:
endpoint: The endpoint object to get the client from
router: The KvRouter to use for sending warmup requests
expected_num_workers: Number of workers to wait for
Returns:
Sorted list of unique instance IDs (ints).
Raises:
AssertionError: If workers don't become ready or warmup request fails.
"""
logger.info("Waiting for workers to be ready")
# Get the client from the endpoint
client = await endpoint.client()
# Poll for instance IDs until we have the expected number
instance_ids: list[int] = []
max_wait_time = 60 # seconds
start_time = asyncio.get_running_loop().time()
while len(instance_ids) < expected_num_workers:
instance_ids = client.instance_ids()
logger.info(f"Found {len(instance_ids)} instance(s): {instance_ids}")
if len(instance_ids) >= expected_num_workers:
break
# Check timeout
if asyncio.get_running_loop().time() - start_time > max_wait_time:
raise AssertionError(
f"Timeout waiting for workers. Found {len(instance_ids)} instance(s), expected {expected_num_workers}"
)
# Wait 1 second before polling again
await asyncio.sleep(1.0)
# Send a warmup request to verify workers can handle requests
test_token_ids = [random.randint(1, 10000) for _ in range(4)]
logger.info(f"Sending warmup request with {len(test_token_ids)} tokens")
try:
await send_request_via_python_kv_router(
kv_python_router=router,
model_name=model_name,
token_ids=test_token_ids,
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True,
"max_tokens": 2,
},
)
except Exception as e:
raise AssertionError(f"Warmup request failed: {e}")
logger.info(f"All {len(instance_ids)} workers are ready")
return sorted(instance_ids)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8):
"""Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second
for attempt in range(max_retries + 1):
await asyncio.sleep(wait_time)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.debug(
f"First request succeeded on attempt {attempt + 1}"
)
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
wait_time *= 2 # Double the wait time
return False
def get_runtime(store_backend="etcd", request_plane="tcp"):
"""Create a DistributedRuntime instance for testing.
Args:
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: How frontend talks to backend ("tcp", "http" or "nats"). Defaults to "tcp".
"""
try:
# Try to get running loop (works in async context)
loop = asyncio.get_running_loop()
except RuntimeError:
# No running loop, create a new one (sync context)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return DistributedRuntime(loop, store_backend, request_plane)
async def check_nats_consumers(namespace: str, expected_count: Optional[int] = None):
"""Check NATS consumers for the KV events stream.
Args:
namespace: The namespace to check consumers for
expected_count: Optional expected number of consumers. If provided, asserts if count doesn't match.
Returns:
List of consumer names
"""
component_subject = f"namespace.{namespace}.component.mocker"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
stream_name = f"{slugified}-kv-events"
logger.info(f"Checking consumers for stream: {stream_name}")
nc = await nats.connect(servers=_nats_server())
try:
js = nc.jetstream()
consumer_infos = await js.consumers_info(stream_name)
consumer_names = [info.name for info in consumer_infos]
logger.info(f"Found {len(consumer_names)} consumers: {consumer_names}")
# Log detailed consumer info
for info in consumer_infos:
logger.info(
f"Consumer {info.name}: "
f"num_pending={info.num_pending}, "
f"num_ack_pending={info.num_ack_pending}, "
f"ack_floor={info.ack_floor}, "
f"delivered={info.delivered}"
)
if expected_count is not None:
assert (
len(consumer_names) == expected_count
), f"Expected {expected_count} durable consumers, found {len(consumer_names)}: {consumer_names}"
logger.info(f"✓ Verified {expected_count} durable consumers exist")
return consumer_names
finally:
await nc.close()
async def send_inflight_requests(urls: list, payload: dict, num_requests: int):
"""Send multiple requests concurrently, alternating between URLs if multiple provided"""
# First, send test requests with retry to ensure all systems are ready
for i, url in enumerate(urls):
logger.info(f"Sending initial test request to URL {i} ({url}) with retry...")
if not await send_request_with_retry(url, payload):
raise RuntimeError(f"Failed to connect to URL {i} after multiple retries")
async def send_single_request(session: aiohttp.ClientSession, request_id: int):
# Alternate between URLs based on request_id
url = urls[request_id % len(urls)]
url_index = request_id % len(urls)
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with status {response.status}"
)
return False
# For streaming responses, read the entire stream
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug(
f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks"
)
return True
except Exception as e:
logger.error(
f"Request {request_id} to URL {url_index} failed with error: {e}"
)
return False
# Send all requests at once
async with aiohttp.ClientSession() as session:
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful = sum(1 for r in results if r if r is True)
failed = num_requests - successful
logger.info(f"Completed all requests: {successful} successful, {failed} failed")
assert (
successful == num_requests
), f"Expected {num_requests} successful requests, got {successful}"
logger.info(f"All {num_requests} requests completed successfully")
async def send_request_via_python_kv_router(
kv_python_router: KvRouter,
model_name: str,
token_ids: list,
initial_wait: float,
max_retries: int,
stop_conditions: Optional[dict] = None,
sampling_options: Optional[dict] = None,
output_options: Optional[dict] = None,
router_config_override: Optional[dict] = None,
worker_id: Optional[
int
] = None, # If None, Router will select the best available worker
dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0)
return_worker_ids: bool = False, # If True, return worker IDs from response
) -> bool | dict[str, Optional[int]]:
"""Send a request to the specified worker instance.
Args:
return_worker_ids: If True, returns a dict with prefill_worker_id and decode_worker_id.
If False, returns True on success or False on failure.
Returns:
If return_worker_ids=False: True if workers respond, otherwise raises or returns False.
If return_worker_ids=True: Dict with 'prefill_worker_id' and 'decode_worker_id' keys.
"""
wait_time = initial_wait
log_message = (
f"worker with worker_id={worker_id}"
if worker_id is not None
else "the best available worker"
)
# Retry loop sending request to worker with exponential backoff
stream = None
for attempt in range(max_retries + 1):
try:
logger.debug(f"Sending request to {log_message} (attempt {attempt + 1})")
stream = await kv_python_router.generate(
token_ids=token_ids,
model=model_name,
stop_conditions=stop_conditions, # type: ignore[arg-type]
sampling_options=sampling_options, # type: ignore[arg-type]
output_options=output_options, # type: ignore[arg-type]
router_config_override=router_config_override, # type: ignore[arg-type]
worker_id=worker_id,
dp_rank=dp_rank,
)
if stream is not None:
logger.debug(f"Request succeeded on attempt {attempt + 1}")
break
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
await asyncio.sleep(wait_time)
wait_time *= 2
else:
raise RuntimeError(
f"Failed to connect to workers after {max_retries + 1} attempts"
) from e
if stream is None:
raise RuntimeError(
f"Failed to get a valid stream from workers after {max_retries + 1} attempts"
)
# Collect tokens and worker IDs from the SSE stream
generated_tokens = []
prefill_worker_id: Optional[int] = None
decode_worker_id: Optional[int] = None
prefill_dp_rank: Optional[int] = None
decode_dp_rank: Optional[int] = None
async for response in stream:
if isinstance(response, dict):
# Check if response has token_ids
if "token_ids" in response:
tokens = response["token_ids"]
if isinstance(tokens, list):
generated_tokens.extend(tokens)
logger.debug(f"Received {len(tokens)} tokens: {tokens}")
# Check for finish reason
if "finish_reason" in response:
logger.debug(
f"Stream finished with reason: {response['finish_reason']}"
)
# Extract worker IDs and dp_ranks from disaggregated_params if present
if return_worker_ids and "disaggregated_params" in response:
disagg_params = response["disaggregated_params"]
if isinstance(disagg_params, dict) and "worker_id" in disagg_params:
worker_id_info = disagg_params["worker_id"]
if isinstance(worker_id_info, dict):
if "prefill_worker_id" in worker_id_info:
prefill_worker_id = worker_id_info["prefill_worker_id"]
if "decode_worker_id" in worker_id_info:
decode_worker_id = worker_id_info["decode_worker_id"]
if "prefill_dp_rank" in worker_id_info:
prefill_dp_rank = worker_id_info["prefill_dp_rank"]
if "decode_dp_rank" in worker_id_info:
decode_dp_rank = worker_id_info["decode_dp_rank"]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.debug(f"Total generated tokens: {len(generated_tokens)}")
if (
stop_conditions
and "max_tokens" in stop_conditions
and "ignore_eos" in stop_conditions
and stop_conditions["ignore_eos"]
):
max_tokens = int(stop_conditions["max_tokens"])
assert len(generated_tokens) == max_tokens, (
f"Expected exactly {max_tokens} tokens but got {len(generated_tokens)}. "
f"Tokens: {generated_tokens}"
)
logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvRouter with ignore_eos=True"
)
if return_worker_ids:
return {
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
"prefill_dp_rank": prefill_dp_rank,
"decode_dp_rank": decode_dp_rank,
}
return True
########################################################
# Test templates
......@@ -777,8 +138,6 @@ def _test_router_two_routers(
Raises:
AssertionError: If consumer lifecycle verification fails
"""
import nats
kv_routers = []
try:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import os
import random
import string
from typing import Any, Optional
import aiohttp
import nats
from dynamo.llm import KvRouter
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__)
NUM_REQUESTS = 100
BLOCK_SIZE = 16
def _nats_server() -> str:
# Prefer dynamically-started NATS from per-test fixtures when present.
return os.environ.get("NATS_SERVER", "nats://localhost:4222")
def generate_random_suffix() -> str:
"""Generate a 10-character random alphabetic suffix for namespace isolation."""
return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311
def assert_event_dumps_equal(
expected: list[dict],
actual: list[dict],
expected_label: str,
actual_label: str,
) -> None:
"""Assert two sorted event dump lists are equal, ignoring event_id fields."""
assert len(expected) == len(actual), (
f"{expected_label} has {len(expected)} events, "
f"{actual_label} has {len(actual)} events"
)
differences = []
for i, (exp_item, act_item) in enumerate(zip(expected, actual)):
exp_compare = exp_item.copy()
act_compare = act_item.copy()
if "event" in exp_compare and "event_id" in exp_compare["event"]:
del exp_compare["event"]["event_id"]
if "event" in act_compare and "event_id" in act_compare["event"]:
del act_compare["event"]["event_id"]
if exp_compare != act_compare:
differences.append(
{"index": i, expected_label: exp_item, actual_label: act_item}
)
if differences:
error_msg = (
f"{expected_label} and {actual_label} differ. "
f"Found {len(differences)} differences:\n"
)
for diff in differences:
error_msg += f"\nDifference at index {diff['index']}:\n"
error_msg += (
f"{expected_label}: {json.dumps(diff[expected_label], indent=2)}\n"
)
error_msg += f"{actual_label}: {json.dumps(diff[actual_label], indent=2)}\n"
error_msg += "-" * 80 + "\n"
assert False, error_msg
def verify_response_worker_ids(
response_worker_ids: list[dict[str, Optional[int]]],
key: str,
expected_worker_id: int,
) -> None:
"""Verify that all responses have the same worker ID for a given key.
Args:
response_worker_ids: List of dicts with worker ID info from responses.
key: The key to check (e.g., "decode_worker_id" or "prefill_worker_id").
expected_worker_id: The expected worker ID value.
Raises:
AssertionError: If any response is missing the key, values differ, or don't match expected.
"""
worker_ids = [r.get(key) for r in response_worker_ids]
logger.info(f"Response {key}s: {worker_ids}")
# All responses should have the key
assert all(
wid is not None for wid in worker_ids
), f"Expected all {len(response_worker_ids)} responses to have {key}, got: {worker_ids}"
# All values should be the same (due to prefix reuse routing)
unique_ids = set(worker_ids)
assert len(unique_ids) == 1, (
f"Expected all responses to have the same {key} (due to prefix reuse), "
f"but found {len(unique_ids)} unique values: {unique_ids}"
)
# The value should match the expected worker ID
actual_worker_id = worker_ids[0]
assert actual_worker_id == expected_worker_id, (
f"Expected {key}={expected_worker_id} (forced in first request), "
f"but got {key}={actual_worker_id}"
)
logger.info(
f"✓ Verified all {len(response_worker_ids)} responses have {key}={actual_worker_id}"
)
def verify_response_timing(timing_info: dict[str, Any]) -> None:
"""Verify timing info has valid values (ttft_ms > 0, total_time_ms > 0)."""
ttft_ms = timing_info.get("ttft_ms")
total_time_ms = timing_info.get("total_time_ms")
assert ttft_ms is not None and ttft_ms > 0, f"Expected ttft_ms > 0, got: {ttft_ms}"
assert (
total_time_ms is not None and total_time_ms > 0
), f"Expected total_time_ms > 0, got: {total_time_ms}"
assert (
total_time_ms >= ttft_ms
), f"Expected total_time_ms >= ttft_ms, got {total_time_ms} < {ttft_ms}"
logger.info(
f"✓ Verified timing: ttft_ms={ttft_ms:.2f}, total_time_ms={total_time_ms:.2f}"
)
########################################################
# Utility functions
########################################################
async def wait_for_frontend_ready(
frontend_url: str, expected_num_workers: int = 2, timeout: int = 120
):
"""Wait for backend worker(s) to be ready via the HTTP frontend (OpenAI API).
This function performs a two-phase readiness check through the frontend HTTP server:
1. Polls GET /v1/models until at least one model is registered (workers connected)
2. Sends a test POST to /v1/chat/completions to verify the request pipeline is functional
Use this when testing through the HTTP frontend server (dynamo.frontend).
For direct Python API testing with KvRouter, use wait_for_workers_ready() instead.
Args:
frontend_url: Base URL of the frontend HTTP server (e.g., "http://localhost:8000")
expected_num_workers: Number of workers to wait for (currently logs but doesn't enforce)
timeout: Maximum time to wait in seconds for both phases combined
Raises:
TimeoutError: If workers don't register or pipeline doesn't become ready within timeout
aiohttp.ClientError: If HTTP requests fail unexpectedly
"""
models_url = f"{frontend_url}/v1/models"
chat_url = f"{frontend_url}/v1/chat/completions"
start_time = asyncio.get_event_loop().time()
logger.info(
f"Waiting for {expected_num_workers} workers to register on HTTP frontend (timeout={timeout}s)..."
)
# Phase 1: Wait for models to appear in /v1/models
model_name = None
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > timeout:
raise TimeoutError(
f"Timeout waiting for vLLM workers. Waited {elapsed:.1f}s, no workers registered."
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(models_url) as response:
if response.status == 200:
data = await response.json()
models = data.get("data", [])
if len(models) > 0:
model_name = models[0].get("id")
logger.info(
f"Workers registered. Found {len(models)} model(s): {[m.get('id') for m in models]}"
)
break
else:
logger.debug(
f"No models registered yet (elapsed: {elapsed:.1f}s)"
)
except Exception as e:
logger.debug(f"Error checking models endpoint: {e}")
# Wait before next poll
await asyncio.sleep(1)
# Phase 2: Wait for chat completions pipeline to be ready
logger.info("Waiting for chat completions pipeline to be built...")
test_payload = {
"model": model_name,
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 1,
"stream": False,
}
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > timeout:
raise TimeoutError(
f"Timeout waiting for chat completions pipeline. Waited {elapsed:.1f}s."
)
try:
async with aiohttp.ClientSession() as session:
async with session.post(chat_url, json=test_payload) as response:
if response.status == 200:
logger.info("Chat completions pipeline ready!")
return
else:
logger.debug(
f"Chat completions not ready yet, status {response.status} (elapsed: {elapsed:.1f}s)"
)
except Exception as e:
logger.debug(f"Error testing chat completions: {e}")
# Wait before next poll
await asyncio.sleep(1)
async def wait_for_workers_ready(
endpoint,
router: KvRouter,
expected_num_workers: int,
model_name: str,
) -> list[int]:
"""Wait for workers to be ready and return their instance IDs.
Supports mocker and vLLM workers.
This function polls the endpoint's client for instance IDs until the expected
number of workers are available, then sends a warmup request to verify they
can handle requests.
Args:
endpoint: The endpoint object to get the client from
router: The KvRouter to use for sending warmup requests
expected_num_workers: Number of workers to wait for
Returns:
Sorted list of unique instance IDs (ints).
Raises:
AssertionError: If workers don't become ready or warmup request fails.
"""
logger.info("Waiting for workers to be ready")
# Get the client from the endpoint
client = await endpoint.client()
# Poll for instance IDs until we have the expected number
instance_ids: list[int] = []
max_wait_time = 60 # seconds
start_time = asyncio.get_running_loop().time()
while len(instance_ids) < expected_num_workers:
instance_ids = client.instance_ids()
logger.info(f"Found {len(instance_ids)} instance(s): {instance_ids}")
if len(instance_ids) >= expected_num_workers:
break
# Check timeout
if asyncio.get_running_loop().time() - start_time > max_wait_time:
raise AssertionError(
f"Timeout waiting for workers. Found {len(instance_ids)} instance(s), expected {expected_num_workers}"
)
# Wait 1 second before polling again
await asyncio.sleep(1.0)
# Send a warmup request to verify workers can handle requests
test_token_ids = [random.randint(1, 10000) for _ in range(4)]
logger.info(f"Sending warmup request with {len(test_token_ids)} tokens")
try:
await send_request_via_python_kv_router(
kv_python_router=router,
model_name=model_name,
token_ids=test_token_ids,
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True,
"max_tokens": 2,
},
)
except Exception as e:
raise AssertionError(f"Warmup request failed: {e}")
logger.info(f"All {len(instance_ids)} workers are ready")
return sorted(instance_ids)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8):
"""Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second
for attempt in range(max_retries + 1):
await asyncio.sleep(wait_time)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.debug(
f"First request succeeded on attempt {attempt + 1}"
)
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
wait_time *= 2 # Double the wait time
return False
def get_runtime(store_backend="etcd", request_plane="tcp"):
"""Create a DistributedRuntime instance for testing.
Args:
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
request_plane: How frontend talks to backend ("tcp", "http" or "nats"). Defaults to "tcp".
"""
try:
# Try to get running loop (works in async context)
loop = asyncio.get_running_loop()
except RuntimeError:
# No running loop, create a new one (sync context)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return DistributedRuntime(loop, store_backend, request_plane)
async def check_nats_consumers(namespace: str, expected_count: Optional[int] = None):
"""Check NATS consumers for the KV events stream.
Args:
namespace: The namespace to check consumers for
expected_count: Optional expected number of consumers. If provided, asserts if count doesn't match.
Returns:
List of consumer names
"""
component_subject = f"namespace.{namespace}.component.mocker"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
stream_name = f"{slugified}-kv-events"
logger.info(f"Checking consumers for stream: {stream_name}")
nc = await nats.connect(servers=_nats_server())
try:
js = nc.jetstream()
consumer_infos = await js.consumers_info(stream_name)
consumer_names = [info.name for info in consumer_infos]
logger.info(f"Found {len(consumer_names)} consumers: {consumer_names}")
# Log detailed consumer info
for info in consumer_infos:
logger.info(
f"Consumer {info.name}: "
f"num_pending={info.num_pending}, "
f"num_ack_pending={info.num_ack_pending}, "
f"ack_floor={info.ack_floor}, "
f"delivered={info.delivered}"
)
if expected_count is not None:
assert (
len(consumer_names) == expected_count
), f"Expected {expected_count} durable consumers, found {len(consumer_names)}: {consumer_names}"
logger.info(f"✓ Verified {expected_count} durable consumers exist")
return consumer_names
finally:
await nc.close()
async def send_inflight_requests(urls: list, payload: dict, num_requests: int):
"""Send multiple requests concurrently, alternating between URLs if multiple provided"""
# First, send test requests with retry to ensure all systems are ready
for i, url in enumerate(urls):
logger.info(f"Sending initial test request to URL {i} ({url}) with retry...")
if not await send_request_with_retry(url, payload):
raise RuntimeError(f"Failed to connect to URL {i} after multiple retries")
async def send_single_request(session: aiohttp.ClientSession, request_id: int):
# Alternate between URLs based on request_id
url = urls[request_id % len(urls)]
url_index = request_id % len(urls)
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with status {response.status}"
)
return False
# For streaming responses, read the entire stream
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug(
f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks"
)
return True
except Exception as e:
logger.error(
f"Request {request_id} to URL {url_index} failed with error: {e}"
)
return False
# Send all requests at once
async with aiohttp.ClientSession() as session:
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful = sum(1 for r in results if r if r is True)
failed = num_requests - successful
logger.info(f"Completed all requests: {successful} successful, {failed} failed")
assert (
successful == num_requests
), f"Expected {num_requests} successful requests, got {successful}"
logger.info(f"All {num_requests} requests completed successfully")
async def send_request_via_python_kv_router(
kv_python_router: KvRouter,
model_name: str,
token_ids: list,
initial_wait: float,
max_retries: int,
stop_conditions: Optional[dict] = None,
sampling_options: Optional[dict] = None,
output_options: Optional[dict] = None,
router_config_override: Optional[dict] = None,
worker_id: Optional[
int
] = None, # If None, Router will select the best available worker
dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0)
return_worker_ids: bool = False, # If True, return worker IDs from response
) -> bool | dict[str, Optional[int]]:
"""Send a request to the specified worker instance.
Args:
return_worker_ids: If True, returns a dict with prefill_worker_id and decode_worker_id.
If False, returns True on success or False on failure.
Returns:
If return_worker_ids=False: True if workers respond, otherwise raises or returns False.
If return_worker_ids=True: Dict with 'prefill_worker_id' and 'decode_worker_id' keys.
"""
wait_time = initial_wait
log_message = (
f"worker with worker_id={worker_id}"
if worker_id is not None
else "the best available worker"
)
# Retry loop sending request to worker with exponential backoff
stream = None
for attempt in range(max_retries + 1):
try:
logger.debug(f"Sending request to {log_message} (attempt {attempt + 1})")
stream = await kv_python_router.generate(
token_ids=token_ids,
model=model_name,
stop_conditions=stop_conditions, # type: ignore[arg-type]
sampling_options=sampling_options, # type: ignore[arg-type]
output_options=output_options, # type: ignore[arg-type]
router_config_override=router_config_override, # type: ignore[arg-type]
worker_id=worker_id,
dp_rank=dp_rank,
)
if stream is not None:
logger.debug(f"Request succeeded on attempt {attempt + 1}")
break
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
await asyncio.sleep(wait_time)
wait_time *= 2
else:
raise RuntimeError(
f"Failed to connect to workers after {max_retries + 1} attempts"
) from e
if stream is None:
raise RuntimeError(
f"Failed to get a valid stream from workers after {max_retries + 1} attempts"
)
# Collect tokens and worker IDs from the SSE stream
generated_tokens = []
prefill_worker_id: Optional[int] = None
decode_worker_id: Optional[int] = None
prefill_dp_rank: Optional[int] = None
decode_dp_rank: Optional[int] = None
async for response in stream:
if isinstance(response, dict):
# Check if response has token_ids
if "token_ids" in response:
tokens = response["token_ids"]
if isinstance(tokens, list):
generated_tokens.extend(tokens)
logger.debug(f"Received {len(tokens)} tokens: {tokens}")
# Check for finish reason
if "finish_reason" in response:
logger.debug(
f"Stream finished with reason: {response['finish_reason']}"
)
# Extract worker IDs and dp_ranks from disaggregated_params if present
if return_worker_ids and "disaggregated_params" in response:
disagg_params = response["disaggregated_params"]
if isinstance(disagg_params, dict) and "worker_id" in disagg_params:
worker_id_info = disagg_params["worker_id"]
if isinstance(worker_id_info, dict):
if "prefill_worker_id" in worker_id_info:
prefill_worker_id = worker_id_info["prefill_worker_id"]
if "decode_worker_id" in worker_id_info:
decode_worker_id = worker_id_info["decode_worker_id"]
if "prefill_dp_rank" in worker_id_info:
prefill_dp_rank = worker_id_info["prefill_dp_rank"]
if "decode_dp_rank" in worker_id_info:
decode_dp_rank = worker_id_info["decode_dp_rank"]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.debug(f"Total generated tokens: {len(generated_tokens)}")
if (
stop_conditions
and "max_tokens" in stop_conditions
and "ignore_eos" in stop_conditions
and stop_conditions["ignore_eos"]
):
max_tokens = int(stop_conditions["max_tokens"])
assert len(generated_tokens) == max_tokens, (
f"Expected exactly {max_tokens} tokens but got {len(generated_tokens)}. "
f"Tokens: {generated_tokens}"
)
logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvRouter with ignore_eos=True"
)
if return_worker_ids:
return {
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
"prefill_dp_rank": prefill_dp_rank,
"decode_dp_rank": decode_dp_rank,
}
return True
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from tests.utils.managed_process import ManagedProcess
class KVRouterProcess(ManagedProcess):
"""Manages the KV router process using dynamo.frontend"""
def __init__(
self,
request,
block_size: int,
frontend_port: int,
namespace: str,
store_backend: str = "etcd",
enforce_disagg: bool = False,
blocks_threshold: float | None = None,
tokens_threshold: float | None = None,
tokens_threshold_frac: float | None = None,
request_plane: str = "nats",
durable_kv_events: bool = False,
):
command = [
"python3",
"-m",
"dynamo.frontend",
"--kv-cache-block-size",
str(block_size),
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
"--discovery-backend",
store_backend,
"--namespace",
namespace,
]
if enforce_disagg:
command.append("--enforce-disagg")
if blocks_threshold is not None:
command.extend(["--active-decode-blocks-threshold", str(blocks_threshold)])
if tokens_threshold is not None:
command.extend(["--active-prefill-tokens-threshold", str(tokens_threshold)])
if tokens_threshold_frac is not None:
command.extend(
["--active-prefill-tokens-threshold-frac", str(tokens_threshold_frac)]
)
if durable_kv_events:
command.append("--router-durable-kv-events")
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
super().__init__(
command=command,
env=env,
timeout=60,
display_output=True,
health_check_ports=[frontend_port],
health_check_urls=[
(f"http://localhost:{frontend_port}/v1/models", self._check_ready)
],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
)
self.port = frontend_port
def _check_ready(self, response):
"""Check if KV router is ready"""
return response.status_code == 200
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
......@@ -20,7 +20,7 @@ from typing import Any, Dict, Optional
import aiohttp
import pytest
from tests.router.common import ( # utilities
from tests.router.common import (
_test_busy_threshold_endpoint,
_test_python_router_bindings,
_test_router_basic,
......@@ -30,9 +30,8 @@ from tests.router.common import ( # utilities
_test_router_overload_503,
_test_router_query_instance_id,
_test_router_two_routers,
generate_random_suffix,
get_runtime,
)
from tests.router.helper import generate_random_suffix, get_runtime
from tests.utils.constants import ROUTER_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
......
......@@ -12,13 +12,12 @@ from typing import Any, Dict, Optional
import pytest
from tests.router.common import ( # utilities
from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_indexers_sync,
generate_random_suffix,
get_runtime,
)
from tests.router.helper import generate_random_suffix, get_runtime
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
......
......@@ -12,13 +12,12 @@ from typing import Any, Dict, Optional
import pytest
from tests.router.common import ( # utilities
from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_indexers_sync,
generate_random_suffix,
get_runtime,
)
from tests.router.helper import generate_random_suffix, get_runtime
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
......
......@@ -13,13 +13,12 @@ from typing import Any, Dict, Optional
import pytest
from tests.router.common import ( # utilities
from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_indexers_sync,
generate_random_suffix,
get_runtime,
)
from tests.router.helper import generate_random_suffix, get_runtime
from tests.utils.constants import DefaultPort
from tests.utils.managed_process import ManagedProcess
from tests.utils.port_utils import allocate_ports, deallocate_ports
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment