Unverified Commit a2154ba5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): batch live output signal sends (#7647)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 06f17011
...@@ -333,9 +333,6 @@ pub mod model { ...@@ -333,9 +333,6 @@ pub mod model {
/// KV Router configuration environment variables /// KV Router configuration environment variables
pub mod router { pub mod router {
/// Minimum number of workers required before KV router startup continues.
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";
/// Queue threshold fraction for prefill token capacity. /// Queue threshold fraction for prefill token capacity.
/// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens. /// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
pub const DYN_ROUTER_QUEUE_THRESHOLD: &str = "DYN_ROUTER_QUEUE_THRESHOLD"; pub const DYN_ROUTER_QUEUE_THRESHOLD: &str = "DYN_ROUTER_QUEUE_THRESHOLD";
...@@ -495,7 +492,6 @@ mod tests { ...@@ -495,7 +492,6 @@ mod tests {
model::huggingface::HF_HOME, model::huggingface::HF_HOME,
model::huggingface::HF_HUB_OFFLINE, model::huggingface::HF_HUB_OFFLINE,
// Router // Router
router::DYN_ROUTER_MIN_INITIAL_WORKERS,
router::DYN_ROUTER_QUEUE_THRESHOLD, router::DYN_ROUTER_QUEUE_THRESHOLD,
router::DYN_ROUTER_QUEUE_POLICY, router::DYN_ROUTER_QUEUE_POLICY,
// Event Plane // Event Plane
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import contextlib
import json import json
import logging import logging
import os
import random import random
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
...@@ -36,6 +38,20 @@ logger = logging.getLogger(__name__) ...@@ -36,6 +38,20 @@ logger = logging.getLogger(__name__)
NUM_REQUESTS = 100 NUM_REQUESTS = 100
BLOCK_SIZE = 16 BLOCK_SIZE = 16
MIN_INITIAL_WORKERS_ENV = "DYN_ROUTER_MIN_INITIAL_WORKERS"
@contextlib.contextmanager
def min_initial_workers_env(min_initial_workers: int):
previous = os.environ.get(MIN_INITIAL_WORKERS_ENV)
os.environ[MIN_INITIAL_WORKERS_ENV] = str(min_initial_workers)
try:
yield
finally:
if previous is None:
os.environ.pop(MIN_INITIAL_WORKERS_ENV, None)
else:
os.environ[MIN_INITIAL_WORKERS_ENV] = previous
######################################################## ########################################################
...@@ -55,6 +71,7 @@ def _test_router_basic( ...@@ -55,6 +71,7 @@ def _test_router_basic(
request_plane: str = "nats", request_plane: str = "nats",
router_mode: str = "kv", router_mode: str = "kv",
enforce_disagg: bool = False, enforce_disagg: bool = False,
min_initial_workers: int | None = None,
): ):
"""Basic router test: start router, wait for workers and send concurrent requests via HTTP frontend. """Basic router test: start router, wait for workers and send concurrent requests via HTTP frontend.
...@@ -78,6 +95,7 @@ def _test_router_basic( ...@@ -78,6 +95,7 @@ def _test_router_basic(
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "nats". request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "nats".
router_mode: Router mode ("kv", "round-robin", "random", "power-of-two", "direct"). Defaults to "kv". router_mode: Router mode ("kv", "round-robin", "random", "power-of-two", "direct"). Defaults to "kv".
enforce_disagg: Whether to pass --enforce-disagg to the frontend. Defaults to False. enforce_disagg: Whether to pass --enforce-disagg to the frontend. Defaults to False.
min_initial_workers: Optional frontend startup worker gate. Defaults to None.
Raises: Raises:
AssertionError: If requests fail or frontend doesn't become ready AssertionError: If requests fail or frontend doesn't become ready
...@@ -92,6 +110,7 @@ def _test_router_basic( ...@@ -92,6 +110,7 @@ def _test_router_basic(
enforce_disagg=enforce_disagg, enforce_disagg=enforce_disagg,
request_plane=request_plane, request_plane=request_plane,
router_mode=router_mode, router_mode=router_mode,
min_initial_workers=min_initial_workers,
): ):
# Start router frontend # Start router frontend
logger.info( logger.info(
...@@ -327,14 +346,15 @@ def _test_python_router_bindings( ...@@ -327,14 +346,15 @@ def _test_python_router_bindings(
AssertionError: If requests fail or router doesn't work correctly AssertionError: If requests fail or router doesn't work correctly
""" """
# Create KvRouterConfig with default settings # Create KvRouterConfig with default settings
kv_router_config = KvRouterConfig(min_initial_workers=num_workers) kv_router_config = KvRouterConfig()
# Create KvRouter Python object # Create KvRouter Python object
kv_router = KvRouter( with min_initial_workers_env(num_workers):
endpoint=endpoint, kv_router = KvRouter(
block_size=block_size, endpoint=endpoint,
kv_router_config=kv_router_config, block_size=block_size,
) kv_router_config=kv_router_config,
)
logger.info("Created KvRouter Python object") logger.info("Created KvRouter Python object")
...@@ -620,6 +640,7 @@ def _test_router_overload_503( ...@@ -620,6 +640,7 @@ def _test_router_overload_503(
namespace=engine_workers.namespace, namespace=engine_workers.namespace,
blocks_threshold=blocks_threshold, blocks_threshold=blocks_threshold,
): ):
frontend_url = f"http://localhost:{frontend_port}"
url = f"http://localhost:{frontend_port}/v1/chat/completions" url = f"http://localhost:{frontend_port}/v1/chat/completions"
# Custom payload for 503 test with more tokens to consume resources # Custom payload for 503 test with more tokens to consume resources
...@@ -628,86 +649,104 @@ def _test_router_overload_503( ...@@ -628,86 +649,104 @@ def _test_router_overload_503(
"max_tokens": 50, # Longer output to consume more blocks "max_tokens": 50, # Longer output to consume more blocks
} }
# First, send one request with retry to ensure system is ready logger.info("Waiting for frontend readiness before overload test...")
logger.info("Sending initial request to ensure system is ready...") asyncio.run(
asyncio.run(send_inflight_requests([url], test_payload_503, 1)) wait_for_frontend_ready(
frontend_url=frontend_url,
expected_num_workers=1,
timeout=60,
)
)
# Now send 50 concurrent requests to exhaust resources, then verify 503 logger.info("Launching streaming requests until the router returns 503...")
logger.info("Sending 50 concurrent requests to exhaust resources...")
async def exhaust_resources_and_verify_503(): async def exhaust_resources_and_verify_503():
stop_event = asyncio.Event()
overload_response = {}
unexpected_statuses = []
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Start 50 long-running requests concurrently
tasks = [] tasks = []
for i in range(50):
# Create unique shuffled content for each request
content_words = test_payload["messages"][0]["content"].split()
random.shuffle(content_words)
shuffled_content = " ".join(content_words)
# Create unique payload for this request
unique_payload = {
**test_payload,
"max_tokens": 50,
"messages": [
{**test_payload["messages"][0], "content": shuffled_content}
],
}
async def send_long_request(req_id, payload):
try:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Don't read the response fully, just hold the connection
await asyncio.sleep(
10
) # Hold connection for 10 seconds
return True
else:
logger.info(
f"Request {req_id} got status {response.status}"
)
return False
except Exception as e:
logger.info(f"Request {req_id} failed: {e}")
return False
tasks.append( async def send_request(req_id, payload):
asyncio.create_task(send_long_request(i, unique_payload)) try:
) async with session.post(url, json=payload) as response:
if response.status == 200:
logger.info(f"Request {req_id} accepted")
await stop_event.wait()
return response.status
if response.status == 503:
body = await response.json()
logger.info(
f"Request {req_id} got expected 503: {body}"
)
overload_response["status"] = response.status
overload_response["body"] = body
stop_event.set()
return response.status
# Wait briefly to ensure requests are in-flight body = await response.text()
await asyncio.sleep(0.8) logger.info(
f"Request {req_id} got unexpected status {response.status}: {body}"
)
unexpected_statuses.append((response.status, body))
return response.status
except asyncio.CancelledError:
raise
except Exception as e:
logger.info(f"Request {req_id} failed: {e}")
unexpected_statuses.append(("exception", str(e)))
return None
# Now send one more request that should get 503
logger.info("Sending additional request that should receive 503...")
try: try:
async with session.post(url, json=test_payload_503) as response: for i in range(50):
status_code = response.status if stop_event.is_set():
if status_code == 503: break
body = await response.json()
logger.info(f"Got expected 503 response: {body}") content_words = test_payload["messages"][0]["content"].split()
error_msg = body.get("message", "") random.shuffle(content_words)
assert ( shuffled_content = " ".join(content_words)
"Service temporarily unavailable" in error_msg unique_payload = {
or "All workers are busy" in error_msg **test_payload_503,
), f"Expected service overload error message, got: {body}" "messages": [
return True {
else: **test_payload["messages"][0],
logger.error(f"Expected 503 but got {status_code}") "content": shuffled_content,
if status_code == 200: }
logger.error( ],
"Request unexpectedly succeeded when it should have been rejected" }
) tasks.append(
return False asyncio.create_task(send_request(i, unique_payload))
except Exception as e: )
logger.error(f"Failed to send overload test request: {e}") await asyncio.sleep(0.1)
return False
if not stop_event.is_set():
try:
await asyncio.wait_for(stop_event.wait(), timeout=10)
except asyncio.TimeoutError:
logger.error("Timed out waiting for overload 503")
finally: finally:
# Cancel all background tasks stop_event.set()
for task in tasks: done, pending = await asyncio.wait(tasks, timeout=3)
for task in pending:
task.cancel() task.cancel()
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*pending, return_exceptions=True)
for task in done:
task.result()
if overload_response.get("status") != 503:
logger.error(
f"Observed statuses before timeout: {unexpected_statuses}"
)
return False
error_msg = overload_response["body"].get("message", "")
assert (
"Service temporarily unavailable" in error_msg
or "All workers are busy" in error_msg
), f"Expected service overload error message, got: {overload_response['body']}"
return True
# Run the test # Run the test
success = asyncio.run(exhaust_resources_and_verify_503()) success = asyncio.run(exhaust_resources_and_verify_503())
...@@ -830,7 +869,6 @@ def _test_router_indexers_sync( ...@@ -830,7 +869,6 @@ def _test_router_indexers_sync(
router_snapshot_threshold=20, router_snapshot_threshold=20,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
min_initial_workers=num_workers,
) )
# If standalone indexer mode, launch mockers one-by-one and register. # If standalone indexer mode, launch mockers one-by-one and register.
...@@ -884,11 +922,12 @@ def _test_router_indexers_sync( ...@@ -884,11 +922,12 @@ def _test_router_indexers_sync(
f"{engine_workers.namespace}.{engine_workers.component_name}.generate" f"{engine_workers.namespace}.{engine_workers.component_name}.generate"
) )
kv_router1 = KvRouter( with min_initial_workers_env(num_workers):
endpoint=endpoint1, kv_router1 = KvRouter(
block_size=block_size, endpoint=endpoint1,
kv_router_config=kv_router_config, block_size=block_size,
) kv_router_config=kv_router_config,
)
# Wait for workers to be ready # Wait for workers to be ready
await wait_for_workers_ready(endpoint1, kv_router1, num_workers, model_name) await wait_for_workers_ready(endpoint1, kv_router1, num_workers, model_name)
...@@ -975,11 +1014,12 @@ def _test_router_indexers_sync( ...@@ -975,11 +1014,12 @@ def _test_router_indexers_sync(
f"{engine_workers.namespace}.{engine_workers.component_name}.generate" f"{engine_workers.namespace}.{engine_workers.component_name}.generate"
) )
kv_router2 = KvRouter( with min_initial_workers_env(num_workers):
endpoint=endpoint2, kv_router2 = KvRouter(
block_size=block_size, endpoint=endpoint2,
kv_router_config=kv_router_config, block_size=block_size,
) kv_router_config=kv_router_config,
)
# Launch Indexer B alongside Router 2. Workers are passed via --workers # Launch Indexer B alongside Router 2. Workers are passed via --workers
# so ZMQ sockets connect before recovery, avoiding the slow-joiner problem. # so ZMQ sockets connect before recovery, avoiding the slow-joiner problem.
...@@ -1256,6 +1296,7 @@ def _test_router_decisions_disagg( ...@@ -1256,6 +1296,7 @@ def _test_router_decisions_disagg(
enforce_disagg=True, enforce_disagg=True,
request_plane=request_plane, request_plane=request_plane,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
min_initial_workers=decode_workers.num_workers,
): ):
# Start KV router frontend - uses decode_workers namespace for discovery # Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers # The frontend will auto-discover both prefill and decode workers
...@@ -1483,13 +1524,13 @@ def _test_router_decisions( ...@@ -1483,13 +1524,13 @@ def _test_router_decisions(
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
router_event_threads=router_event_threads, router_event_threads=router_event_threads,
min_initial_workers=expected_num_instances,
)
kv_router = KvRouter(
endpoint=endpoint,
block_size=block_size,
kv_router_config=kv_router_config,
) )
with min_initial_workers_env(expected_num_instances):
kv_router = KvRouter(
endpoint=endpoint,
block_size=block_size,
kv_router_config=kv_router_config,
)
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready( worker_ids = await wait_for_workers_ready(
......
...@@ -728,6 +728,7 @@ def test_mocker_router( ...@@ -728,6 +728,7 @@ def test_mocker_router(
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
request_plane=request_plane, request_plane=request_plane,
router_mode=router_mode, router_mode=router_mode,
min_initial_workers=mockers.num_workers,
) )
...@@ -801,7 +802,7 @@ def test_mocker_kv_router_overload_503( ...@@ -801,7 +802,7 @@ def test_mocker_kv_router_overload_503(
logger.info("Starting mocker KV router overload test for 503 status") logger.info("Starting mocker KV router overload test for 503 status")
# Create mocker args dictionary with limited resources - use local indexer (NATS Core mode) # Create mocker args dictionary with limited resources - use local indexer (NATS Core mode)
mocker_args = { mocker_args = {
"speedup_ratio": 10, "speedup_ratio": 0.01,
"block_size": 4, # Smaller block size "block_size": 4, # Smaller block size
"num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly "num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly
"durable_kv_events": durable_kv_events, "durable_kv_events": durable_kv_events,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment