Unverified Commit c09abf7d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test(router): cover round-robin unset dp-rank flow (#7991)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 1dd076cc
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::backend::ExecutionContext; use crate::backend::ExecutionContext;
...@@ -297,6 +298,7 @@ pub struct MockEngine { ...@@ -297,6 +298,7 @@ pub struct MockEngine {
request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>, request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify, senders_ready: Notify,
engine_args: MockEngineArgs, engine_args: MockEngineArgs,
unset_dp_rank_counter: AtomicU32,
/// Bootstrap server for prefill workers in disaggregated mode /// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>, bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
/// Keep schedulers alive so their CancelGuards don't fire prematurely. /// Keep schedulers alive so their CancelGuards don't fire prematurely.
...@@ -311,11 +313,20 @@ impl MockEngine { ...@@ -311,11 +313,20 @@ impl MockEngine {
request_senders: OnceCell::new(), request_senders: OnceCell::new(),
senders_ready: Notify::new(), senders_ready: Notify::new(),
engine_args, engine_args,
unset_dp_rank_counter: AtomicU32::new(0),
bootstrap_server: Arc::new(OnceCell::new()), bootstrap_server: Arc::new(OnceCell::new()),
_schedulers: OnceCell::new(), _schedulers: OnceCell::new(),
} }
} }
fn resolve_dp_rank(&self, request: &PreprocessedRequest) -> u32 {
if let Some(dp_rank) = request.routing.as_ref().and_then(|routing| routing.dp_rank) {
return dp_rank;
}
self.unset_dp_rank_counter.fetch_add(1, Ordering::Relaxed) % self.engine_args.dp_size
}
pub async fn start(&self, component: Component) -> Result<()> { pub async fn start(&self, component: Component) -> Result<()> {
// Use primary_token() instead of child_token() so the mocker continues running // Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3. // during graceful shutdown (Phase 1/2) and only stops in Phase 3.
...@@ -583,12 +594,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -583,12 +594,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
) -> Result<ManyOut<LLMEngineOutput>, Error> { ) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts(); let (request, ctx) = input.into_parts();
// Extract dp_rank from routing hints (defaults to 0 if not set) let dp_rank = self.resolve_dp_rank(&request);
let dp_rank = request
.routing
.as_ref()
.and_then(|r| r.dp_rank)
.unwrap_or(0);
// Validate dp_rank // Validate dp_rank
if dp_rank >= self.engine_args.dp_size { if dp_rank >= self.engine_args.dp_size {
......
...@@ -1779,6 +1779,145 @@ def _test_router_decisions_disagg( ...@@ -1779,6 +1779,145 @@ def _test_router_decisions_disagg(
) )
def _test_router_decisions_disagg_round_robin_prefill_dp_rank(
prefill_workers,
decode_workers,
block_size: int,
request,
frontend_port: int,
test_payload: dict,
expected_prefill_dp_ranks: int,
store_backend: str = "etcd",
request_plane: str = "nats",
):
"""Verify disaggregated round-robin requests store prefill KV blocks across DP ranks."""
with FrontendRouterProcess(
request,
block_size,
frontend_port,
decode_workers.namespace,
store_backend,
enforce_disagg=True,
request_plane=request_plane,
router_mode="round-robin",
min_initial_workers=decode_workers.num_workers,
):
logger.info(
"Starting round-robin frontend on port %s for disagg prefill dp-rank test",
frontend_port,
)
async def test_sync():
frontend_url = f"http://localhost:{frontend_port}"
chat_url = f"{frontend_url}/v1/chat/completions"
await wait_for_frontend_ready(
frontend_url=frontend_url,
expected_num_workers=decode_workers.num_workers,
timeout=120,
)
runtime = get_runtime(
store_backend=store_backend, request_plane=request_plane
)
prefill_endpoint = runtime.endpoint(
f"{prefill_workers.namespace}.prefill.generate"
)
with min_initial_workers_env(prefill_workers.num_workers):
observer_router = KvRouter(
endpoint=prefill_endpoint,
block_size=block_size,
kv_router_config=KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=True,
durable_kv_events=False,
router_event_threads=4,
router_track_prefill_tokens=True,
router_prefill_load_model="none",
),
)
client = await prefill_endpoint.client()
worker_ids: list[int] = []
deadline = asyncio.get_running_loop().time() + 60
while asyncio.get_running_loop().time() < deadline:
worker_ids = sorted(set(client.instance_ids()))
if len(worker_ids) >= prefill_workers.num_workers:
break
await asyncio.sleep(1.0)
assert len(worker_ids) == prefill_workers.num_workers, (
f"Timed out waiting for prefill workers. "
f"Found {worker_ids}, expected {prefill_workers.num_workers}"
)
prefill_worker_id = worker_ids[0]
def stored_blocks_by_dp_rank(events_json: str) -> dict[int, int]:
counts = {dp_rank: 0 for dp_rank in range(expected_prefill_dp_ranks)}
for event in json.loads(events_json):
if event.get("worker_id") != prefill_worker_id:
continue
stored = event.get("event", {}).get("data", {}).get("stored")
if stored is None:
continue
dp_rank = event.get("event", {}).get("dp_rank", 0)
counts[dp_rank] = counts.get(dp_rank, 0) + len(
stored.get("blocks", [])
)
return counts
await asyncio.sleep(2.0)
baseline_counts = stored_blocks_by_dp_rank(
await observer_router.dump_events()
)
async with aiohttp.ClientSession() as session:
for request_idx in range(expected_prefill_dp_ranks * 2):
prompt_tokens = " ".join(
f"prefill-{request_idx}-token-{token_idx}"
for token_idx in range(block_size * 3)
)
payload = {
**test_payload,
"stream": False,
"max_tokens": 1,
"messages": [
{
"role": "user",
"content": prompt_tokens,
}
],
}
async with session.post(chat_url, json=payload) as response:
assert response.status == 200, (
f"Request {request_idx + 1} failed with status "
f"{response.status}: {await response.text()}"
)
await response.text()
await asyncio.sleep(0.5)
await asyncio.sleep(2.0)
final_counts = stored_blocks_by_dp_rank(await observer_router.dump_events())
return prefill_worker_id, baseline_counts, final_counts
prefill_worker_id, baseline_counts, final_counts = asyncio.run(test_sync())
delta_counts = {
dp_rank: final_counts.get(dp_rank, 0) - baseline_counts.get(dp_rank, 0)
for dp_rank in range(expected_prefill_dp_ranks)
}
active_dp_ranks = sorted(
dp_rank for dp_rank, block_count in delta_counts.items() if block_count > 0
)
assert active_dp_ranks == list(range(expected_prefill_dp_ranks)), (
f"Expected round-robin prefill requests for worker {prefill_worker_id} "
f"to store KV blocks on dp_ranks {list(range(expected_prefill_dp_ranks))}, "
f"but saw deltas {delta_counts}"
)
def _test_router_decisions( def _test_router_decisions(
engine_workers, engine_workers,
endpoint, endpoint,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import sys
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -34,7 +35,7 @@ class FrontendRouterProcess(ManagedProcess): ...@@ -34,7 +35,7 @@ class FrontendRouterProcess(ManagedProcess):
use_remote_indexer: bool = False, use_remote_indexer: bool = False,
): ):
command = [ command = [
"python3", sys.executable,
"-m", "-m",
"dynamo.frontend", "dynamo.frontend",
"--router-mode", "--router-mode",
...@@ -113,6 +114,7 @@ class FrontendRouterProcess(ManagedProcess): ...@@ -113,6 +114,7 @@ class FrontendRouterProcess(ManagedProcess):
], ],
log_dir=request.node.name, log_dir=request.node.name,
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
display_name=f"dynamo-frontend-{router_mode}",
) )
self.port = frontend_port self.port = frontend_port
self.router_mode = router_mode self.router_mode = router_mode
...@@ -141,7 +143,7 @@ class DirectRouterProcess(ManagedProcess): ...@@ -141,7 +143,7 @@ class DirectRouterProcess(ManagedProcess):
request_plane: str = "nats", request_plane: str = "nats",
): ):
command = [ command = [
"python3", sys.executable,
"-m", "-m",
"dynamo.frontend", "dynamo.frontend",
"--router-mode", "--router-mode",
...@@ -169,6 +171,7 @@ class DirectRouterProcess(ManagedProcess): ...@@ -169,6 +171,7 @@ class DirectRouterProcess(ManagedProcess):
], ],
log_dir=request.node.name, log_dir=request.node.name,
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
display_name="dynamo-frontend-direct",
) )
self.port = frontend_port self.port = frontend_port
......
...@@ -11,8 +11,10 @@ ...@@ -11,8 +11,10 @@
import asyncio import asyncio
import logging import logging
import os import os
import sys
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Iterator, Optional
import aiohttp import aiohttp
import pytest import pytest
...@@ -25,6 +27,7 @@ from tests.router.common import ( ...@@ -25,6 +27,7 @@ from tests.router.common import (
_test_router_basic, _test_router_basic,
_test_router_decisions, _test_router_decisions,
_test_router_decisions_disagg, _test_router_decisions_disagg,
_test_router_decisions_disagg_round_robin_prefill_dp_rank,
_test_router_indexers_sync, _test_router_indexers_sync,
_test_router_overload_503, _test_router_overload_503,
_test_router_query_instance_id, _test_router_query_instance_id,
...@@ -139,7 +142,7 @@ def _build_mocker_command( ...@@ -139,7 +142,7 @@ def _build_mocker_command(
List of command arguments for subprocess List of command arguments for subprocess
""" """
command = [ command = [
"python", sys.executable,
"-m", "-m",
"dynamo.mocker", "dynamo.mocker",
"--model-path", "--model-path",
...@@ -319,6 +322,7 @@ class MockerProcess: ...@@ -319,6 +322,7 @@ class MockerProcess:
health_check_urls=[], health_check_urls=[],
log_dir=request.node.name, log_dir=request.node.name,
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
display_name="dynamo-mocker",
) )
logger.info( logger.info(
f"Created mocker process with {num_mockers} worker(s), endpoint: {self.endpoint}" f"Created mocker process with {num_mockers} worker(s), endpoint: {self.endpoint}"
...@@ -640,6 +644,7 @@ class DisaggMockerProcess: ...@@ -640,6 +644,7 @@ class DisaggMockerProcess:
health_check_urls=[], health_check_urls=[],
log_dir=request.node.name, log_dir=request.node.name,
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
display_name=f"dynamo-mocker-{worker_type}",
) )
logger.info( logger.info(
f"Created {worker_type} mocker process with {num_mockers} worker(s), " f"Created {worker_type} mocker process with {num_mockers} worker(s), "
...@@ -668,6 +673,77 @@ class DisaggMockerProcess: ...@@ -668,6 +673,77 @@ class DisaggMockerProcess:
self._bootstrap_ports = [] self._bootstrap_ports = []
@contextmanager
def _launch_disagg_workers(
request,
namespace: str,
registration_order: str,
*,
prefill_mocker_args: Dict[str, Any],
decode_mocker_args: Dict[str, Any],
num_prefill_mockers: int,
num_decode_mockers: int,
enable_disagg_bootstrap: bool,
request_plane: str = "nats",
) -> Iterator[tuple[DisaggMockerProcess, DisaggMockerProcess]]:
if registration_order not in ("prefill_first", "decode_first"):
raise ValueError(f"Unexpected registration order: {registration_order}")
if registration_order == "prefill_first":
logger.info("Starting %s prefill mocker instances (first)", num_prefill_mockers)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="prefill",
mocker_args=prefill_mocker_args,
num_mockers=num_prefill_mockers,
request_plane=request_plane,
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
logger.info(
"Starting %s decode mocker instances (second)", num_decode_mockers
)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="decode",
mocker_args=decode_mocker_args,
num_mockers=num_decode_mockers,
request_plane=request_plane,
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
yield prefill_workers, decode_workers
return
logger.info("Starting %s decode mocker instances (first)", num_decode_mockers)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="decode",
mocker_args=decode_mocker_args,
num_mockers=num_decode_mockers,
request_plane=request_plane,
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
logger.info(
"Starting %s prefill mocker instances (second)", num_prefill_mockers
)
with DisaggMockerProcess(
request,
namespace=namespace,
worker_type="prefill",
mocker_args=prefill_mocker_args,
num_mockers=num_prefill_mockers,
request_plane=request_plane,
enable_bootstrap=enable_disagg_bootstrap,
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
yield prefill_workers, decode_workers
@pytest.mark.timeout(180) # planner-profile mocker setup can exceed 120s on CI CPUs @pytest.mark.timeout(180) # planner-profile mocker setup can exceed 120s on CI CPUs
@pytest.mark.parametrize( @pytest.mark.parametrize(
"router_mode,durable_kv_events,mocker_args_override", "router_mode,durable_kv_events,mocker_args_override",
...@@ -1192,90 +1268,88 @@ def test_router_decisions_disagg( ...@@ -1192,90 +1268,88 @@ def test_router_decisions_disagg(
# durable_kv_events defaults to False (NATS Core mode) # durable_kv_events defaults to False (NATS Core mode)
} }
if registration_order == "prefill_first": with _launch_disagg_workers(
# Start prefill workers first request,
logger.info("Starting 4 prefill mocker instances (first)") shared_namespace,
with DisaggMockerProcess( registration_order,
request, prefill_mocker_args=mocker_args,
namespace=shared_namespace, decode_mocker_args=mocker_args,
worker_type="prefill", num_prefill_mockers=4,
mocker_args=mocker_args, num_decode_mockers=4,
num_mockers=4, enable_disagg_bootstrap=enable_disagg_bootstrap,
) as (prefill_workers, decode_workers):
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
_test_router_decisions_disagg(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane="nats", request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap, )
) as prefill_workers:
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Then start decode workers
logger.info("Starting 4 decode mocker instances (second)")
with DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Get unique port for this test @pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
frontend_port = get_unique_ports( @pytest.mark.parametrize(
request, num_ports=1, registration_order=registration_order "enable_disagg_bootstrap", [False, True], ids=["no_bootstrap", "with_bootstrap"]
)[0] )
@pytest.mark.timeout(180)
# Run disagg routing test def test_router_decisions_disagg_round_robin_prefill_dp_rank(
_test_router_decisions_disagg( request,
prefill_workers=prefill_workers, runtime_services_dynamic_ports,
decode_workers=decode_workers, predownload_tokenizers,
block_size=BLOCK_SIZE, registration_order,
request=request, enable_disagg_bootstrap,
frontend_port=frontend_port, ):
test_payload=TEST_PAYLOAD, """Verify round-robin disagg prefill requests spread KV stores across DP ranks."""
request_plane="nats", logger.info(
) "Starting disaggregated round-robin prefill dp-rank test "
else: "(registration_order=%s, bootstrap=%s)",
# Start decode workers first registration_order,
logger.info("Starting 4 decode mocker instances (first)") enable_disagg_bootstrap,
with DisaggMockerProcess( )
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane="nats",
) as decode_workers:
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Then start prefill workers namespace_suffix = generate_random_suffix()
logger.info("Starting 4 prefill mocker instances (second)") shared_namespace = f"test-namespace-{namespace_suffix}"
with DisaggMockerProcess( prefill_mocker_args = {
request, "speedup_ratio": SPEEDUP_RATIO,
namespace=shared_namespace, "block_size": BLOCK_SIZE,
worker_type="prefill", "dp_size": 4,
mocker_args=mocker_args, }
num_mockers=4, decode_mocker_args = {
request_plane="nats", "speedup_ratio": SPEEDUP_RATIO,
enable_bootstrap=enable_disagg_bootstrap, "block_size": BLOCK_SIZE,
) as prefill_workers: }
logger.info(
f"Prefill workers using endpoint: {prefill_workers.endpoint}"
)
# Get unique port for this test def run_case(prefill_workers, decode_workers):
frontend_port = get_unique_ports( frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order request, num_ports=1, registration_order=registration_order
)[0] )[0]
_test_router_decisions_disagg_round_robin_prefill_dp_rank(
# Run disagg routing test prefill_workers=prefill_workers,
_test_router_decisions_disagg( decode_workers=decode_workers,
prefill_workers=prefill_workers, block_size=BLOCK_SIZE,
decode_workers=decode_workers, request=request,
block_size=BLOCK_SIZE, frontend_port=frontend_port,
request=request, test_payload=TEST_PAYLOAD,
frontend_port=frontend_port, expected_prefill_dp_ranks=prefill_mocker_args["dp_size"],
test_payload=TEST_PAYLOAD, request_plane="nats",
request_plane="nats", )
)
with _launch_disagg_workers(
request,
shared_namespace,
registration_order,
prefill_mocker_args=prefill_mocker_args,
decode_mocker_args=decode_mocker_args,
num_prefill_mockers=1,
num_decode_mockers=1,
enable_disagg_bootstrap=enable_disagg_bootstrap,
) as (prefill_workers, decode_workers):
run_case(prefill_workers, decode_workers)
@pytest.mark.timeout(180) @pytest.mark.timeout(180)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment