Unverified Commit e7a73cb1 authored by Paul Li's avatar Paul Li Committed by GitHub
Browse files

test: Simplify test_indexers_sync() with the new helper send_request_via_python_kv_router() (#3339)


Signed-off-by: default avatarPaul Li <zhixiong2008@gmail.com>
parent 031dc589
...@@ -314,7 +314,7 @@ async def send_request_via_python_kv_router( ...@@ -314,7 +314,7 @@ async def send_request_via_python_kv_router(
sampling_options=sampling_options, sampling_options=sampling_options,
output_options=output_options, output_options=output_options,
router_config_override=router_config_override, router_config_override=router_config_override,
# worker_id=worker_id, worker_id=worker_id,
) )
if stream is not None: if stream is not None:
...@@ -816,9 +816,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -816,9 +816,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
) )
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
# Initialize mockers
mockers.__enter__() mockers.__enter__()
# Run the async test # Use async to manage the test flow
async def test_sync(): async def test_sync():
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime()
...@@ -827,70 +828,51 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -827,70 +828,51 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
component = namespace.component("mocker") component = namespace.component("mocker")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
# Create first KV router # Create KvRouterConfig with lower snapshot threshold for testing
from dynamo._core import KvPushRouter, KvRouterConfig
kv_router_config = KvRouterConfig(router_snapshot_threshold=20) kv_router_config = KvRouterConfig(router_snapshot_threshold=20)
async def send_requests_to_router(router, num_requests, router_name): async def send_requests_to_router(router, num_requests, router_name):
# First, send a test request with retry to ensure router is ready # Generate small test token IDs
max_retries = 8 test_token_ids = [random.randint(1, 10000) for _ in range(10)]
wait_time = 1
# Initialize and check the readiness of the mockers by sending dummy request
for attempt in range(max_retries + 1): logger.info(f"Initializing {router_name} and mocker instances")
try: await send_request_via_python_kv_router(
logger.info( kv_python_router=router,
f"Testing {router_name} readiness (attempt {attempt + 1})" token_ids=test_token_ids,
) initial_wait=1.0,
# Generate small test token IDs max_retries=8,
test_token_ids = [random.randint(1, 10000) for _ in range(10)] stop_conditions={"max_tokens": 1}, # Generate just 1 token
stream = await router.generate( )
token_ids=test_token_ids, # Small test
model=MODEL_NAME,
stop_conditions={"max_tokens": 1},
)
# Just consume the stream to verify it works
async for _ in stream:
pass
logger.info(f"{router_name} is ready!")
break
except Exception as e:
logger.warning(
f"{router_name} attempt {attempt + 1} failed: {e}"
)
if attempt < max_retries:
await asyncio.sleep(wait_time)
wait_time *= 2
else:
raise RuntimeError(
f"Failed to connect to {router_name} after retries"
)
# Now send the actual requests # Now send the actual requests
tasks = [] tasks = []
for i in range(num_requests): for i in range(num_requests):
# Generate random token IDs for each request # Generate random token IDs for each request
logger.info(
f"Sending request {i + 1}/{num_requests} to {router_name}"
)
# Generate 30 random tokens
request_tokens = [random.randint(1, 10000) for _ in range(30)] request_tokens = [random.randint(1, 10000) for _ in range(30)]
async def single_request(req_id, tokens): # Send request to mocker via the router
try: tasks.append(
stream = await router.generate( asyncio.create_task(
token_ids=tokens, send_request_via_python_kv_router(
model=MODEL_NAME, kv_python_router=router,
stop_conditions={"max_tokens": 10}, token_ids=request_tokens,
) initial_wait=1.0,
# Consume the stream max_retries=8,
async for _ in stream: stop_conditions={
pass "ignore_eos": True, # Don't stop on EOS token
return True "max_tokens": 10, # Generate exactly 10 tokens
except Exception as e: },
logger.error(
f"Request {req_id} to {router_name} failed: {e}"
) )
return False )
)
tasks.append(asyncio.create_task(single_request(i, request_tokens)))
# Wait for all requests to complete
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r) successful = sum(1 for r in results if r)
logger.info( logger.info(
...@@ -898,6 +880,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -898,6 +880,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
) )
return successful return successful
# Launch first router
logger.info("Creating first KV router") logger.info("Creating first KV router")
kv_push_router1 = KvPushRouter( kv_push_router1 = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
...@@ -920,11 +903,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -920,11 +903,10 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
# Launch second router - will automatically sync with the first router's state # Launch second router - will automatically sync with the first router's state
logger.info("Creating second KV router") logger.info("Creating second KV router")
kv_router_config2 = KvRouterConfig(router_snapshot_threshold=20)
kv_push_router2 = KvPushRouter( kv_push_router2 = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
kv_router_config=kv_router_config2, kv_router_config=kv_router_config,
) )
# Send 25 requests to second router with initial retry loop # Send 25 requests to second router with initial retry loop
...@@ -995,7 +977,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers): ...@@ -995,7 +977,7 @@ def test_indexers_sync(request, runtime_services, predownload_tokenizers):
"router2_state": state2_item, "router2_state": state2_item,
} }
) )
# If there are differences, format them for easier debugging
if differences: if differences:
error_msg = f"Router states are not equal. Found {len(differences)} differences:\n" error_msg = f"Router states are not equal. Found {len(differences)} differences:\n"
for diff in differences: for diff in differences:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment