"tests/vscode:/vscode.git/clone" did not exist on "e6ce4db3b29b3c942581a843ead98c6ee569a2ee"
Unverified Commit 29ec4969 authored by Neal Vaidya's avatar Neal Vaidya Committed by GitHub
Browse files

chore: add test for diverging prefixes (#4919)


Signed-off-by: default avatarNeal Vaidya <nealv@nvidia.com>
parent bbeb2808
...@@ -1798,11 +1798,16 @@ def _test_router_decisions( ...@@ -1798,11 +1798,16 @@ def _test_router_decisions(
test_dp_rank: bool = False, test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE, block_size: int = BLOCK_SIZE,
): ):
"""Validate KV cache prefix reuse and worker routing by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
Assumes engine workers are already initialized. Sends 4 progressive requests where each extends Assumes engine workers are already initialized.
the previous tokens by `block_size`. The first request is forced to a specific worker (and optionally The first request is forced to a specific worker (and optionally dp_rank),
dp_rank), and subsequent requests should naturally route to the same worker due to prefix reuse. and subsequent requests should naturally route to the same worker due to prefix reuse.
Test sequence:
1. Request 1: [A, B, C, D] → Forces to Worker 1, caches 4 blocks
2. Request 2: [A, B, E, F] → Shares [A, B] prefix, diverges from Request 1
3. Request 3: [A, B, C, D, G, H] → Should route to Worker 1 (has [A, B, C, D] cached)
Args: Args:
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__()) engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
...@@ -1844,23 +1849,27 @@ def _test_router_decisions( ...@@ -1844,23 +1849,27 @@ def _test_router_decisions(
else: else:
logger.info(f"Will force first request to worker_id={forced_worker_id}") logger.info(f"Will force first request to worker_id={forced_worker_id}")
# Send 4 progressive requests with overlapping prefixes # Send 3 requests with some shared prefixes and some divergent prefixes
cumulative_tokens = []
response_worker_ids: list[dict[str, Optional[int]]] = [] response_worker_ids: list[dict[str, Optional[int]]] = []
for i in range(4): num_blocks = 8
# Add `block_size` new random tokens blocks = [
new_tokens = [random.randint(1, 10000) for _ in range(block_size)] [random.randint(1, 10000) for _ in range(block_size)]
cumulative_tokens.extend(new_tokens) for _ in range(num_blocks)
]
requests = [
blocks[0] + blocks[1] + blocks[2] + blocks[3],
blocks[0] + blocks[1] + blocks[4] + blocks[5],
blocks[0] + blocks[1] + blocks[2] + blocks[3] + blocks[6] + blocks[7],
]
for i, request in enumerate(requests):
# Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally # Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally
worker_id_override = forced_worker_id if i == 0 else None worker_id_override = forced_worker_id if i == 0 else None
dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None
log_msg = ( log_msg = f"Sending request {i + 1}/4 with {len(request)} tokens "
f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens "
f"(added {len(new_tokens)} new tokens)"
)
if worker_id_override is not None: if worker_id_override is not None:
if test_dp_rank: if test_dp_rank:
log_msg += f" - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}" log_msg += f" - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}"
...@@ -1871,7 +1880,7 @@ def _test_router_decisions( ...@@ -1871,7 +1880,7 @@ def _test_router_decisions(
result = await send_request_via_python_kv_router( result = await send_request_via_python_kv_router(
kv_python_router=kv_push_router, kv_python_router=kv_push_router,
model_name=model_name, model_name=model_name,
token_ids=cumulative_tokens.copy(), token_ids=request,
initial_wait=1.0, initial_wait=1.0,
max_retries=8, max_retries=8,
stop_conditions={ stop_conditions={
...@@ -1944,12 +1953,12 @@ def _test_router_decisions( ...@@ -1944,12 +1953,12 @@ def _test_router_decisions(
f"but found {len(keys_with_events_dp)} with events: {keys_with_events_dp}" f"but found {len(keys_with_events_dp)} with events: {keys_with_events_dp}"
) )
# Verify: The routing key with events should have exactly 4 events (one per request) # Verify: The routing key with events should have exactly 8 events (one per unique block)
active_key_dp = keys_with_events_dp[0] active_key_dp = keys_with_events_dp[0]
num_events = len(events_by_key_dp[active_key_dp]) num_events = len(events_by_key_dp[active_key_dp])
assert num_events == 4, ( assert num_events == 8, (
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 4 events, " f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 8 events, "
f"but found {num_events} events" f"but found {num_events} events"
) )
...@@ -1991,12 +2000,12 @@ def _test_router_decisions( ...@@ -1991,12 +2000,12 @@ def _test_router_decisions(
f"but found {len(keys_with_events_single)} with events: {keys_with_events_single}" f"but found {len(keys_with_events_single)} with events: {keys_with_events_single}"
) )
# Verify: The routing key with events should have exactly 4 events (one per request) # Verify: The routing key with events should have exactly 8 events (one per unique block)
active_worker_id = keys_with_events_single[0] active_worker_id = keys_with_events_single[0]
num_events = len(events_by_key_single[active_worker_id]) num_events = len(events_by_key_single[active_worker_id])
assert num_events == 4, ( assert num_events == 8, (
f"Expected worker_id {active_worker_id} to have exactly 4 events, " f"Expected worker_id {active_worker_id} to have exactly 8 events, "
f"but found {num_events} events" f"but found {num_events} events"
) )
......
...@@ -333,6 +333,8 @@ def test_sglang_kv_router_basic( ...@@ -333,6 +333,8 @@ def test_sglang_kv_router_basic(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.skip(reason="Broken by sglang changes")
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
def test_router_decisions_sglang_multiple_workers( def test_router_decisions_sglang_multiple_workers(
request, runtime_services, predownload_models, set_ucx_tls_no_mm request, runtime_services, predownload_models, set_ucx_tls_no_mm
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment