Unverified Commit 57c701fb authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: expose prefill worker id in disagg (#4563)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 550bf98c
......@@ -623,14 +623,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
// Check if worker_id is requested in extra_fields
let should_populate_worker_id = backend_input
.extra_fields
.as_deref()
.unwrap_or(&[])
.iter()
.any(|s| s == "worker_id");
// Get prefill worker ID if available (stored by PrefillRouter)
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
let decode_worker_id = instance_id;
......@@ -672,24 +664,30 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
prefill_marked = true;
}
// Inject worker_id in first item's disaggregated_params if requested
if first_item && should_populate_worker_id {
if let Some(ref mut data) = item.data {
// Add worker_id to disaggregated_params
let worker_id_json = json!({
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
});
if let Some(ref mut params) = data.disaggregated_params {
if let Some(obj) = params.as_object_mut() {
obj.insert("worker_id".to_string(), worker_id_json);
}
} else {
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
}
}
// Always inject worker_id in first item's disaggregated_params
// This is needed for:
// 1. PrefillRouter to know which prefill worker was chosen
// 2. Client response when extra_fields contains "worker_id"
if first_item {
first_item = false;
let Some(ref mut data) = item.data else {
yield item;
continue;
};
// prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id
// decode_worker_id is always the current instance_id
let worker_id_json = json!({
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
});
if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
obj.insert("worker_id".to_string(), worker_id_json);
} else {
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
}
}
yield item;
......
......@@ -36,6 +36,7 @@ class KVRouterProcess(ManagedProcess):
frontend_port: int,
namespace: str,
store_backend: str = "etcd",
enforce_disagg: bool = False,
):
command = [
"python3",
......@@ -53,6 +54,9 @@ class KVRouterProcess(ManagedProcess):
namespace,
]
if enforce_disagg:
command.append("--enforce-disagg")
super().__init__(
command=command,
timeout=60,
......@@ -1490,6 +1494,196 @@ def _test_router_indexers_sync(
logger.info("Indexers sync test completed successfully")
def _test_router_disagg_decisions(
prefill_workers,
decode_workers,
block_size: int,
request,
frontend_port: int,
test_payload: dict,
store_backend: str = "etcd",
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
Assumes prefill_workers and decode_workers are already initialized. This function manages
router lifecycle and sends progressive requests with overlapping prefixes.
This test:
1. Starts the KV router frontend with disagg support
2. Sends 4 progressive requests where each extends the previous tokens by block_size
3. Extracts prefill_worker_id and decode_worker_id from response nvext
4. Verifies all prefill_worker_ids are the same (due to prefix reuse routing)
5. Verifies prefill_worker_id is NOT in the set of decode_worker_ids (true disagg)
Args:
prefill_workers: Prefill workers already initialized with __enter__()
decode_workers: Decode workers already initialized with __enter__()
block_size: Block size for KV cache
request: Pytest request fixture for managing resources
frontend_port: Port for the frontend HTTP server
test_payload: Base test payload to send to /v1/chat/completions
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
Raises:
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
"""
try:
# Start KV router frontend - uses decode_workers namespace for discovery
# The frontend will auto-discover both prefill and decode workers
logger.info(
f"Starting KV router frontend on port {frontend_port} for disagg test"
)
kv_router = KVRouterProcess(
request,
block_size,
frontend_port,
decode_workers.namespace,
store_backend,
enforce_disagg=True,
)
kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}"
chat_url = f"{frontend_url}/v1/chat/completions"
# Wait for workers to register with frontend
logger.info(
"Waiting for prefill and decode workers to register with frontend..."
)
asyncio.run(
wait_for_frontend_ready(
frontend_url=frontend_url,
expected_num_workers=decode_workers.num_workers,
timeout=120,
)
)
async def send_progressive_requests():
"""Send 4 progressive requests with overlapping prefixes and collect worker IDs."""
prefill_worker_ids = []
decode_worker_ids = []
# Generate base tokens for progressive prefix extension
base_content = test_payload["messages"][0]["content"]
async with aiohttp.ClientSession() as session:
for i in range(4):
# Build progressive content by repeating base content
# Each iteration adds more content to extend the prefix
progressive_content = " ".join([base_content] * (i + 1))
# Create payload with worker_id in extra_fields to get prefill/decode worker IDs
payload = {
**test_payload,
"messages": [
{
"role": "user",
"content": progressive_content,
}
],
"nvext": {"extra_fields": ["worker_id"]},
"stream": True,
}
logger.info(
f"Sending request {i + 1}/4 with progressive prefix "
f"(~{len(progressive_content)} chars)"
)
async with session.post(chat_url, json=payload) as response:
assert (
response.status == 200
), f"Request {i + 1} failed with status {response.status}"
# Collect all chunks and look for nvext with worker_id
prefill_wid = None
decode_wid = None
async for line in response.content:
if not line:
continue
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str.startswith("data:"):
continue
data_str = line_str[5:].strip()
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
# Check for nvext.worker_id in the response
nvext = data.get("nvext", {})
worker_id_info = nvext.get("worker_id", {})
if worker_id_info:
if "prefill_worker_id" in worker_id_info:
prefill_wid = worker_id_info[
"prefill_worker_id"
]
if "decode_worker_id" in worker_id_info:
decode_wid = worker_id_info["decode_worker_id"]
except json.JSONDecodeError:
continue
logger.info(
f"Request {i + 1}: prefill_worker_id={prefill_wid}, "
f"decode_worker_id={decode_wid}"
)
if prefill_wid is not None:
prefill_worker_ids.append(prefill_wid)
if decode_wid is not None:
decode_worker_ids.append(decode_wid)
# Small delay between requests
await asyncio.sleep(0.5)
return prefill_worker_ids, decode_worker_ids
# Run the progressive requests
prefill_ids, decode_ids = asyncio.run(send_progressive_requests())
logger.info(f"Collected prefill_worker_ids: {prefill_ids}")
logger.info(f"Collected decode_worker_ids: {decode_ids}")
# Verify we got worker IDs from all requests
assert len(prefill_ids) == 4, (
f"Expected 4 prefill_worker_ids, got {len(prefill_ids)}. "
f"Make sure nvext.extra_fields=['worker_id'] is being processed."
)
# Verify all prefill_worker_ids are the same (prefix reuse)
unique_prefill_ids = set(prefill_ids)
assert len(unique_prefill_ids) == 1, (
f"Expected all prefill requests to route to the same worker due to prefix reuse, "
f"but found {len(unique_prefill_ids)} unique prefill_worker_ids: {unique_prefill_ids}. "
f"Full list: {prefill_ids}"
)
# Verify prefill_worker_id is NOT in decode_worker_ids (true disagg)
unique_decode_ids = set(decode_ids)
prefill_id = prefill_ids[0]
assert prefill_id not in unique_decode_ids, (
f"Prefill worker {prefill_id} should NOT be in decode workers {unique_decode_ids}. "
f"This suggests disaggregated mode is not working correctly - "
f"prefill and decode should use separate worker pools."
)
logger.info(
f"Successfully verified disaggregated routing:\n"
f" - All 4 requests routed to same prefill_worker_id={prefill_id} (prefix reuse)\n"
f" - Prefill worker is NOT in decode worker set {unique_decode_ids} (true disagg)"
)
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
def _test_router_decisions(
engine_workers,
endpoint,
......
......@@ -10,6 +10,7 @@ from tests.router.common import ( # utilities
_test_python_router_bindings,
_test_router_basic,
_test_router_decisions,
_test_router_disagg_decisions,
_test_router_indexers_sync,
_test_router_overload_503,
_test_router_query_instance_id,
......@@ -61,6 +62,7 @@ def get_unique_ports(
"test_mocker_two_kv_router": 100,
"test_mocker_kv_router_overload_503": 200,
"test_query_instance_id_returns_worker_and_tokens": 300,
"test_router_disagg_decisions": 400,
}
base_offset = test_offsets.get(test_name, 0)
......@@ -87,8 +89,80 @@ TEST_PAYLOAD: Dict[str, Any] = {
}
def _build_mocker_command(
endpoint: str,
store_backend: str,
num_workers: int,
mocker_args: Dict[str, Any],
worker_type: Optional[str] = None,
) -> list[str]:
"""Build the mocker CLI command with all arguments.
Args:
endpoint: The dynamo endpoint string
store_backend: Storage backend ("etcd" or "file")
num_workers: Number of workers to spawn (uses --num-workers flag)
mocker_args: Dictionary of mocker arguments
worker_type: Optional worker type ("prefill" or "decode") for disagg mode
Returns:
List of command arguments for subprocess
"""
command = [
"python",
"-m",
"dynamo.mocker",
"--model-path",
MODEL_NAME,
"--endpoint",
endpoint,
"--store-kv",
store_backend,
"--num-workers",
str(num_workers),
]
# Add worker type flag for disaggregated mode
if worker_type == "prefill":
command.append("--is-prefill-worker")
elif worker_type == "decode":
command.append("--is-decode-worker")
# Add individual CLI arguments from mocker_args
if "speedup_ratio" in mocker_args:
command.extend(["--speedup-ratio", str(mocker_args["speedup_ratio"])])
if "block_size" in mocker_args:
command.extend(["--block-size", str(mocker_args["block_size"])])
if "num_gpu_blocks" in mocker_args:
command.extend(
["--num-gpu-blocks-override", str(mocker_args["num_gpu_blocks"])]
)
if "max_num_seqs" in mocker_args:
command.extend(["--max-num-seqs", str(mocker_args["max_num_seqs"])])
if "max_num_batched_tokens" in mocker_args:
command.extend(
["--max-num-batched-tokens", str(mocker_args["max_num_batched_tokens"])]
)
if "enable_prefix_caching" in mocker_args:
if mocker_args["enable_prefix_caching"]:
command.append("--enable-prefix-caching")
else:
command.append("--no-enable-prefix-caching")
if "enable_chunked_prefill" in mocker_args:
if mocker_args["enable_chunked_prefill"]:
command.append("--enable-chunked-prefill")
else:
command.append("--no-enable-chunked-prefill")
if "watermark" in mocker_args:
command.extend(["--watermark", str(mocker_args["watermark"])])
if "dp_size" in mocker_args:
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
return command
class MockerProcess:
"""Manages multiple mocker engine instances with the same namespace"""
"""Manages mocker engine instances with shared tokio runtime via --num-workers."""
def __init__(
self,
......@@ -97,90 +171,114 @@ class MockerProcess:
num_mockers: int = 1,
store_backend: str = "etcd",
):
# Generate a unique namespace suffix shared by all mockers
namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}"
self.component_name = "mocker"
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_mockers = num_mockers
self.num_workers = self.num_mockers # for compatibility with common.py
self.mocker_processes = []
# Default mocker args if not provided
if mocker_args is None:
mocker_args = {}
# Create multiple mocker processes with the same namespace
for i in range(num_mockers):
command = [
"python",
"-m",
"dynamo.mocker",
"--model-path",
MODEL_NAME,
"--endpoint",
self.endpoint,
"--store-kv",
store_backend,
]
# Add individual CLI arguments from mocker_args
if "speedup_ratio" in mocker_args:
command.extend(["--speedup-ratio", str(mocker_args["speedup_ratio"])])
if "block_size" in mocker_args:
command.extend(["--block-size", str(mocker_args["block_size"])])
if "num_gpu_blocks" in mocker_args:
command.extend(
["--num-gpu-blocks-override", str(mocker_args["num_gpu_blocks"])]
)
if "max_num_seqs" in mocker_args:
command.extend(["--max-num-seqs", str(mocker_args["max_num_seqs"])])
if "max_num_batched_tokens" in mocker_args:
command.extend(
[
"--max-num-batched-tokens",
str(mocker_args["max_num_batched_tokens"]),
]
)
if "enable_prefix_caching" in mocker_args:
if mocker_args["enable_prefix_caching"]:
command.append("--enable-prefix-caching")
else:
command.append("--no-enable-prefix-caching")
if "enable_chunked_prefill" in mocker_args:
if mocker_args["enable_chunked_prefill"]:
command.append("--enable-chunked-prefill")
else:
command.append("--no-enable-chunked-prefill")
if "watermark" in mocker_args:
command.extend(["--watermark", str(mocker_args["watermark"])])
if "dp_size" in mocker_args:
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
process = ManagedProcess(
command=command,
timeout=60,
display_output=True,
health_check_ports=[],
health_check_urls=[],
log_dir=request.node.name,
terminate_existing=False,
self.num_workers = num_mockers
mocker_args = mocker_args or {}
command = _build_mocker_command(
endpoint=self.endpoint,
store_backend=store_backend,
num_workers=num_mockers,
mocker_args=mocker_args,
)
self._process = ManagedProcess(
command=command,
timeout=60,
display_output=True,
health_check_ports=[],
health_check_urls=[],
log_dir=request.node.name,
terminate_existing=False,
)
logger.info(
f"Created mocker process with {num_mockers} worker(s), endpoint: {self.endpoint}"
)
def __enter__(self):
logger.info(f"Starting mocker process with {self.num_workers} worker(s)")
self._process.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logger.info("Stopping mocker process")
self._process.__exit__(exc_type, exc_val, exc_tb)
class DisaggMockerProcess:
"""Manages prefill or decode mocker instances for disaggregated serving.
Uses --num-workers for shared tokio runtime. For disaggregated serving:
- Prefill workers: worker_type="prefill", endpoint is namespace.prefill.generate
- Decode workers: worker_type="decode", endpoint is namespace.backend.generate
Both prefill and decode workers should share the same namespace for proper discovery.
"""
def __init__(
self,
request,
namespace: str,
worker_type: str,
mocker_args: Optional[Dict[str, Any]] = None,
num_mockers: int = 1,
store_backend: str = "etcd",
):
if worker_type not in ("prefill", "decode"):
raise ValueError(
f"worker_type must be 'prefill' or 'decode', got {worker_type}"
)
self.mocker_processes.append(process)
logger.info(f"Created mocker instance {i} with endpoint: {self.endpoint}")
self.namespace = namespace
self.worker_type = worker_type
self.num_workers = num_mockers
# Set component name and endpoint based on worker type
if worker_type == "prefill":
self.component_name = "prefill"
self.endpoint = f"dyn://{self.namespace}.prefill.generate"
else:
self.component_name = "backend"
self.endpoint = f"dyn://{self.namespace}.backend.generate"
mocker_args = mocker_args or {}
command = _build_mocker_command(
endpoint=self.endpoint,
store_backend=store_backend,
num_workers=num_mockers,
mocker_args=mocker_args,
worker_type=worker_type,
)
self._process = ManagedProcess(
command=command,
timeout=60,
display_output=True,
health_check_ports=[],
health_check_urls=[],
log_dir=request.node.name,
terminate_existing=False,
)
logger.info(
f"Created {worker_type} mocker process with {num_mockers} worker(s), "
f"endpoint: {self.endpoint}"
)
def __enter__(self):
"""Start all mocker processes"""
for i, process in enumerate(self.mocker_processes):
logger.info(f"Starting mocker instance {i}")
process.__enter__()
logger.info(
f"Starting {self.worker_type} mocker process with {self.num_workers} worker(s)"
)
self._process.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all mocker processes"""
for i, process in enumerate(self.mocker_processes):
logger.info(f"Stopping mocker instance {i}")
process.__exit__(exc_type, exc_val, exc_tb)
logger.info(f"Stopping {self.worker_type} mocker process")
self._process.__exit__(exc_type, exc_val, exc_tb)
@pytest.mark.pre_merge
......@@ -492,3 +590,71 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME)
def test_router_disagg_decisions(
request, runtime_services_session, predownload_tokenizers
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse.
"""
logger.info("Starting disaggregated router prefix reuse test")
# Generate shared namespace for prefill and decode workers
namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}"
# Create mocker args
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
prefill_workers = None
decode_workers = None
try:
# Start prefill workers (4 instances)
logger.info("Starting 4 prefill mocker instances")
prefill_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Start decode workers (4 instances)
logger.info("Starting 4 decode mocker instances")
decode_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
# Run disagg routing test
_test_router_disagg_decisions(
prefill_workers=prefill_workers,
decode_workers=decode_workers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
)
finally:
if decode_workers is not None:
decode_workers.__exit__(None, None, None)
if prefill_workers is not None:
prefill_workers.__exit__(None, None, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment