Unverified Commit e7a73cb1 authored by Paul Li's avatar Paul Li Committed by GitHub
Browse files

test: Simplify test_indexers_sync() with the new helper send_request_via_python_kv_router() (#3339)


Signed-off-by: default avatarPaul Li <zhixiong2008@gmail.com>
parent 031dc589
...@@ -314,7 +314,7 @@ async def send_request_via_python_kv_router( ...@@ -314,7 +314,7 @@ async def send_request_via_python_kv_router(
sampling_options=sampling_options, sampling_options=sampling_options,
output_options=output_options, output_options=output_options,
router_config_override=router_config_override, router_config_override=router_config_override,
# worker_id=worker_id, worker_id=worker_id,
) )
if stream is not None: if stream is not None:
...@@ -816,9 +816,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -816,9 +816,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
) )
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers
mockers.__enter__() mockers.__enter__()
# Run the async test # Use async to manage the test flow
async def test_sync(): async def test_sync():
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime()
...@@ -827,70 +828,51 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -827,70 +828,51 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
component = namespace.component("mocker") component = namespace.component("mocker")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
# Create first KV router # Create KvRouterConfig with lower snapshot threshold for testing
from dynamo._core import KvPushRouter, KvRouterConfig
kv_router_config = KvRouterConfig(router_snapshot_threshold=20) kv_router_config = KvRouterConfig(router_snapshot_threshold=20)
async def send_requests_to_router(router, num_requests, router_name): async def send_requests_to_router(router, num_requests, router_name):
# First, send a test request with retry to ensure router is ready
max_retries = 8
wait_time = 1
for attempt in range(max_retries + 1):
try:
logger.info(
f"Testing {router_name} readiness (attempt {attempt + 1})"
)
# Generate small test token IDs # Generate small test token IDs
test_token_ids = [random.randint(1, 10000) for _ in range(10)] test_token_ids = [random.randint(1, 10000) for _ in range(10)]
stream = await router.generate(
token_ids=test_token_ids, # Small test # Initialize and check the readiness of the mockers by sending dummy request
model=MODEL_NAME, logger.info(f"Initializing {router_name} and mocker instances")
stop_conditions={"max_tokens": 1}, await send_request_via_python_kv_router(
) kv_python_router=router,
# Just consume the stream to verify it works token_ids=test_token_ids,
async for _ in stream: initial_wait=1.0,
pass max_retries=8,
logger.info(f"{router_name} is ready!") stop_conditions={"max_tokens": 1}, # Generate just 1 token
break
except Exception as e:
logger.warning(
f"{router_name} attempt {attempt + 1} failed: {e}"
)
if attempt < max_retries:
await asyncio.sleep(wait_time)
wait_time *= 2
else:
raise RuntimeError(
f"Failed to connect to {router_name} after retries"
) )
# Now send the actual requests # Now send the actual requests
tasks = [] tasks = []
for i in range(num_requests): for i in range(num_requests):
# Generate random token IDs for each request # Generate random token IDs for each request
logger.info(
f"Sending request {i + 1}/{num_requests} to {router_name}"
)
# Generate 30 random tokens
request_tokens = [random.randint(1, 10000) for _ in range(30)] request_tokens = [random.randint(1, 10000) for _ in range(30)]
async def single_request(req_id, tokens): # Send request to mocker via the router
try: tasks.append(
stream = await router.generate( asyncio.create_task(
token_ids=tokens, send_request_via_python_kv_router(
model=MODEL_NAME, kv_python_router=router,
stop_conditions={"max_tokens": 10}, token_ids=request_tokens,
initial_wait=1.0,
max_retries=8,
stop_conditions={
"ignore_eos": True, # Don't stop on EOS token
"max_tokens": 10, # Generate exactly 10 tokens
},
)
) )
# Consume the stream
async for _ in stream:
pass
return True
except Exception as e:
logger.error(
f"Request {req_id} to {router_name} failed: {e}"
) )
return False
tasks.append(asyncio.create_task(single_request(i, request_tokens)))
# Wait for all requests to complete
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r) successful = sum(1 for r in results if r)
logger.info( logger.info(
...@@ -898,6 +880,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -898,6 +880,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
) )
return successful return successful
# Launch first router
logger.info("Creating first KV router") logger.info("Creating first KV router")
kv_push_router1 = KvPushRouter( kv_push_router1 = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
...@@ -920,11 +903,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -920,11 +903,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
# Launch second router - will automatically sync with the first router's state # Launch second router - will automatically sync with the first router's state
logger.info("Creating second KV router") logger.info("Creating second KV router")
kv_router_config2 = KvRouterConfig(router_snapshot_threshold=20)
kv_push_router2 = KvPushRouter( kv_push_router2 = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
kv_router_config=kv_router_config2, kv_router_config=kv_router_config,
) )
# Send 25 requests to second router with initial retry loop # Send 25 requests to second router with initial retry loop
...@@ -995,7 +977,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -995,7 +977,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
"router2_state": state2_item, "router2_state": state2_item,
} }
) )
# If there are differences, format them for easier debugging
if differences: if differences:
error_msg = f"Router states are not equal. Found {len(differences)} differences:\n" error_msg = f"Router states are not equal. Found {len(differences)} differences:\n"
for diff in differences: for diff in differences:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment