Unverified Commit 08db2844 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test: more stringent test for routing decisions (#6531)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 6654a57d
...@@ -609,6 +609,8 @@ async def send_request_via_python_kv_router( ...@@ -609,6 +609,8 @@ async def send_request_via_python_kv_router(
generated_tokens = [] generated_tokens = []
prefill_worker_id: Optional[int] = None prefill_worker_id: Optional[int] = None
decode_worker_id: Optional[int] = None decode_worker_id: Optional[int] = None
prefill_dp_rank: Optional[int] = None
decode_dp_rank: Optional[int] = None
async for response in stream: async for response in stream:
if isinstance(response, dict): if isinstance(response, dict):
...@@ -625,7 +627,7 @@ async def send_request_via_python_kv_router( ...@@ -625,7 +627,7 @@ async def send_request_via_python_kv_router(
f"Stream finished with reason: {response['finish_reason']}" f"Stream finished with reason: {response['finish_reason']}"
) )
# Extract worker IDs from disaggregated_params if present # Extract worker IDs and dp_ranks from disaggregated_params if present
if return_worker_ids and "disaggregated_params" in response: if return_worker_ids and "disaggregated_params" in response:
disagg_params = response["disaggregated_params"] disagg_params = response["disaggregated_params"]
if isinstance(disagg_params, dict) and "worker_id" in disagg_params: if isinstance(disagg_params, dict) and "worker_id" in disagg_params:
...@@ -635,6 +637,10 @@ async def send_request_via_python_kv_router( ...@@ -635,6 +637,10 @@ async def send_request_via_python_kv_router(
prefill_worker_id = worker_id_info["prefill_worker_id"] prefill_worker_id = worker_id_info["prefill_worker_id"]
if "decode_worker_id" in worker_id_info: if "decode_worker_id" in worker_id_info:
decode_worker_id = worker_id_info["decode_worker_id"] decode_worker_id = worker_id_info["decode_worker_id"]
if "prefill_dp_rank" in worker_id_info:
prefill_dp_rank = worker_id_info["prefill_dp_rank"]
if "decode_dp_rank" in worker_id_info:
decode_dp_rank = worker_id_info["decode_dp_rank"]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True # Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.debug(f"Total generated tokens: {len(generated_tokens)}") logger.debug(f"Total generated tokens: {len(generated_tokens)}")
...@@ -658,6 +664,8 @@ async def send_request_via_python_kv_router( ...@@ -658,6 +664,8 @@ async def send_request_via_python_kv_router(
return { return {
"prefill_worker_id": prefill_worker_id, "prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id, "decode_worker_id": decode_worker_id,
"prefill_dp_rank": prefill_dp_rank,
"decode_dp_rank": decode_dp_rank,
} }
return True return True
...@@ -1909,21 +1917,23 @@ def _test_router_decisions( ...@@ -1909,21 +1917,23 @@ def _test_router_decisions(
model_name: str, model_name: str,
request, request,
test_dp_rank: bool = False, test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE, block_size: int = 8,
use_kv_events: bool = True, use_kv_events: bool = True,
durable_kv_events: bool = False, durable_kv_events: bool = False,
router_event_threads: int = 1, router_event_threads: int = 1,
): ):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes. """Validate cross-worker routing decisions based on longest prefix match and tree-size tiebreaking.
Assumes engine workers are already initialized. Assumes engine workers are already initialized.
The first request is forced to a specific worker (and optionally dp_rank), Seeds two routing targets (worker a and worker b) with different prefix trees,
and subsequent requests should naturally route to the same worker due to prefix reuse. then verifies the router picks the correct worker for subsequent requests.
Test sequence: Test sequence (7 blocks A-G, each block_size tokens, 5 requests):
1. Request 1: [A, B, C, D] → Forces to Worker 1, caches 4 blocks 1. [A, B] → force worker a (seed worker a's tree)
2. Request 2: [A, B, E, F] → Shares [A, B] prefix, diverges from Request 1 2. [A, C, D] → force worker a (branch under A on worker a)
3. Request 3: [A, B, C, D, G, H] → Should route to Worker 1 (has [A, B, C, D] cached) 3. [A, C, E] → force worker b (seed worker b's tree)
4. [A, C, D, F] → router picks (worker a wins: prefix [A,C,D]=3 vs worker b [A,C]=2)
5. [A, C, G] → router picks (tie on [A,C], worker b wins by smaller tree: 3 vs 5)
Args: Args:
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__()) engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
...@@ -1931,12 +1941,13 @@ def _test_router_decisions( ...@@ -1931,12 +1941,13 @@ def _test_router_decisions(
model_name: Name of the model model_name: Name of the model
request: Pytest request fixture request: Pytest request fixture
test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups) test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups)
block_size: KV cache block size. Defaults to 8.
use_kv_events: If True (default), uses KV events from workers. If False, uses use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode). approximate routing with TTL-based expiration (--no-kv-events mode).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False. durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises: Raises:
AssertionError: If routing decisions don't follow KV cache prefix reuse as expected AssertionError: If routing decisions don't match expected prefix/tiebreak logic
""" """
# Create KvRouterConfig with lower snapshot threshold for testing # Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig( kv_router_config = KvRouterConfig(
...@@ -1965,187 +1976,178 @@ def _test_router_decisions( ...@@ -1965,187 +1976,178 @@ def _test_router_decisions(
) )
logger.info(f"Workers ready: {worker_ids}") logger.info(f"Workers ready: {worker_ids}")
# Use the first worker_id for forced routing # Determine worker a / worker b routing targets
forced_worker_id = worker_ids[0] if len(worker_ids) >= 2:
forced_dp_rank = 1 if test_dp_rank else None worker_a_id = worker_ids[0]
worker_b_id = worker_ids[1]
if test_dp_rank: elif len(worker_ids) == 1 and test_dp_rank:
logger.info( worker_a_id = worker_ids[0]
f"Will force first request to worker_id={forced_worker_id}, dp_rank={forced_dp_rank}" worker_b_id = worker_ids[0]
)
else: else:
logger.info(f"Will force first request to worker_id={forced_worker_id}") raise AssertionError(
f"Need at least 2 routing targets but got {len(worker_ids)} worker(s) "
f"with test_dp_rank={test_dp_rank}"
)
# Send 3 requests with some shared prefixes and some divergent prefixes dp_rank_a = 0 if test_dp_rank else None
response_worker_ids: list[dict[str, Optional[int]]] = [] dp_rank_b = 1 if test_dp_rank else None
logger.info(
f"Routing targets: worker_a=(id={worker_a_id}, dp_rank={dp_rank_a}), "
f"worker_b=(id={worker_b_id}, dp_rank={dp_rank_b})"
)
num_blocks = 8 # Generate 7 random blocks (A-G)
num_blocks = 7
blocks = [ blocks = [
[random.randint(1, 10000) for _ in range(block_size)] [random.randint(1, 10000) for _ in range(block_size)]
for _ in range(num_blocks) for _ in range(num_blocks)
] ]
A, B, C, D, E, F, G = blocks
requests = [
blocks[0] + blocks[1] + blocks[2] + blocks[3], # 5 requests with specific prefix structure
blocks[0] + blocks[1] + blocks[4] + blocks[5], request_specs = [
blocks[0] + blocks[1] + blocks[2] + blocks[3] + blocks[6] + blocks[7], # (token_ids, forced_worker_id, forced_dp_rank, sleep_after)
(A + B, worker_a_id, dp_rank_a, 0.1), # req1: seed worker a
(
A + C + D,
worker_a_id,
dp_rank_a,
0.1,
), # req2: branch under A on worker a
(A + C + E, worker_b_id, dp_rank_b, 2.0), # req3: seed worker b
(
A + C + D + F,
None,
None,
2.0,
), # req4: router picks (worker a should win)
(A + C + G, None, None, 2.0), # req5: router picks (worker b should win)
] ]
for i, request in enumerate(requests): response_worker_ids: list[dict[str, Optional[int]]] = []
# Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally
worker_id_override = forced_worker_id if i == 0 else None for i, (token_ids, wid_override, dp_override, sleep_after) in enumerate(
dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None request_specs
):
log_msg = f"Sending request {i + 1}/4 with {len(request)} tokens " log_msg = f"Sending request {i + 1}/5 with {len(token_ids)} tokens"
if worker_id_override is not None: if wid_override is not None:
if test_dp_rank: log_msg += f" - FORCING worker_id={wid_override}"
log_msg += f" - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}" if dp_override is not None:
else: log_msg += f", dp_rank={dp_override}"
log_msg += f" - FORCING worker_id={worker_id_override}"
logger.info(log_msg) logger.info(log_msg)
result = await send_request_via_python_kv_router( result = await send_request_via_python_kv_router(
kv_python_router=kv_router, kv_python_router=kv_router,
model_name=model_name, model_name=model_name,
token_ids=request, token_ids=token_ids,
initial_wait=1.0, initial_wait=1.0,
max_retries=8, max_retries=8,
stop_conditions={ stop_conditions={
"ignore_eos": True, # Don't stop on EOS token "ignore_eos": True,
"max_tokens": 2, # Generate exactly 2 tokens "max_tokens": 2,
}, },
worker_id=worker_id_override, worker_id=wid_override,
dp_rank=dp_rank_override, dp_rank=dp_override,
return_worker_ids=True, return_worker_ids=True,
) )
assert isinstance(result, dict), f"Expected dict result, got {type(result)}" assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
response_worker_ids.append(result) response_worker_ids.append(result)
logger.info( logger.info(
f"Request {i + 1} response: prefill_worker_id={result.get('prefill_worker_id')}, " f"Request {i + 1} response: prefill_worker_id={result.get('prefill_worker_id')}, "
f"decode_worker_id={result.get('decode_worker_id')}" f"decode_worker_id={result.get('decode_worker_id')}, "
f"prefill_dp_rank={result.get('prefill_dp_rank')}, "
f"decode_dp_rank={result.get('decode_dp_rank')}"
) )
# Wait a bit between requests if sleep_after > 0:
await asyncio.sleep(2) await asyncio.sleep(sleep_after)
# Wait for final synchronization (especially important for DP)
if test_dp_rank:
await asyncio.sleep(1)
# Dump events from the router
events_json = await kv_router.dump_events() events_json = await kv_router.dump_events()
return events_json, forced_worker_id, forced_dp_rank, response_worker_ids return (
events_json,
worker_a_id,
worker_b_id,
dp_rank_a,
dp_rank_b,
response_worker_ids,
)
# Run the async test # Run the async test
( (
events_json, events_json,
expected_worker_id, worker_a_id,
expected_dp_rank, worker_b_id,
dp_rank_a,
dp_rank_b,
response_worker_ids, response_worker_ids,
) = asyncio.run(test_sync()) ) = asyncio.run(test_sync())
# Verify worker IDs from responses # Verify request 4 routed to worker a (longest prefix match)
verify_response_worker_ids( req4 = response_worker_ids[3]
response_worker_ids, "decode_worker_id", expected_worker_id assert req4["prefill_worker_id"] == worker_a_id, (
f"Request 4: expected prefill_worker_id={worker_a_id} (longest prefix match), "
f"got {req4['prefill_worker_id']}"
) )
verify_response_worker_ids( if test_dp_rank:
response_worker_ids, "prefill_worker_id", expected_worker_id assert (
req4["prefill_dp_rank"] == dp_rank_a
), f"Request 4: expected prefill_dp_rank={dp_rank_a}, got {req4['prefill_dp_rank']}"
# Verify request 5 routed to worker b (tiebreak by smaller tree)
req5 = response_worker_ids[4]
assert req5["prefill_worker_id"] == worker_b_id, (
f"Request 5: expected prefill_worker_id={worker_b_id} (tiebreak by smaller tree), "
f"got {req5['prefill_worker_id']}"
) )
# Parse events and count by worker routing key (worker_id or (worker_id, dp_rank))
events = json.loads(events_json)
if test_dp_rank: if test_dp_rank:
# Group by (worker_id, dp_rank) tuple for DP testing assert (
events_by_key_dp: dict[tuple[int, int], list[Any]] = {} req5["prefill_dp_rank"] == dp_rank_b
for event in events: ), f"Request 5: expected prefill_dp_rank={dp_rank_b}, got {req5['prefill_dp_rank']}"
worker_id = event.get("worker_id")
dp_rank = event.get("event", {}).get("dp_rank", 0)
key = (worker_id, dp_rank)
if key not in events_by_key_dp:
events_by_key_dp[key] = []
events_by_key_dp[key].append(event)
logger.info(
f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_key_dp.items()]}"
)
# Verify: All but one routing key should have no events (due to prefix reuse)
keys_with_events_dp = [
key for key, evts in events_by_key_dp.items() if len(evts) > 0
]
assert len(keys_with_events_dp) == 1, (
f"Expected exactly 1 (worker_id, dp_rank) to have events (due to prefix reuse), "
f"but found {len(keys_with_events_dp)} with events: {keys_with_events_dp}"
)
# Verify: The routing key with events should have exactly 8 events (one per unique block)
active_key_dp = keys_with_events_dp[0]
num_events = len(events_by_key_dp[active_key_dp])
assert num_events == 8, (
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 8 events, "
f"but found {num_events} events"
)
# Verify: Routing should match the forced values logger.info(
active_worker_id, active_dp_rank = active_key_dp f"Response routing verified: req4 → worker_a (id={worker_a_id}, dp_rank={dp_rank_a}), "
assert active_worker_id == expected_worker_id, ( f"req5 → worker_b (id={worker_b_id}, dp_rank={dp_rank_b})"
f"Expected all events to have worker_id={expected_worker_id} (forced in first request), " )
f"but found worker_id={active_worker_id}"
)
assert active_dp_rank == expected_dp_rank, (
f"Expected all events to have dp_rank={expected_dp_rank} (forced in first request), "
f"but found dp_rank={active_dp_rank}"
)
logger.info(
f"Successfully verified: Worker {active_worker_id} dp_rank {active_dp_rank} handled all 4 requests with prefix reuse. "
f"All events correctly routed to worker_id={expected_worker_id}, dp_rank={expected_dp_rank} as expected. "
f"KV events synchronized correctly."
)
else:
# Group by worker_id only for multiple workers testing
events_by_key_single: dict[int, list] = {}
for event in events:
worker_id = event.get("worker_id")
if worker_id not in events_by_key_single:
events_by_key_single[worker_id] = []
events_by_key_single[worker_id].append(event)
logger.info( # Parse events and verify event counts per routing target
f"Events by worker_id: {[(key, len(evts)) for key, evts in events_by_key_single.items()]}" events = json.loads(events_json)
)
# Verify: All but one routing key should have no events (due to prefix reuse) # Always group by (worker_id, dp_rank)
keys_with_events_single = [ events_by_key: dict[tuple[int, int], list[Any]] = {}
key for key, evts in events_by_key_single.items() if len(evts) > 0 for event in events:
] worker_id = event.get("worker_id")
dp_rank = event.get("event", {}).get("dp_rank", 0)
key = (worker_id, dp_rank)
if key not in events_by_key:
events_by_key[key] = []
events_by_key[key].append(event)
assert len(keys_with_events_single) == 1, ( logger.info(
f"Expected exactly 1 worker_id to have events (due to prefix reuse), " f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_key.items()]}"
f"but found {len(keys_with_events_single)} with events: {keys_with_events_single}" )
)
# Verify: The routing key with events should have exactly 8 events (one per unique block) # Worker a key: 5 events (A, B from req1; C, D from req2; F from req4)
active_worker_id = keys_with_events_single[0] worker_a_key = (worker_a_id, dp_rank_a if dp_rank_a is not None else 0)
num_events = len(events_by_key_single[active_worker_id]) worker_a_events = len(events_by_key.get(worker_a_key, []))
assert worker_a_events == 5, (
f"Expected worker_a {worker_a_key} to have 5 events (A,B + C,D + F), "
f"but found {worker_a_events}"
)
assert num_events == 8, ( # Worker b key: 4 events (A, C, E from req3; G from req5)
f"Expected worker_id {active_worker_id} to have exactly 8 events, " worker_b_key = (worker_b_id, dp_rank_b if dp_rank_b is not None else 0)
f"but found {num_events} events" worker_b_events = len(events_by_key.get(worker_b_key, []))
) assert worker_b_events == 4, (
f"Expected worker_b {worker_b_key} to have 4 events (A,C,E + G), "
f"but found {worker_b_events}"
)
# Verify: Routing should match the forced values logger.info(
assert active_worker_id == expected_worker_id, ( f"Successfully verified cross-worker routing: "
f"Expected all events to have worker_id={expected_worker_id} (forced in first request), " f"worker_a {worker_a_key} has {worker_a_events} events, "
f"but found worker_id={active_worker_id}" f"worker_b {worker_b_key} has {worker_b_events} events"
) )
logger.info(
f"Successfully verified: Worker {active_worker_id} handled all 4 requests with prefix reuse. "
f"All events correctly routed to worker_id={expected_worker_id} as expected. "
f"KV events synchronized correctly."
)
def _test_busy_threshold_endpoint( def _test_busy_threshold_endpoint(
......
...@@ -680,7 +680,7 @@ def test_router_decisions( ...@@ -680,7 +680,7 @@ def test_router_decisions(
# durable_kv_events=True enables JetStream mode; False (default) uses NATS Core with local indexer # durable_kv_events=True enables JetStream mode; False (default) uses NATS Core with local indexer
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": 8,
"dp_size": 4, "dp_size": 4,
"durable_kv_events": durable_kv_events and use_kv_events, "durable_kv_events": durable_kv_events and use_kv_events,
} }
......
...@@ -424,6 +424,7 @@ def test_router_decisions_sglang_multiple_workers( ...@@ -424,6 +424,7 @@ def test_router_decisions_sglang_multiple_workers(
MODEL_NAME, MODEL_NAME,
request, request,
test_dp_rank=False, test_dp_rank=False,
block_size=PAGE_SIZE,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
) )
...@@ -466,7 +467,12 @@ def test_router_decisions_sglang_dp( ...@@ -466,7 +467,12 @@ def test_router_decisions_sglang_dp(
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate") endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate")
_test_router_decisions( _test_router_decisions(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=PAGE_SIZE,
) )
......
...@@ -453,6 +453,7 @@ def test_router_decisions_vllm_multiple_workers( ...@@ -453,6 +453,7 @@ def test_router_decisions_vllm_multiple_workers(
MODEL_NAME, MODEL_NAME,
request, request,
test_dp_rank=False, test_dp_rank=False,
block_size=BLOCK_SIZE,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
) )
...@@ -497,7 +498,12 @@ def test_router_decisions_vllm_dp( ...@@ -497,7 +498,12 @@ def test_router_decisions_vllm_dp(
) # endpoint is backend.generate ) # endpoint is backend.generate
_test_router_decisions( _test_router_decisions(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=True vllm_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=BLOCK_SIZE,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment