Unverified Commit 2160b913 authored by Paul Li's avatar Paul Li Committed by GitHub
Browse files

test: Add per-worker routing and KV event synchronization test + minor...


test: Add per-worker routing and KV event synchronization test + minor subscriber worker_id parsing fix (#3426)
Signed-off-by: default avatarPaul Li <zhixiong2008@gmail.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 332482a9
...@@ -205,12 +205,14 @@ pub async fn start_kv_router_background( ...@@ -205,12 +205,14 @@ pub async fn start_kv_router_background(
let key = String::from_utf8_lossy(kv.key()); let key = String::from_utf8_lossy(kv.key());
let Some(worker_id_str) = key.split('/').next_back() else { // Extract the hex worker ID after the colon (e.g., "generate:694d99badb9f7c07" -> "694d99badb9f7c07")
let Some(worker_id_str) = key.split(':').next_back() else {
tracing::warn!("Could not extract worker ID from instance key: {}", key); tracing::warn!("Could not extract worker ID from instance key: {}", key);
continue; continue;
}; };
let Ok(worker_id) = worker_id_str.parse::<i64>() else { // Parse as hexadecimal (base 16)
let Ok(worker_id) = i64::from_str_radix(worker_id_str, 16) else {
tracing::warn!("Could not parse worker ID from instance key: {}", key); tracing::warn!("Could not parse worker ID from instance key: {}", key);
continue; continue;
}; };
......
...@@ -7,6 +7,7 @@ import logging ...@@ -7,6 +7,7 @@ import logging
import os import os
import random import random
import string import string
import subprocess
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import aiohttp import aiohttp
...@@ -188,7 +189,9 @@ async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8) ...@@ -188,7 +189,9 @@ async def send_request_with_retry(url: str, payload: dict, max_retries: int = 8)
# Read the response to ensure it's valid # Read the response to ensure it's valid
async for _ in response.content: async for _ in response.content:
pass pass
logger.info(f"First request succeeded on attempt {attempt + 1}") logger.debug(
f"First request succeeded on attempt {attempt + 1}"
)
return True return True
else: else:
logger.warning( logger.warning(
...@@ -305,7 +308,7 @@ async def send_request_via_python_kv_router( ...@@ -305,7 +308,7 @@ async def send_request_via_python_kv_router(
# Retry loop sending reuqest to mocker worker with exponential backoff # Retry loop sending reuqest to mocker worker with exponential backoff
for attempt in range(max_retries + 1): for attempt in range(max_retries + 1):
try: try:
logger.info(f"Sending request to {log_message} (attempt {attempt + 1})") logger.debug(f"Sending request to {log_message} (attempt {attempt + 1})")
stream = await kv_python_router.generate( stream = await kv_python_router.generate(
token_ids=token_ids, token_ids=token_ids,
...@@ -318,7 +321,7 @@ async def send_request_via_python_kv_router( ...@@ -318,7 +321,7 @@ async def send_request_via_python_kv_router(
) )
if stream is not None: if stream is not None:
logger.info(f"Request succeeded on attempt {attempt + 1}") logger.debug(f"Request succeeded on attempt {attempt + 1}")
break break
except Exception as e: except Exception as e:
...@@ -344,10 +347,12 @@ async def send_request_via_python_kv_router( ...@@ -344,10 +347,12 @@ async def send_request_via_python_kv_router(
# Check for finish reason # Check for finish reason
if "finish_reason" in response: if "finish_reason" in response:
logger.info(f"Stream finished with reason: {response['finish_reason']}") logger.debug(
f"Stream finished with reason: {response['finish_reason']}"
)
# Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True # Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True
logger.info(f"Total generated tokens: {len(generated_tokens)}") logger.debug(f"Total generated tokens: {len(generated_tokens)}")
if ( if (
stop_conditions stop_conditions
and "max_tokens" in stop_conditions and "max_tokens" in stop_conditions
...@@ -360,7 +365,7 @@ async def send_request_via_python_kv_router( ...@@ -360,7 +365,7 @@ async def send_request_via_python_kv_router(
f"Tokens: {generated_tokens}" f"Tokens: {generated_tokens}"
) )
logger.info( logger.debug(
f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True" f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True"
) )
return True return True
...@@ -368,6 +373,74 @@ async def send_request_via_python_kv_router( ...@@ -368,6 +373,74 @@ async def send_request_via_python_kv_router(
return False return False
async def wait_for_mockers_ready(
endpoint, router: KvPushRouter, expected_num_workers: int = NUM_MOCKERS
) -> list[int]:
"""Wait for mocker workers to be ready and return their instance IDs.
This function polls the endpoint's client for instance IDs until the expected
number of workers are available, then sends a warmup request to verify they
can handle requests.
Args:
endpoint: The endpoint object to get the client from
router: The KvPushRouter to use for sending warmup requests
expected_num_workers: Number of workers to wait for (default: NUM_MOCKERS)
Returns:
Sorted list of unique instance IDs (ints).
Raises:
AssertionError: If workers don't become ready or warmup request fails.
"""
logger.info("Waiting for mockers to be ready")
# Get the client from the endpoint
client = await endpoint.client()
# Poll for instance IDs until we have the expected number
instance_ids: list[int] = []
max_wait_time = 60 # seconds
start_time = asyncio.get_event_loop().time()
while len(instance_ids) < expected_num_workers:
instance_ids = client.instance_ids()
logger.info(f"Found {len(instance_ids)} instance(s): {instance_ids}")
if len(instance_ids) >= expected_num_workers:
break
# Check timeout
if asyncio.get_event_loop().time() - start_time > max_wait_time:
raise AssertionError(
f"Timeout waiting for workers. Found {len(instance_ids)} instance(s), expected {expected_num_workers}"
)
# Wait 1 second before polling again
await asyncio.sleep(1.0)
# Send a warmup request to verify workers can handle requests
test_token_ids = [random.randint(1, 10000) for _ in range(4)]
logger.info(f"Sending warmup request with {len(test_token_ids)} tokens")
try:
await send_request_via_python_kv_router(
kv_python_router=router,
token_ids=test_token_ids,
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True,
"max_tokens": 2,
},
)
except Exception as e:
raise AssertionError(f"Warmup request failed: {e}")
logger.info(f"All {len(instance_ids)} workers are ready")
return sorted(instance_ids)
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
def test_mocker_kv_router(request, runtime_services, predownload_tokenizers): def test_mocker_kv_router(request, runtime_services, predownload_tokenizers):
...@@ -693,21 +766,8 @@ def test_kv_push_router_bindings(request, runtime_services, predownload_tokenize ...@@ -693,21 +766,8 @@ def test_kv_push_router_bindings(request, runtime_services, predownload_tokenize
logger.info("Created KvPushRouter Python object") logger.info("Created KvPushRouter Python object")
# Initialize and check the readiness of the mockers by sending dummy request # Wait for mockers to be ready
asyncio.run( asyncio.run(wait_for_mockers_ready(endpoint, kv_push_router))
send_request_via_python_kv_router(
kv_python_router=kv_push_router,
token_ids=[1, 2, 3],
initial_wait=1.0,
max_retries=8,
stop_conditions={"max_tokens": 1}, # Generate just 1 token
sampling_options={"temperature": 0.7},
output_options={
"include_input_tokens": False,
"return_full_text": False,
},
)
)
# Generate random token IDs (100 to 200 tokens) # Generate random token IDs (100 to 200 tokens)
num_input_tokens = random.randint(100, 200) num_input_tokens = random.randint(100, 200)
...@@ -832,24 +892,11 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -832,24 +892,11 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
kv_router_config = KvRouterConfig(router_snapshot_threshold=20) kv_router_config = KvRouterConfig(router_snapshot_threshold=20)
async def send_requests_to_router(router, num_requests, router_name): async def send_requests_to_router(router, num_requests, router_name):
# Generate small test token IDs
test_token_ids = [random.randint(1, 10000) for _ in range(10)]
# Initialize and check the readiness of the mockers by sending dummy request
logger.info(f"Initializing {router_name} and mocker instances")
await send_request_via_python_kv_router(
kv_python_router=router,
token_ids=test_token_ids,
initial_wait=1.0,
max_retries=8,
stop_conditions={"max_tokens": 1}, # Generate just 1 token
)
# Now send the actual requests # Now send the actual requests
tasks = [] tasks = []
for i in range(num_requests): for i in range(num_requests):
# Generate random token IDs for each request # Generate random token IDs for each request
logger.info( logger.debug(
f"Sending request {i + 1}/{num_requests} to {router_name}" f"Sending request {i + 1}/{num_requests} to {router_name}"
) )
...@@ -888,7 +935,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -888,7 +935,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
kv_router_config=kv_router_config, kv_router_config=kv_router_config,
) )
# Send 25 requests to first router with initial retry loop # Wait for mockers to be ready
await wait_for_mockers_ready(endpoint, kv_push_router1)
# Send 25 requests to first router
logger.info("Sending 25 requests to first router") logger.info("Sending 25 requests to first router")
# Send requests to first router # Send requests to first router
...@@ -921,6 +971,67 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -921,6 +971,67 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
logger.info("Waiting for final synchronization") logger.info("Waiting for final synchronization")
await asyncio.sleep(1) await asyncio.sleep(1)
# Verify NATS object store bucket was created with snapshot
# Mirror the Rust bucket naming logic from subscriber.rs:
# component.subject() -> "namespace.{ns}.component.{comp}"
# then slugify (convert dots to dashes, lowercase, etc) and append "-radix-bucket"
component_subject = f"namespace.{mockers.namespace}.component.mocker"
slugified = component_subject.lower().replace(".", "-").replace("_", "-")
expected_bucket = f"{slugified}-radix-bucket"
expected_file = "radix-state"
logger.info(f"Verifying NATS object store bucket exists: {expected_bucket}")
snapshot_verified = False
try:
# List objects in the bucket
result = subprocess.run(
[
"nats",
"object",
"ls",
expected_bucket,
"--server",
"nats://localhost:4222",
],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode == 0:
logger.info(
f"Successfully listed bucket contents:\n{result.stdout}"
)
# Check if the expected file exists
if expected_file in result.stdout:
logger.info(
f"✓ Snapshot file '{expected_file}' found in bucket '{expected_bucket}'"
)
snapshot_verified = True
else:
logger.error(
f"Snapshot file '{expected_file}' not found in bucket '{expected_bucket}'"
)
logger.error(f"Bucket contents:\n{result.stdout}")
else:
logger.error(f"Failed to list bucket: {result.stderr}")
except subprocess.TimeoutExpired:
logger.error("Timeout checking NATS object store bucket")
except FileNotFoundError:
logger.warning(
"nats CLI not found in PATH, skipping bucket verification (test will continue)"
)
snapshot_verified = True # Don't fail if nats CLI not installed
except Exception as e:
logger.error(f"Error checking NATS object store: {e}")
# Assert that snapshot was created (threshold=20, sent 25 requests)
if not snapshot_verified:
assert False, (
f"Expected snapshot to be created in bucket '{expected_bucket}' with file '{expected_file}'. "
f"Router sent 25 requests with snapshot_threshold=20, so snapshot should have been triggered."
)
# Dump states from both routers # Dump states from both routers
logger.info("Dumping states from both routers") logger.info("Dumping states from both routers")
state1_json = await kv_push_router1.dump_events() state1_json = await kv_push_router1.dump_events()
...@@ -1192,3 +1303,139 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -1192,3 +1303,139 @@ def test_query_instance_id_returns_worker_and_tokens(
kv_router.__exit__(None, None, None) kv_router.__exit__(None, None, None)
if "mockers" in locals(): if "mockers" in locals():
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.model(MODEL_NAME)
def test_router_decisions(request, runtime_services, predownload_tokenizers):
"""Validate KV cache prefix reuse by sending progressive requests with overlapping prefixes.
Flow:
- Start two mocker workers sharing a namespace.
- Wait for workers to be ready.
- Send 4 progressive requests, each extending the previous tokens:
* Request 1: BLOCK_SIZE random tokens
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens
- Dump events from router and verify:
* All but one worker should have no events (one worker handles all due to prefix reuse)
* The worker with events should have exactly 4 events (one per request)
"""
# runtime_services starts etcd and nats
logger.info("Starting test router prefix reuse and KV events synchronization")
# Create mocker args dictionary
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
try:
# Start mocker instances with the new CLI interface
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers
mockers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime()
# Use the namespace from the mockers
namespace = runtime.namespace(mockers.namespace)
component = namespace.component("mocker")
endpoint = component.endpoint("generate")
# Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig(router_snapshot_threshold=20)
kv_push_router = KvPushRouter(
endpoint=endpoint,
block_size=BLOCK_SIZE,
kv_router_config=kv_router_config,
)
# Use async to manage the test flow
async def test_sync():
# Wait for workers to be ready and get their instance IDs
mocker_worker_ids = await wait_for_mockers_ready(endpoint, kv_push_router)
logger.info(f"Workers ready: {mocker_worker_ids}")
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens = []
for i in range(4):
# Add BLOCK_SIZE new random tokens
new_tokens = [random.randint(1, 10000) for _ in range(BLOCK_SIZE)]
cumulative_tokens.extend(new_tokens)
logger.info(
f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens "
f"(added {len(new_tokens)} new tokens)"
)
await send_request_via_python_kv_router(
kv_python_router=kv_push_router,
token_ids=cumulative_tokens.copy(),
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True, # Don't stop on EOS token
"max_tokens": 2, # Generate exactly 2 tokens
},
)
# Wait a bit between requests
await asyncio.sleep(0.5)
# Wait for final synchronization
await asyncio.sleep(1)
# Dump events from the router
events_json = await kv_push_router.dump_events()
return events_json
# Run the async test
events_json = asyncio.run(test_sync())
# Parse events and count by worker
events = json.loads(events_json)
events_by_worker: dict[int, list[Any]] = {}
for event in events:
worker_id = event.get("worker_id")
if worker_id not in events_by_worker:
events_by_worker[worker_id] = []
events_by_worker[worker_id].append(event)
logger.info(
f"Events by worker: {[(wid, len(evts)) for wid, evts in events_by_worker.items()]}"
)
# Verify: All but one worker should have no events
workers_with_events = [
wid for wid, evts in events_by_worker.items() if len(evts) > 0
]
assert len(workers_with_events) == 1, (
f"Expected exactly 1 worker to have events (due to prefix reuse), "
f"but found {len(workers_with_events)} workers with events: {workers_with_events}"
)
# Verify: The worker with events should have exactly 4 events
active_worker = workers_with_events[0]
num_events = len(events_by_worker[active_worker])
assert num_events == 4, (
f"Expected worker {active_worker} to have exactly 4 events, "
f"but found {num_events} events"
)
logger.info(
f"Successfully verified: Worker {active_worker} handled all 4 requests with prefix reuse. "
f"KV events synchronized correctly."
)
finally:
# Clean up mockers
if "mockers" in locals():
mockers.__exit__(None, None, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment