"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "b5e762b2b875551bff51a7b09dba47d2d79fc8a8"
Unverified Commit 08db2844 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test: more stringent test for routing decisions (#6531)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 6654a57d
......@@ -609,6 +609,8 @@ async def send_request_via_python_kv_router(
generated_tokens = []
prefill_worker_id: Optional[int] = None
decode_worker_id: Optional[int] = None
prefill_dp_rank: Optional[int] = None
decode_dp_rank: Optional[int] = None
async for response in stream:
if isinstance(response, dict):
......@@ -625,7 +627,7 @@ async def send_request_via_python_kv_router(
f"Stream finished with reason: {response['finish_reason']}"
)
# Extract worker IDs from disaggregated_params if present
# Extract worker IDs and dp_ranks from disaggregated_params if present
if return_worker_ids and "disaggregated_params" in response:
disagg_params = response["disaggregated_params"]
if isinstance(disagg_params, dict) and "worker_id" in disagg_params:
......@@ -635,6 +637,10 @@ async def send_request_via_python_kv_router(
prefill_worker_id = worker_id_info["prefill_worker_id"]
if "decode_worker_id" in worker_id_info:
decode_worker_id = worker_id_info["decode_worker_id"]
if "prefill_dp_rank" in worker_id_info:
prefill_dp_rank = worker_id_info["prefill_dp_rank"]
if "decode_dp_rank" in worker_id_info:
decode_dp_rank = worker_id_info["decode_dp_rank"]
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.debug(f"Total generated tokens: {len(generated_tokens)}")
......@@ -658,6 +664,8 @@ async def send_request_via_python_kv_router(
return {
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
"prefill_dp_rank": prefill_dp_rank,
"decode_dp_rank": decode_dp_rank,
}
return True
......@@ -1909,21 +1917,23 @@ def _test_router_decisions(
model_name: str,
request,
test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE,
block_size: int = 8,
use_kv_events: bool = True,
durable_kv_events: bool = False,
router_event_threads: int = 1,
):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
"""Validate cross-worker routing decisions based on longest prefix match and tree-size tiebreaking.
Assumes engine workers are already initialized.
The first request is forced to a specific worker (and optionally dp_rank),
and subsequent requests should naturally route to the same worker due to prefix reuse.
Seeds two routing targets (worker a and worker b) with different prefix trees,
then verifies the router picks the correct worker for subsequent requests.
Test sequence:
1. Request 1: [A, B, C, D] → Forces to Worker 1, caches 4 blocks
2. Request 2: [A, B, E, F] → Shares [A, B] prefix, diverges from Request 1
3. Request 3: [A, B, C, D, G, H] → Should route to Worker 1 (has [A, B, C, D] cached)
Test sequence (7 blocks A-G, each block_size tokens, 5 requests):
1. [A, B] → force worker a (seed worker a's tree)
2. [A, C, D] → force worker a (branch under A on worker a)
3. [A, C, E] → force worker b (seed worker b's tree)
4. [A, C, D, F] → router picks (worker a wins: prefix [A,C,D]=3 vs worker b [A,C]=2)
5. [A, C, G] → router picks (tie on [A,C], worker b wins by smaller tree: 3 vs 5)
Args:
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
......@@ -1931,12 +1941,13 @@ def _test_router_decisions(
model_name: Name of the model
request: Pytest request fixture
test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups)
block_size: KV cache block size. Defaults to 8.
use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises:
AssertionError: If routing decisions don't follow KV cache prefix reuse as expected
AssertionError: If routing decisions don't match expected prefix/tiebreak logic
"""
# Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig(
......@@ -1965,187 +1976,178 @@ def _test_router_decisions(
)
logger.info(f"Workers ready: {worker_ids}")
# Use the first worker_id for forced routing
forced_worker_id = worker_ids[0]
forced_dp_rank = 1 if test_dp_rank else None
if test_dp_rank:
logger.info(
f"Will force first request to worker_id={forced_worker_id}, dp_rank={forced_dp_rank}"
)
# Determine worker a / worker b routing targets
if len(worker_ids) >= 2:
worker_a_id = worker_ids[0]
worker_b_id = worker_ids[1]
elif len(worker_ids) == 1 and test_dp_rank:
worker_a_id = worker_ids[0]
worker_b_id = worker_ids[0]
else:
logger.info(f"Will force first request to worker_id={forced_worker_id}")
raise AssertionError(
f"Need at least 2 routing targets but got {len(worker_ids)} worker(s) "
f"with test_dp_rank={test_dp_rank}"
)
# Send 3 requests with some shared prefixes and some divergent prefixes
response_worker_ids: list[dict[str, Optional[int]]] = []
dp_rank_a = 0 if test_dp_rank else None
dp_rank_b = 1 if test_dp_rank else None
logger.info(
f"Routing targets: worker_a=(id={worker_a_id}, dp_rank={dp_rank_a}), "
f"worker_b=(id={worker_b_id}, dp_rank={dp_rank_b})"
)
num_blocks = 8
# Generate 7 random blocks (A-G)
num_blocks = 7
blocks = [
[random.randint(1, 10000) for _ in range(block_size)]
for _ in range(num_blocks)
]
requests = [
blocks[0] + blocks[1] + blocks[2] + blocks[3],
blocks[0] + blocks[1] + blocks[4] + blocks[5],
blocks[0] + blocks[1] + blocks[2] + blocks[3] + blocks[6] + blocks[7],
A, B, C, D, E, F, G = blocks
# 5 requests with specific prefix structure
request_specs = [
# (token_ids, forced_worker_id, forced_dp_rank, sleep_after)
(A + B, worker_a_id, dp_rank_a, 0.1), # req1: seed worker a
(
A + C + D,
worker_a_id,
dp_rank_a,
0.1,
), # req2: branch under A on worker a
(A + C + E, worker_b_id, dp_rank_b, 2.0), # req3: seed worker b
(
A + C + D + F,
None,
None,
2.0,
), # req4: router picks (worker a should win)
(A + C + G, None, None, 2.0), # req5: router picks (worker b should win)
]
for i, request in enumerate(requests):
# Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally
worker_id_override = forced_worker_id if i == 0 else None
dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None
log_msg = f"Sending request {i + 1}/4 with {len(request)} tokens "
if worker_id_override is not None:
if test_dp_rank:
log_msg += f" - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}"
else:
log_msg += f" - FORCING worker_id={worker_id_override}"
response_worker_ids: list[dict[str, Optional[int]]] = []
for i, (token_ids, wid_override, dp_override, sleep_after) in enumerate(
request_specs
):
log_msg = f"Sending request {i + 1}/5 with {len(token_ids)} tokens"
if wid_override is not None:
log_msg += f" - FORCING worker_id={wid_override}"
if dp_override is not None:
log_msg += f", dp_rank={dp_override}"
logger.info(log_msg)
result = await send_request_via_python_kv_router(
kv_python_router=kv_router,
model_name=model_name,
token_ids=request,
token_ids=token_ids,
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True, # Don't stop on EOS token
"max_tokens": 2, # Generate exactly 2 tokens
"ignore_eos": True,
"max_tokens": 2,
},
worker_id=worker_id_override,
dp_rank=dp_rank_override,
worker_id=wid_override,
dp_rank=dp_override,
return_worker_ids=True,
)
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
response_worker_ids.append(result)
logger.info(
f"Request {i + 1} response: prefill_worker_id={result.get('prefill_worker_id')}, "
f"decode_worker_id={result.get('decode_worker_id')}"
f"decode_worker_id={result.get('decode_worker_id')}, "
f"prefill_dp_rank={result.get('prefill_dp_rank')}, "
f"decode_dp_rank={result.get('decode_dp_rank')}"
)
# Wait a bit between requests
await asyncio.sleep(2)
if sleep_after > 0:
await asyncio.sleep(sleep_after)
# Wait for final synchronization (especially important for DP)
if test_dp_rank:
await asyncio.sleep(1)
# Dump events from the router
events_json = await kv_router.dump_events()
return events_json, forced_worker_id, forced_dp_rank, response_worker_ids
return (
events_json,
worker_a_id,
worker_b_id,
dp_rank_a,
dp_rank_b,
response_worker_ids,
)
# Run the async test
(
events_json,
expected_worker_id,
expected_dp_rank,
worker_a_id,
worker_b_id,
dp_rank_a,
dp_rank_b,
response_worker_ids,
) = asyncio.run(test_sync())
# Verify worker IDs from responses
verify_response_worker_ids(
response_worker_ids, "decode_worker_id", expected_worker_id
# Verify request 4 routed to worker a (longest prefix match)
req4 = response_worker_ids[3]
assert req4["prefill_worker_id"] == worker_a_id, (
f"Request 4: expected prefill_worker_id={worker_a_id} (longest prefix match), "
f"got {req4['prefill_worker_id']}"
)
verify_response_worker_ids(
response_worker_ids, "prefill_worker_id", expected_worker_id
if test_dp_rank:
assert (
req4["prefill_dp_rank"] == dp_rank_a
), f"Request 4: expected prefill_dp_rank={dp_rank_a}, got {req4['prefill_dp_rank']}"
# Verify request 5 routed to worker b (tiebreak by smaller tree)
req5 = response_worker_ids[4]
assert req5["prefill_worker_id"] == worker_b_id, (
f"Request 5: expected prefill_worker_id={worker_b_id} (tiebreak by smaller tree), "
f"got {req5['prefill_worker_id']}"
)
# Parse events and count by worker routing key (worker_id or (worker_id, dp_rank))
events = json.loads(events_json)
if test_dp_rank:
# Group by (worker_id, dp_rank) tuple for DP testing
events_by_key_dp: dict[tuple[int, int], list[Any]] = {}
for event in events:
worker_id = event.get("worker_id")
dp_rank = event.get("event", {}).get("dp_rank", 0)
key = (worker_id, dp_rank)
if key not in events_by_key_dp:
events_by_key_dp[key] = []
events_by_key_dp[key].append(event)
logger.info(
f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_key_dp.items()]}"
)
# Verify: All but one routing key should have no events (due to prefix reuse)
keys_with_events_dp = [
key for key, evts in events_by_key_dp.items() if len(evts) > 0
]
assert len(keys_with_events_dp) == 1, (
f"Expected exactly 1 (worker_id, dp_rank) to have events (due to prefix reuse), "
f"but found {len(keys_with_events_dp)} with events: {keys_with_events_dp}"
)
# Verify: The routing key with events should have exactly 8 events (one per unique block)
active_key_dp = keys_with_events_dp[0]
num_events = len(events_by_key_dp[active_key_dp])
assert num_events == 8, (
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 8 events, "
f"but found {num_events} events"
)
assert (
req5["prefill_dp_rank"] == dp_rank_b
), f"Request 5: expected prefill_dp_rank={dp_rank_b}, got {req5['prefill_dp_rank']}"
# Verify: Routing should match the forced values
active_worker_id, active_dp_rank = active_key_dp
assert active_worker_id == expected_worker_id, (
f"Expected all events to have worker_id={expected_worker_id} (forced in first request), "
f"but found worker_id={active_worker_id}"
)
assert active_dp_rank == expected_dp_rank, (
f"Expected all events to have dp_rank={expected_dp_rank} (forced in first request), "
f"but found dp_rank={active_dp_rank}"
)
logger.info(
f"Successfully verified: Worker {active_worker_id} dp_rank {active_dp_rank} handled all 4 requests with prefix reuse. "
f"All events correctly routed to worker_id={expected_worker_id}, dp_rank={expected_dp_rank} as expected. "
f"KV events synchronized correctly."
)
else:
# Group by worker_id only for multiple workers testing
events_by_key_single: dict[int, list] = {}
for event in events:
worker_id = event.get("worker_id")
if worker_id not in events_by_key_single:
events_by_key_single[worker_id] = []
events_by_key_single[worker_id].append(event)
logger.info(
f"Response routing verified: req4 → worker_a (id={worker_a_id}, dp_rank={dp_rank_a}), "
f"req5 → worker_b (id={worker_b_id}, dp_rank={dp_rank_b})"
)
logger.info(
f"Events by worker_id: {[(key, len(evts)) for key, evts in events_by_key_single.items()]}"
)
# Parse events and verify event counts per routing target
events = json.loads(events_json)
# Verify: All but one routing key should have no events (due to prefix reuse)
keys_with_events_single = [
key for key, evts in events_by_key_single.items() if len(evts) > 0
]
# Always group by (worker_id, dp_rank)
events_by_key: dict[tuple[int, int], list[Any]] = {}
for event in events:
worker_id = event.get("worker_id")
dp_rank = event.get("event", {}).get("dp_rank", 0)
key = (worker_id, dp_rank)
if key not in events_by_key:
events_by_key[key] = []
events_by_key[key].append(event)
assert len(keys_with_events_single) == 1, (
f"Expected exactly 1 worker_id to have events (due to prefix reuse), "
f"but found {len(keys_with_events_single)} with events: {keys_with_events_single}"
)
logger.info(
f"Events by (worker_id, dp_rank): {[(key, len(evts)) for key, evts in events_by_key.items()]}"
)
# Verify: The routing key with events should have exactly 8 events (one per unique block)
active_worker_id = keys_with_events_single[0]
num_events = len(events_by_key_single[active_worker_id])
# Worker a key: 5 events (A, B from req1; C, D from req2; F from req4)
worker_a_key = (worker_a_id, dp_rank_a if dp_rank_a is not None else 0)
worker_a_events = len(events_by_key.get(worker_a_key, []))
assert worker_a_events == 5, (
f"Expected worker_a {worker_a_key} to have 5 events (A,B + C,D + F), "
f"but found {worker_a_events}"
)
assert num_events == 8, (
f"Expected worker_id {active_worker_id} to have exactly 8 events, "
f"but found {num_events} events"
)
# Worker b key: 4 events (A, C, E from req3; G from req5)
worker_b_key = (worker_b_id, dp_rank_b if dp_rank_b is not None else 0)
worker_b_events = len(events_by_key.get(worker_b_key, []))
assert worker_b_events == 4, (
f"Expected worker_b {worker_b_key} to have 4 events (A,C,E + G), "
f"but found {worker_b_events}"
)
# Verify: Routing should match the forced values
assert active_worker_id == expected_worker_id, (
f"Expected all events to have worker_id={expected_worker_id} (forced in first request), "
f"but found worker_id={active_worker_id}"
)
logger.info(
f"Successfully verified: Worker {active_worker_id} handled all 4 requests with prefix reuse. "
f"All events correctly routed to worker_id={expected_worker_id} as expected. "
f"KV events synchronized correctly."
)
logger.info(
f"Successfully verified cross-worker routing: "
f"worker_a {worker_a_key} has {worker_a_events} events, "
f"worker_b {worker_b_key} has {worker_b_events} events"
)
def _test_busy_threshold_endpoint(
......
......@@ -680,7 +680,7 @@ def test_router_decisions(
# durable_kv_events=True enables JetStream mode; False (default) uses NATS Core with local indexer
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"block_size": 8,
"dp_size": 4,
"durable_kv_events": durable_kv_events and use_kv_events,
}
......
......@@ -424,6 +424,7 @@ def test_router_decisions_sglang_multiple_workers(
MODEL_NAME,
request,
test_dp_rank=False,
block_size=PAGE_SIZE,
router_event_threads=router_event_threads,
)
......@@ -466,7 +467,12 @@ def test_router_decisions_sglang_dp(
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate")
_test_router_decisions(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
sglang_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=PAGE_SIZE,
)
......
......@@ -453,6 +453,7 @@ def test_router_decisions_vllm_multiple_workers(
MODEL_NAME,
request,
test_dp_rank=False,
block_size=BLOCK_SIZE,
router_event_threads=router_event_threads,
)
......@@ -497,7 +498,12 @@ def test_router_decisions_vllm_dp(
) # endpoint is backend.generate
_test_router_decisions(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
vllm_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=BLOCK_SIZE,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment