Unverified Commit 29ec4969 authored by Neal Vaidya's avatar Neal Vaidya Committed by GitHub
Browse files

chore: add test for diverging prefixes (#4919)


Signed-off-by: default avatarNeal Vaidya <nealv@nvidia.com>
parent bbeb2808
......@@ -1798,11 +1798,16 @@ def _test_router_decisions(
test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE,
):
"""Validate KV cache prefix reuse and worker routing by sending progressive requests with overlapping prefixes.
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
Assumes engine workers are already initialized. Sends 4 progressive requests where each extends
the previous tokens by `block_size`. The first request is forced to a specific worker (and optionally
dp_rank), and subsequent requests should naturally route to the same worker due to prefix reuse.
Assumes engine workers are already initialized.
The first request is forced to a specific worker (and optionally dp_rank),
and subsequent requests should naturally route to the same worker due to prefix reuse.
Test sequence:
1. Request 1: [A, B, C, D] → Forces to Worker 1, caches 4 blocks
2. Request 2: [A, B, E, F] → Shares [A, B] prefix, diverges from Request 1
3. Request 3: [A, B, C, D, G, H] → Should route to Worker 1 (has [A, B, C, D] cached)
Args:
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
......@@ -1844,23 +1849,27 @@ def _test_router_decisions(
else:
logger.info(f"Will force first request to worker_id={forced_worker_id}")
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens = []
# Send 3 requests with some shared prefixes and some divergent prefixes
response_worker_ids: list[dict[str, Optional[int]]] = []
for i in range(4):
# Add `block_size` new random tokens
new_tokens = [random.randint(1, 10000) for _ in range(block_size)]
cumulative_tokens.extend(new_tokens)
num_blocks = 8
blocks = [
[random.randint(1, 10000) for _ in range(block_size)]
for _ in range(num_blocks)
]
requests = [
blocks[0] + blocks[1] + blocks[2] + blocks[3],
blocks[0] + blocks[1] + blocks[4] + blocks[5],
blocks[0] + blocks[1] + blocks[2] + blocks[3] + blocks[6] + blocks[7],
]
for i, request in enumerate(requests):
# Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally
worker_id_override = forced_worker_id if i == 0 else None
dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None
log_msg = (
f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens "
f"(added {len(new_tokens)} new tokens)"
)
log_msg = f"Sending request {i + 1}/4 with {len(request)} tokens "
if worker_id_override is not None:
if test_dp_rank:
log_msg += f" - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}"
......@@ -1871,7 +1880,7 @@ def _test_router_decisions(
result = await send_request_via_python_kv_router(
kv_python_router=kv_push_router,
model_name=model_name,
token_ids=cumulative_tokens.copy(),
token_ids=request,
initial_wait=1.0,
max_retries=8,
stop_conditions={
......@@ -1944,12 +1953,12 @@ def _test_router_decisions(
f"but found {len(keys_with_events_dp)} with events: {keys_with_events_dp}"
)
# Verify: The routing key with events should have exactly 4 events (one per request)
# Verify: The routing key with events should have exactly 8 events (one per unique block)
active_key_dp = keys_with_events_dp[0]
num_events = len(events_by_key_dp[active_key_dp])
assert num_events == 4, (
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 4 events, "
assert num_events == 8, (
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 8 events, "
f"but found {num_events} events"
)
......@@ -1991,12 +2000,12 @@ def _test_router_decisions(
f"but found {len(keys_with_events_single)} with events: {keys_with_events_single}"
)
# Verify: The routing key with events should have exactly 4 events (one per request)
# Verify: The routing key with events should have exactly 8 events (one per unique block)
active_worker_id = keys_with_events_single[0]
num_events = len(events_by_key_single[active_worker_id])
assert num_events == 4, (
f"Expected worker_id {active_worker_id} to have exactly 4 events, "
assert num_events == 8, (
f"Expected worker_id {active_worker_id} to have exactly 8 events, "
f"but found {num_events} events"
)
......
......@@ -333,6 +333,8 @@ def test_sglang_kv_router_basic(
@pytest.mark.pre_merge
@pytest.mark.gpu_1
@pytest.mark.skip(reason="Broken by sglang changes")
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
def test_router_decisions_sglang_multiple_workers(
request, runtime_services, predownload_models, set_ucx_tls_no_mm
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment