"tests/vscode:/vscode.git/clone" did not exist on "302ef403a2305e9158064f8e386d1b5284d12cb2"
Unverified Commit c09abf7d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test(router): cover round-robin unset dp-rank flow (#7991)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 1dd076cc
......@@ -8,6 +8,7 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::backend::ExecutionContext;
......@@ -297,6 +298,7 @@ pub struct MockEngine {
request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify,
engine_args: MockEngineArgs,
unset_dp_rank_counter: AtomicU32,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
/// Keep schedulers alive so their CancelGuards don't fire prematurely.
......@@ -311,11 +313,20 @@ impl MockEngine {
request_senders: OnceCell::new(),
senders_ready: Notify::new(),
engine_args,
unset_dp_rank_counter: AtomicU32::new(0),
bootstrap_server: Arc::new(OnceCell::new()),
_schedulers: OnceCell::new(),
}
}
fn resolve_dp_rank(&self, request: &PreprocessedRequest) -> u32 {
if let Some(dp_rank) = request.routing.as_ref().and_then(|routing| routing.dp_rank) {
return dp_rank;
}
self.unset_dp_rank_counter.fetch_add(1, Ordering::Relaxed) % self.engine_args.dp_size
}
pub async fn start(&self, component: Component) -> Result<()> {
// Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
......@@ -583,12 +594,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts();
// Extract dp_rank from routing hints (defaults to 0 if not set)
let dp_rank = request
.routing
.as_ref()
.and_then(|r| r.dp_rank)
.unwrap_or(0);
let dp_rank = self.resolve_dp_rank(&request);
// Validate dp_rank
if dp_rank >= self.engine_args.dp_size {
......
......@@ -1779,6 +1779,145 @@ def _test_router_decisions_disagg(
)
def _test_router_decisions_disagg_round_robin_prefill_dp_rank(
prefill_workers,
decode_workers,
block_size: int,
request,
frontend_port: int,
test_payload: dict,
expected_prefill_dp_ranks: int,
store_backend: str = "etcd",
request_plane: str = "nats",
):
"""Verify disaggregated round-robin requests store prefill KV blocks across DP ranks."""
with FrontendRouterProcess(
request,
block_size,
frontend_port,
decode_workers.namespace,
store_backend,
enforce_disagg=True,
request_plane=request_plane,
router_mode="round-robin",
min_initial_workers=decode_workers.num_workers,
):
logger.info(
"Starting round-robin frontend on port %s for disagg prefill dp-rank test",
frontend_port,
)
async def test_sync():
frontend_url = f"http://localhost:{frontend_port}"
chat_url = f"{frontend_url}/v1/chat/completions"
await wait_for_frontend_ready(
frontend_url=frontend_url,
expected_num_workers=decode_workers.num_workers,
timeout=120,
)
runtime = get_runtime(
store_backend=store_backend, request_plane=request_plane
)
prefill_endpoint = runtime.endpoint(
f"{prefill_workers.namespace}.prefill.generate"
)
with min_initial_workers_env(prefill_workers.num_workers):
observer_router = KvRouter(
endpoint=prefill_endpoint,
block_size=block_size,
kv_router_config=KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=True,
durable_kv_events=False,
router_event_threads=4,
router_track_prefill_tokens=True,
router_prefill_load_model="none",
),
)
client = await prefill_endpoint.client()
worker_ids: list[int] = []
deadline = asyncio.get_running_loop().time() + 60
while asyncio.get_running_loop().time() < deadline:
worker_ids = sorted(set(client.instance_ids()))
if len(worker_ids) >= prefill_workers.num_workers:
break
await asyncio.sleep(1.0)
assert len(worker_ids) == prefill_workers.num_workers, (
f"Timed out waiting for prefill workers. "
f"Found {worker_ids}, expected {prefill_workers.num_workers}"
)
prefill_worker_id = worker_ids[0]
def stored_blocks_by_dp_rank(events_json: str) -> dict[int, int]:
counts = {dp_rank: 0 for dp_rank in range(expected_prefill_dp_ranks)}
for event in json.loads(events_json):
if event.get("worker_id") != prefill_worker_id:
continue
stored = event.get("event", {}).get("data", {}).get("stored")
if stored is None:
continue
dp_rank = event.get("event", {}).get("dp_rank", 0)
counts[dp_rank] = counts.get(dp_rank, 0) + len(
stored.get("blocks", [])
)
return counts
await asyncio.sleep(2.0)
baseline_counts = stored_blocks_by_dp_rank(
await observer_router.dump_events()
)
async with aiohttp.ClientSession() as session:
for request_idx in range(expected_prefill_dp_ranks * 2):
prompt_tokens = " ".join(
f"prefill-{request_idx}-token-{token_idx}"
for token_idx in range(block_size * 3)
)
payload = {
**test_payload,
"stream": False,
"max_tokens": 1,
"messages": [
{
"role": "user",
"content": prompt_tokens,
}
],
}
async with session.post(chat_url, json=payload) as response:
assert response.status == 200, (
f"Request {request_idx + 1} failed with status "
f"{response.status}: {await response.text()}"
)
await response.text()
await asyncio.sleep(0.5)
await asyncio.sleep(2.0)
final_counts = stored_blocks_by_dp_rank(await observer_router.dump_events())
return prefill_worker_id, baseline_counts, final_counts
prefill_worker_id, baseline_counts, final_counts = asyncio.run(test_sync())
delta_counts = {
dp_rank: final_counts.get(dp_rank, 0) - baseline_counts.get(dp_rank, 0)
for dp_rank in range(expected_prefill_dp_ranks)
}
active_dp_ranks = sorted(
dp_rank for dp_rank, block_count in delta_counts.items() if block_count > 0
)
assert active_dp_ranks == list(range(expected_prefill_dp_ranks)), (
f"Expected round-robin prefill requests for worker {prefill_worker_id} "
f"to store KV blocks on dp_ranks {list(range(expected_prefill_dp_ranks))}, "
f"but saw deltas {delta_counts}"
)
def _test_router_decisions(
engine_workers,
endpoint,
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
import sys
from tests.utils.managed_process import ManagedProcess
......@@ -34,7 +35,7 @@ class FrontendRouterProcess(ManagedProcess):
use_remote_indexer: bool = False,
):
command = [
"python3",
sys.executable,
"-m",
"dynamo.frontend",
"--router-mode",
......@@ -113,6 +114,7 @@ class FrontendRouterProcess(ManagedProcess):
],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
display_name=f"dynamo-frontend-{router_mode}",
)
self.port = frontend_port
self.router_mode = router_mode
......@@ -141,7 +143,7 @@ class DirectRouterProcess(ManagedProcess):
request_plane: str = "nats",
):
command = [
"python3",
sys.executable,
"-m",
"dynamo.frontend",
"--router-mode",
......@@ -169,6 +171,7 @@ class DirectRouterProcess(ManagedProcess):
],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
display_name="dynamo-frontend-direct",
)
self.port = frontend_port
......
......@@ -11,8 +11,10 @@
import asyncio
import logging
import os
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterator, Optional
import aiohttp
import pytest
......@@ -25,6 +27,7 @@ from tests.router.common import (
_test_router_basic,
_test_router_decisions,
_test_router_decisions_disagg,
_test_router_decisions_disagg_round_robin_prefill_dp_rank,
_test_router_indexers_sync,
_test_router_overload_503,
_test_router_query_instance_id,
......@@ -139,7 +142,7 @@ def _build_mocker_command(
List of command arguments for subprocess
"""
command = [
"python",
sys.executable,
"-m",
"dynamo.mocker",
"--model-path",
......@@ -319,6 +322,7 @@ class MockerProcess:
health_check_urls=[],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
display_name="dynamo-mocker",
)
logger.info(
f"Created mocker process with {num_mockers} worker(s), endpoint: {self.endpoint}"
......@@ -640,6 +644,7 @@ class DisaggMockerProcess:
health_check_urls=[],
log_dir=request.node.name,
terminate_all_matching_process_names=False,
display_name=f"dynamo-mocker-{worker_type}",
)
logger.info(
f"Created {worker_type} mocker process with {num_mockers} worker(s), "
......@@ -668,6 +673,77 @@ class DisaggMockerProcess:
self._bootstrap_ports = []
@contextmanager
def _launch_disagg_workers(
request,
namespace: str,
registration_order: str,
*,
prefill_mocker_args: Dict[str, Any],
decode_mocker_args: Dict[str, Any],
num_prefill_mockers: int,
num_decode_mockers: int,
enable_disagg_bootstrap: bool,
request_plane: str = "nats",
) -> Iterator[tuple[DisaggMockerProcess, DisaggMockerProcess]]:
if registration_order not in ("prefill_first", "decode_first"):
raise ValueError(f"Unexpected registration order: {registration_order}")
if registration_order == "prefill_first":
logger.info("Starting %s prefill mocker instances (first)", num_prefill_mockers)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="prefill",
mocker_args=prefill_mocker_args,
num_mockers=num_prefill_mockers,
request_plane=request_plane,
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
logger.info(
"Starting %s decode mocker instances (second)", num_decode_mockers
)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="decode",
mocker_args=decode_mocker_args,
num_mockers=num_decode_mockers,
request_plane=request_plane,
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
yield prefill_workers, decode_workers
return
logger.info("Starting %s decode mocker instances (first)", num_decode_mockers)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="decode",
mocker_args=decode_mocker_args,
num_mockers=num_decode_mockers,
request_plane=request_plane,
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
logger.info(
"Starting %s prefill mocker instances (second)", num_prefill_mockers
)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="prefill",
mocker_args=prefill_mocker_args,
num_mockers=num_prefill_mockers,
request_plane=request_plane,
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
yield prefill_workers, decode_workers
@pytest.mark.timeout(180) # planner-profile mocker setup can exceed 120s on CI CPUs
@pytest.mark.parametrize(
"router_mode,durable_kv_events,mocker_args_override",
......@@ -1192,38 +1268,19 @@ def test_router_decisions_disagg(
# durable_kv_events defaults to False (NATS Core mode)
}
if registration_order == "prefill_first":
# Start prefill workers first
logger.info("Starting 4 prefill mocker instances (first)")
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Then start decode workers
logger.info("Starting 4 decode mocker instances (second)")
with DisaggMockerProcess(
with _launch_disagg_workers(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Get unique port for this test
shared_namespace,
registration_order,
prefill_mocker_args=mocker_args,
decode_mocker_args=mocker_args,
num_prefill_mockers=4,
num_decode_mockers=4,
enable_disagg_bootstrap=enable_disagg_bootstrap,
) as (prefill_workers, decode_workers):
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
......@@ -1233,50 +1290,67 @@ def test_router_decisions_disagg(
test_payload=TEST_PAYLOAD,
request_plane="nats",
)
else:
# Start decode workers first
logger.info("Starting 4 decode mocker instances (first)")
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Then start prefill workers
logger.info("Starting 4 prefill mocker instances (second)")
with DisaggMockerProcess(
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.parametrize(
"enable_disagg_bootstrap", [False, True], ids=["no_bootstrap", "with_bootstrap"]
)
@pytest.mark.timeout(180)
def test_router_decisions_disagg_round_robin_prefill_dp_rank(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
runtime_services_dynamic_ports,
predownload_tokenizers,
registration_order,
enable_disagg_bootstrap,
):
"""Verify round-robin disagg prefill requests spread KV stores across DP ranks."""
logger.info(
f"Prefill workers using endpoint: {prefill_workers.endpoint}"
"Starting disaggregated round-robin prefill dp-rank test "
"(registration_order=%s, bootstrap=%s)",
registration_order,
enable_disagg_bootstrap,
)
# Get unique port for this test
namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}"
prefill_mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"dp_size": 4,
}
decode_mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
}
def run_case(prefill_workers, decode_workers):
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
_test_router_decisions_disagg_round_robin_prefill_dp_rank(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
expected_prefill_dp_ranks=prefill_mocker_args["dp_size"],
request_plane="nats",
)
with _launch_disagg_workers(
request,
shared_namespace,
registration_order,
prefill_mocker_args=prefill_mocker_args,
decode_mocker_args=decode_mocker_args,
num_prefill_mockers=1,
num_decode_mockers=1,
enable_disagg_bootstrap=enable_disagg_bootstrap,
) as (prefill_workers, decode_workers):
run_case(prefill_workers, decode_workers)
@pytest.mark.timeout(180)
def test_router_decisions_disagg_router_aic(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment