Unverified Commit 8af634ba authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(global-router): add aggregated mode with TTFT x ITL pool selection (#8044)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7ffc699f
......@@ -5,15 +5,16 @@ SPDX-License-Identifier: Apache-2.0
# Global Router
A hierarchical routing service that sits between the Dynamo frontend and local routers in different pool namespaces. The global router enables disaggregated serving with flexible pool selection based on request characteristics.
A hierarchical routing service that sits between the Dynamo frontend and local routers in different pool namespaces. The global router supports both disaggregated and aggregated serving with flexible pool selection based on request characteristics.
## Overview
The Global Router acts as both a prefill and decode worker from the frontend's perspective:
- Registers with `ModelType.Prefill` for prefill requests
- Registers with `ModelType.Chat | ModelType.Completions` for decode requests
The Global Router supports two modes:
Internally, it routes requests to local routers in different namespaces based on a configurable grid-based selection strategy.
- **Disagg mode** (default): Registers as both prefill and decode worker. Routes prefill requests based on (ISL, TTFT) and decode requests based on (context_length, ITL) to separate pool types.
- **Agg mode**: Registers as a single generate worker. Routes all requests based on (TTFT target, ITL target) to unified pools that handle both prefill and decode.
Both modes support priority-based pool overrides from agent hints.
## Supported Backends
......@@ -26,6 +27,8 @@ Internally, it routes requests to local routers in different namespaces based on
## Architecture
### Disagg Mode
```
Frontend
|
......@@ -35,22 +38,39 @@ Global Router (registers as both prefill + decode)
+---> Prefill Pool 0 (namespace: prefill_pool_0)
| |
| +---> Local Router ---> Prefill Worker 0
| |
| +---> Prefill Worker 1
| |
| +---> ...
|
+---> Prefill Pool ...
|
+---> Decode Pool 0 (namespace: decode_pool_0)
| |
| +---> Local Router ---> Decode Worker 0
| |
| +---> Decode Worker 1
| |
| +---> ...
|
+---> Decode Pool ...
```
### Agg Mode
```
Frontend
|
v
Global Router (registers as Chat + Completions)
|
+---> Agg Pool 0 (namespace: agg_pool_0)
| |
| +---> Local Router ---> Worker 0 (prefill + decode)
| +---> Worker 1 (prefill + decode)
|
+---> Agg Pool 1 (namespace: agg_pool_1)
| |
| +---> Local Router ---> Worker 0 (prefill + decode)
| +---> Worker 1 (prefill + decode)
|
+---> Agg Pool ...
```
## Usage
```bash
......@@ -71,46 +91,85 @@ All options can be set via CLI flags or environment variables. CLI flags take pr
| `--namespace` | No | `DYN_NAMESPACE` | "dynamo" | Namespace for global router |
| `--component-name` | No | `DYN_GLOBAL_ROUTER_COMPONENT_NAME` | "global_router" | Component name |
| `--default-ttft-target` | No | `DYN_GLOBAL_ROUTER_DEFAULT_TTFT_TARGET` | None | Default TTFT target (ms) for prefill pool selection |
| `--default-itl-target` | No | `DYN_GLOBAL_ROUTER_DEFAULT_ITL_TARGET` | None | Default ITL target (ms) for decode pool selection |
| `--default-itl-target` | No | `DYN_GLOBAL_ROUTER_DEFAULT_ITL_TARGET` | None | Default ITL target (ms) for pool selection |
## Configuration
The configuration file defines pool namespaces and selection strategies:
The configuration file format depends on the mode. The `mode` field determines which mode is used; if omitted, it defaults to `"disagg"`.
### Disagg Mode Configuration
```jsonc
{
"num_prefill_pools": <int>, // Number of prefill pools
"num_decode_pools": <int>, // Number of decode pools
"prefill_pool_dynamo_namespaces": [], // List of Dynamo namespaces for each prefill pool
"decode_pool_dynamo_namespaces": [], // List of Dynamo namespaces for each decode pool
"mode": "disagg", // Optional, defaults to "disagg"
"num_prefill_pools": <int>,
"num_decode_pools": <int>,
"prefill_pool_dynamo_namespaces": [],
"decode_pool_dynamo_namespaces": [],
"prefill_pool_selection_strategy": {
"isl_min": <int>, // Minimum input sequence length (tokens)
"isl_max": <int>, // Maximum input sequence length (tokens)
"isl_resolution": <int>, // Number of grid rows for ISL dimension
"ttft_min": <float>, // Minimum TTFT target (ms)
"ttft_max": <float>, // Maximum TTFT target (ms)
"ttft_resolution": <int>, // Number of grid columns for TTFT dimension
"prefill_pool_mapping": [[]] // 2D array [isl_resolution][ttft_resolution] -> pool index
"isl_min": <int>,
"isl_max": <int>,
"isl_resolution": <int>,
"ttft_min": <float>,
"ttft_max": <float>,
"ttft_resolution": <int>,
"prefill_pool_mapping": [[]], // 2D array [isl_resolution][ttft_resolution] -> pool index
"priority_overrides": [] // Optional
},
"decode_pool_selection_strategy": {
"context_length_min": <int>, // Minimum context length (tokens)
"context_length_max": <int>, // Maximum context length (tokens)
"context_length_resolution": <int>, // Number of grid rows for context length
"context_length_min": <int>,
"context_length_max": <int>,
"context_length_resolution": <int>,
"itl_min": <float>,
"itl_max": <float>,
"itl_resolution": <int>,
"decode_pool_mapping": [[]], // 2D array [context_length_resolution][itl_resolution] -> pool index
"priority_overrides": [] // Optional
}
}
```
### Agg Mode Configuration
```jsonc
{
"mode": "agg",
"num_agg_pools": <int>,
"agg_pool_dynamo_namespaces": [],
"agg_pool_selection_strategy": {
"ttft_min": <float>, // Minimum TTFT target (ms)
"ttft_max": <float>, // Maximum TTFT target (ms)
"ttft_resolution": <int>, // Number of grid rows for TTFT dimension
"itl_min": <float>, // Minimum ITL target (ms)
"itl_max": <float>, // Maximum ITL target (ms)
"itl_resolution": <int>, // Number of grid columns for ITL dimension
"decode_pool_mapping": [[]] // 2D array [context_length_resolution][itl_resolution] -> pool index
"agg_pool_mapping": [[]], // 2D array [ttft_resolution][itl_resolution] -> pool index
"priority_overrides": [] // Optional
}
}
```
### Why TTFT x ITL for Agg Mode
In aggregated mode, the same pool handles both prefill and decode. Both SLA targets matter for a single routing decision:
- **TTFT target** captures the user's prefill latency requirement. ISL is implicitly accounted for — a user sending a large prompt with a tight TTFT target is saying "I need a fast pool."
- **ITL target** captures the user's decode latency requirement. With chunked prefill, ITL reflects the combined prefill+decode contention. Without chunked prefill, ITL reflects pure decode performance.
This creates natural pool separation:
- Tight TTFT + tight ITL -> premium interactive pool
- Relaxed TTFT + tight ITL -> decode-optimized pool
- Tight TTFT + relaxed ITL -> prefill-optimized pool
- Relaxed TTFT + relaxed ITL -> batch/throughput pool
### Pool Selection
The pool selection uses a 2D grid lookup. Each dimension is divided into buckets based on the resolution.
**Prefill Pool Selection** (based on ISL and TTFT target):
**Prefill Pool Selection** (disagg mode, based on ISL and TTFT target):
1. Compute `isl_step = (isl_max - isl_min) / isl_resolution`
2. Compute `ttft_step = (ttft_max - ttft_min) / ttft_resolution`
......@@ -119,19 +178,42 @@ The pool selection uses a 2D grid lookup. Each dimension is divided into buckets
- `ttft_idx = clamp((ttft_target - ttft_min) / ttft_step, 0, ttft_resolution - 1)`
4. Lookup pool: `pool_index = prefill_pool_mapping[isl_idx][ttft_idx]`
**Decode Pool Selection** (based on context length and ITL target):
**Decode Pool Selection** (disagg mode, based on context length and ITL target):
Same logic but using `context_length` and `itl_target` with `decode_pool_mapping`.
**Example**: With `isl_min=0`, `isl_max=32000`, `isl_resolution=2`:
- ISL in [0, 16000) → `isl_idx = 0`
- ISL in [16000, 32000] → `isl_idx = 1`
**Agg Pool Selection** (agg mode, based on TTFT and ITL targets):
If `prefill_pool_mapping = [[0, 1], [0, 1]]` and `ttft_resolution=2`:
- Low ISL + Low TTFT target → pool 0
- Low ISL + High TTFT target → pool 1
- High ISL + Low TTFT target → pool 0
- High ISL + High TTFT target → pool 1
Same grid logic using `ttft_target` and `itl_target` with `agg_pool_mapping`.
### Priority-Based Pool Override
All strategies support optional `priority_overrides` rules. When a request carries a priority value (from `nvext.agent_hints.priority`), the global router evaluates the override rules after the grid lookup. The first rule whose `[min_priority, max_priority]` range contains the request priority wins, and the request is routed to that rule's `target_pool` instead of the grid result. If no rule matches (or no priority is present), the grid result is used as normal.
This is useful for straggler mitigation in RL workloads: the RL framework can tag slow requests with a high priority, and the global router redirects them to a dedicated min-latency pool.
```jsonc
"priority_overrides": [
{
"min_priority": 10, // inclusive lower bound
"max_priority": 100, // inclusive upper bound
"target_pool": 1 // pool index to route to
}
]
```
Priority is set by the client via the NVIDIA OpenAI extension:
```json
{
"messages": [...],
"nvext": {
"agent_hints": {
"priority": 50
}
}
}
```
### Priority-Based Pool Override
......@@ -178,26 +260,36 @@ Clients can pass TTFT and ITL targets via `extra_args` in the request:
{
"messages": [...],
"extra_args": {
"ttft_target": 100, // Target TTFT in ms for prefill pool selection
"itl_target": 20 // Target ITL in ms for decode pool selection
"ttft_target": 100,
"itl_target": 20
}
}
```
If not provided, the middle of the configured range is used as default.
If not provided, the middle of the configured range is used as default. For disagg mode, `ttft_target` drives prefill pool selection and `itl_target` drives decode pool selection. For agg mode, both `ttft_target` and `itl_target` drive pool selection.
## Request Flow
### Disagg Mode
1. Frontend receives request and sends to Global Router (registered as prefill)
2. Global Router selects prefill pool based on (ISL, TTFT_target)
2. Global Router selects prefill pool based on (ISL, TTFT_target, priority)
3. Request is forwarded to local router in the selected prefill pool namespace
4. Local router forwards to a prefill worker
5. Prefill response returns with `disaggregated_params`
6. Frontend sends decode request to Global Router (registered as decode)
7. Global Router selects decode pool based on (context_length, ITL_target)
7. Global Router selects decode pool based on (context_length, ITL_target, priority)
8. Request is forwarded to local router in the selected decode pool namespace
9. Tokens stream back through the chain
### Agg Mode
1. Frontend receives request and sends to Global Router (registered as Chat + Completions)
2. Global Router selects agg pool based on (TTFT_target, ITL_target, priority)
3. Request is forwarded to local router in the selected agg pool namespace
4. Local router forwards to a worker that handles both prefill and decode
5. Tokens stream back through the chain
## Example
See `examples/global_planner/` for a complete example with:
......
......@@ -6,15 +6,17 @@ Global Router Service for Hierarchical Routing
Usage: python -m dynamo.global_router --config <config.json> --model-name <model>
This service acts as both a prefill and decode worker from the frontend's perspective,
but internally routes requests to local routers in different namespaces based on
a grid-based pool selection strategy.
Key features:
- Registers as BOTH prefill AND decode worker via register_model()
- Routes prefill requests based on (ISL, TTFT) to prefill pools
- Routes decode requests based on (context_length, ITL) to decode pools
- Connects to local routers in each pool's namespace
This service routes requests to local routers in different namespaces based on
a grid-based pool selection strategy. It supports two modes:
- "disagg" mode: Registers as BOTH prefill AND decode worker. Routes prefill
requests based on (ISL, TTFT) and decode requests based on (context_length, ITL)
to separate pool types.
- "agg" mode: Registers as a single generate worker. Routes all requests based
on (ISL, ITL) to unified pools that handle both prefill and decode.
Both modes support priority-based pool overrides from agent hints.
"""
import argparse
......@@ -37,7 +39,7 @@ logger = logging.getLogger(__name__)
def parse_args() -> DynamoGlobalRouterConfig:
"""Parse command-line arguments for the Global Router service."""
parser = argparse.ArgumentParser(
description="Dynamo Global Router Service: Hierarchical routing to prefill/decode pools",
description="Dynamo Global Router Service: Hierarchical routing to worker pools",
formatter_class=argparse.RawTextHelpFormatter,
)
DynamoGlobalRouterArgGroup().add_arguments(parser)
......@@ -72,8 +74,24 @@ async def worker(runtime: DistributedRuntime):
# Initialize connections to local routers
await handler.initialize()
# Create endpoints for prefill and decode
# Note: We use separate endpoints so we can register them with different ModelTypes
logger.info(f"Mode: {handler.config.mode}")
logger.info(f"Pool info: {handler.get_pool_info()}")
if handler.config.mode == "disagg":
await _serve_disagg(runtime, config, handler)
elif handler.config.mode == "agg":
await _serve_agg(runtime, config, handler)
else:
raise ValueError(f"Unknown mode: {handler.config.mode}")
async def _serve_disagg(
runtime: DistributedRuntime,
config: DynamoGlobalRouterConfig,
handler: GlobalRouterHandler,
) -> None:
"""Register and serve disagg-mode endpoints (prefill + decode)."""
assert config.model_name is not None
prefill_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.prefill_generate"
)
......@@ -82,8 +100,6 @@ async def worker(runtime: DistributedRuntime):
)
logger.info("Registering as prefill worker...")
# Register as prefill worker - frontend will send prefill requests here
# Use model_name as model_path since we don't need tokenizer/model files
await register_model(
model_input=ModelInput.Tokens,
model_type=ModelType.Prefill,
......@@ -96,7 +112,6 @@ async def worker(runtime: DistributedRuntime):
)
logger.info("Registering as decode worker...")
# Register as decode worker - frontend will send decode requests here
await register_model(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
......@@ -108,25 +123,70 @@ async def worker(runtime: DistributedRuntime):
f"Registered decode endpoint: {config.namespace}.{config.component_name}.decode_generate"
)
logger.info("Global Router ready - serving endpoints...")
logger.info(f"Pool info: {handler.get_pool_info()}")
logger.info("Global Router ready (disagg mode) - serving endpoints...")
# Serve both endpoints concurrently
try:
await asyncio.gather(
prefill_endpoint.serve_endpoint(
handler.handle_prefill,
graceful_shutdown=True,
metrics_labels=[("service", "global_router"), ("type", "prefill")],
metrics_labels=[
("service", "global_router"),
("type", "prefill"),
],
),
decode_endpoint.serve_endpoint(
handler.handle_decode,
graceful_shutdown=True,
metrics_labels=[("service", "global_router"), ("type", "decode")],
metrics_labels=[
("service", "global_router"),
("type", "decode"),
],
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
logger.error(f"Failed to serve disagg endpoints: {e}")
raise
finally:
logger.info("Global Router Service shutting down")
async def _serve_agg(
runtime: DistributedRuntime,
config: DynamoGlobalRouterConfig,
handler: GlobalRouterHandler,
) -> None:
"""Register and serve agg-mode endpoint (single generate)."""
assert config.model_name is not None
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.generate"
)
logger.info("Registering as agg worker (Chat + Completions)...")
await register_model(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=generate_endpoint,
model_path=config.model_name,
model_name=config.model_name,
)
logger.info(
f"Registered agg endpoint: {config.namespace}.{config.component_name}.generate"
)
logger.info("Global Router ready (agg mode) - serving endpoint...")
try:
await generate_endpoint.serve_endpoint(
handler.handle_generate,
graceful_shutdown=True,
metrics_labels=[
("service", "global_router"),
("type", "agg"),
],
)
except Exception as e:
logger.error(f"Failed to serve agg endpoint: {e}")
raise
finally:
logger.info("Global Router Service shutting down")
......
......@@ -2,12 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
"""
Global Router Handler for hierarchical routing to prefill/decode pools.
Global Router Handler for hierarchical routing to worker pools.
This handler:
1. Receives requests from the frontend (acts as both prefill and decode worker)
2. Selects the appropriate pool based on config-driven grid selection
3. Forwards requests to local routers in the selected pool's namespace
Supports two modes:
- "disagg": Routes prefill and decode requests to separate pool types
based on (ISL, TTFT) and (context_length, ITL) respectively.
- "agg": Routes generate requests to unified pools that handle both
prefill and decode, based on (ISL, ITL).
Both modes support priority-based pool overrides from agent hints.
"""
import logging
......@@ -22,12 +25,13 @@ logger = logging.getLogger(__name__)
class GlobalRouterHandler:
"""
Handler for the Global Router that routes requests to prefill/decode pools.
Handler for the Global Router that routes requests to worker pools.
The global router sits between the frontend and local routers. It:
- Receives prefill requests and routes to appropriate prefill pool
- Receives decode requests and routes to appropriate decode pool
- In disagg mode: routes prefill/decode requests to separate pool types
- In agg mode: routes generate requests to unified pools
- Uses grid-based selection strategy from config to choose pools
- Supports priority-based pool overrides from agent hints
"""
def __init__(
......@@ -38,16 +42,6 @@ class GlobalRouterHandler:
default_ttft_target: Optional[float] = None,
default_itl_target: Optional[float] = None,
):
"""
Initialize the Global Router Handler.
Args:
runtime: Dynamo distributed runtime for creating clients
config_path: Path to the JSON configuration file
model_name: Model name for logging/debugging
default_ttft_target: Default TTFT target (ms) when not in request
default_itl_target: Default ITL target (ms) when not in request
"""
self.runtime = runtime
self.config = load_config(config_path)
self.model_name = model_name
......@@ -58,13 +52,23 @@ class GlobalRouterHandler:
# Will be populated in initialize()
self.prefill_clients: Dict[str, Client] = {}
self.decode_clients: Dict[str, Client] = {}
self.agg_clients: Dict[str, Client] = {}
# Keep track of namespace -> pool index mapping for easy access
if self.config.mode == "disagg":
assert self.config.prefill_pool_dynamo_namespaces is not None
assert self.config.decode_pool_dynamo_namespaces is not None
self.prefill_namespace_to_idx: Dict[str, int] = {
ns: idx for idx, ns in enumerate(self.config.prefill_pool_dynamo_namespaces)
ns: idx
for idx, ns in enumerate(self.config.prefill_pool_dynamo_namespaces)
}
self.decode_namespace_to_idx: Dict[str, int] = {
ns: idx for idx, ns in enumerate(self.config.decode_pool_dynamo_namespaces)
ns: idx
for idx, ns in enumerate(self.config.decode_pool_dynamo_namespaces)
}
elif self.config.mode == "agg":
assert self.config.agg_pool_dynamo_namespaces is not None
self.agg_namespace_to_idx: Dict[str, int] = {
ns: idx for idx, ns in enumerate(self.config.agg_pool_dynamo_namespaces)
}
async def initialize(self) -> None:
......@@ -74,7 +78,17 @@ class GlobalRouterHandler:
This connects to the local router in each pool's namespace.
Local routers are expected at: {namespace}.router.generate
"""
logger.info("Initializing Global Router Handler...")
logger.info(f"Initializing Global Router Handler (mode={self.config.mode})...")
if self.config.mode == "disagg":
await self._initialize_disagg()
elif self.config.mode == "agg":
await self._initialize_agg()
async def _initialize_disagg(self) -> None:
"""Initialize disagg mode clients to prefill and decode pools."""
assert self.config.prefill_pool_dynamo_namespaces is not None
assert self.config.decode_pool_dynamo_namespaces is not None
# Connect to prefill pool local routers
for idx, namespace in enumerate(self.config.prefill_pool_dynamo_namespaces):
......@@ -107,25 +121,39 @@ class GlobalRouterHandler:
raise
logger.info(
f"Global Router initialized: {len(self.prefill_clients)} prefill pools, "
f"Global Router initialized (disagg): {len(self.prefill_clients)} prefill pools, "
f"{len(self.decode_clients)} decode pools"
)
async def _initialize_agg(self) -> None:
"""Initialize agg mode clients to unified pools."""
assert self.config.agg_pool_dynamo_namespaces is not None
for idx, namespace in enumerate(self.config.agg_pool_dynamo_namespaces):
try:
endpoint = self.runtime.endpoint(f"{namespace}.router.generate")
client = await endpoint.client()
self.agg_clients[namespace] = client
logger.info(f"Connected to agg pool {idx}: {namespace}.router.generate")
except Exception as e:
logger.error(f"Failed to connect to agg pool {idx} ({namespace}): {e}")
raise
logger.info(f"Global Router initialized (agg): {len(self.agg_clients)} pools")
async def handle_prefill(
self, request: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Handle prefill requests from the frontend.
Selects the appropriate prefill pool based on ISL and TTFT target,
then forwards the request to the local router in that pool.
Args:
request: PreprocessedRequest dict with token_ids, etc.
Handle prefill requests from the frontend (disagg mode).
Yields:
LLMEngineOutput dicts from the prefill worker
Selects the appropriate prefill pool based on ISL, TTFT target,
and optional priority, then forwards the request to the local
router in that pool.
"""
assert self.config.prefill_pool_selection_strategy is not None
assert self.config.prefill_pool_dynamo_namespaces is not None
# Extract ISL (input sequence length)
token_ids = request.get("token_ids", [])
isl = len(token_ids)
......@@ -165,17 +193,15 @@ class GlobalRouterHandler:
self, request: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Handle decode requests from the frontend.
Selects the appropriate decode pool based on context length and ITL target,
then forwards the request to the local router in that pool.
Args:
request: PreprocessedRequest dict with token_ids, prefill_result, etc.
Handle decode requests from the frontend (disagg mode).
Yields:
LLMEngineOutput dicts from the decode worker
Selects the appropriate decode pool based on context length, ITL target,
and optional priority, then forwards the request to the local
router in that pool.
"""
assert self.config.decode_pool_selection_strategy is not None
assert self.config.decode_pool_dynamo_namespaces is not None
# Extract context length (input tokens + any previously generated)
token_ids = request.get("token_ids", [])
# context_length should be averaged ISL + OSL // 2
......@@ -192,7 +218,9 @@ class GlobalRouterHandler:
# Select decode pool
pool_idx = self.config.decode_pool_selection_strategy.select_pool(
context_length=context_length, itl_target=itl_target, priority=priority
context_length=context_length,
itl_target=itl_target,
priority=priority,
)
namespace = self.config.decode_pool_dynamo_namespaces[pool_idx]
client = self.decode_clients[namespace]
......@@ -214,15 +242,64 @@ class GlobalRouterHandler:
logger.error(f"Error forwarding decode request to {namespace}: {e}")
raise
def get_pool_info(self) -> Dict[str, Any]:
async def handle_generate(
self, request: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Get information about connected pools for debugging/monitoring.
Handle generate requests (agg mode).
Returns:
Dict with pool information
Selects the appropriate agg pool based on TTFT target, ITL target, and
optional priority, then forwards the request to the local router in
that pool. The pool's workers handle both prefill and decode.
"""
return {
assert self.config.agg_pool_selection_strategy is not None
assert self.config.agg_pool_dynamo_namespaces is not None
# Extract SLA targets from extra_args, fallback to CLI defaults.
# Use `is None` checks to preserve explicit 0 values.
extra_args = request.get("extra_args") or {}
ttft_target = extra_args.get("ttft_target")
if ttft_target is None:
ttft_target = self.default_ttft_target
itl_target = extra_args.get("itl_target")
if itl_target is None:
itl_target = self.default_itl_target
# Extract priority from routing hints (set by nvext.agent_hints.priority)
routing = request.get("routing") or {}
priority = routing.get("priority")
# Select agg pool
pool_idx = self.config.agg_pool_selection_strategy.select_pool(
ttft_target=ttft_target, itl_target=itl_target, priority=priority
)
namespace = self.config.agg_pool_dynamo_namespaces[pool_idx]
client = self.agg_clients[namespace]
logger.info(
f"Routing agg request: TTFT_target={ttft_target}, ITL_target={itl_target}, "
f"priority={priority} -> pool {pool_idx} ({namespace})"
)
# Forward request to local router and stream back responses
try:
stream = await client.generate(request)
async for output in stream:
data = output.data() if hasattr(output, "data") else output
yield data
except Exception as e:
logger.error(f"Error forwarding agg request to {namespace}: {e}")
raise
def get_pool_info(self) -> Dict[str, Any]:
"""Get information about connected pools for debugging/monitoring."""
info: Dict[str, Any] = {
"model_name": self.model_name,
"mode": self.config.mode,
}
if self.config.mode == "disagg":
info.update(
{
"num_prefill_pools": self.config.num_prefill_pools,
"num_decode_pools": self.config.num_decode_pools,
"prefill_pools": self.config.prefill_pool_dynamo_namespaces,
......@@ -230,3 +307,13 @@ class GlobalRouterHandler:
"prefill_connected": list(self.prefill_clients.keys()),
"decode_connected": list(self.decode_clients.keys()),
}
)
elif self.config.mode == "agg":
info.update(
{
"num_agg_pools": self.config.num_agg_pools,
"agg_pools": self.config.agg_pool_dynamo_namespaces,
"agg_connected": list(self.agg_clients.keys()),
}
)
return info
......@@ -4,10 +4,14 @@
"""
Configuration loading and pool selection logic for the Global Router.
The config file defines:
- Prefill and decode pool namespaces
- Grid-based pool selection strategies mapping (ISL, TTFT) -> prefill pool
and (context_length, ITL) -> decode pool
Supports two modes:
- "disagg" (default): Separate prefill and decode pools with independent
grid-based selection strategies mapping (ISL, TTFT) -> prefill pool
and (context_length, ITL) -> decode pool.
- "agg": Unified pools handling both prefill and decode (chunked prefill),
with grid-based selection mapping (ISL, ITL) -> agg pool.
Both modes support optional priority-based pool overrides from agent hints.
"""
import json
......@@ -181,19 +185,134 @@ class DecodePoolSelectionStrategy:
return max(0, min(int(value), resolution - 1))
@dataclass
class AggPoolSelectionStrategy:
"""Strategy for selecting agg pools based on TTFT and ITL targets.
In aggregated mode, each pool handles both prefill and decode. Since both
phases happen in the same pool, both SLA targets matter for a single routing
decision. The grid maps (TTFT target, ITL target) -> pool index.
This works regardless of whether chunked prefill is enabled:
- With chunked prefill: ITL reflects combined prefill+decode contention.
- Without chunked prefill: TTFT captures the blocking prefill cost,
ITL captures pure decode performance.
"""
ttft_min: float
ttft_max: float
ttft_resolution: int
itl_min: float
itl_max: float
itl_resolution: int
agg_pool_mapping: List[List[int]]
priority_overrides: List[PriorityPoolOverride] = field(default_factory=list)
@property
def ttft_step(self) -> float:
"""Step size for TTFT grid."""
return (self.ttft_max - self.ttft_min) / self.ttft_resolution
@property
def itl_step(self) -> float:
"""Step size for ITL grid."""
return (self.itl_max - self.itl_min) / self.itl_resolution
def select_pool(
self,
ttft_target: Optional[float] = None,
itl_target: Optional[float] = None,
priority: Optional[int] = None,
) -> int:
"""
Select agg pool based on TTFT target, ITL target, and optional priority.
Args:
ttft_target: Target time to first token in ms. If None, uses middle of range.
itl_target: Target inter-token latency in ms. If None, uses middle of range.
priority: Request priority from agent hints. If set and a priority
override rule matches, the override takes precedence over the grid.
Returns:
Pool index from agg_pool_mapping or a priority override
"""
if ttft_target is None:
ttft_target = (self.ttft_min + self.ttft_max) / 2
if itl_target is None:
itl_target = (self.itl_min + self.itl_max) / 2
# Compute grid indices with clamping
ttft_idx = self._clamp_index(
(ttft_target - self.ttft_min) / self.ttft_step, self.ttft_resolution
)
itl_idx = self._clamp_index(
(itl_target - self.itl_min) / self.itl_step, self.itl_resolution
)
pool_idx = self.agg_pool_mapping[ttft_idx][itl_idx]
pool_idx = _apply_priority_overrides(
pool_idx, priority, self.priority_overrides
)
logger.debug(
f"Agg pool selection: TTFT={ttft_target}, ITL={itl_target}, "
f"priority={priority} -> pool {pool_idx}"
)
return pool_idx
@staticmethod
def _clamp_index(value: float, resolution: int) -> int:
"""Clamp index to valid grid range."""
return max(0, min(int(value), resolution - 1))
@dataclass
class GlobalRouterConfig:
"""Configuration for the Global Router."""
"""Configuration for the Global Router.
Supports two modes:
- "disagg" (default): separate prefill and decode pools
- "agg": unified pools handling both prefill and decode
"""
mode: str = "disagg" # "disagg" or "agg"
num_prefill_pools: int
num_decode_pools: int
prefill_pool_dynamo_namespaces: List[str]
decode_pool_dynamo_namespaces: List[str]
prefill_pool_selection_strategy: PrefillPoolSelectionStrategy
decode_pool_selection_strategy: DecodePoolSelectionStrategy
# --- disagg-only fields (required when mode="disagg") ---
num_prefill_pools: Optional[int] = None
num_decode_pools: Optional[int] = None
prefill_pool_dynamo_namespaces: Optional[List[str]] = None
decode_pool_dynamo_namespaces: Optional[List[str]] = None
prefill_pool_selection_strategy: Optional[PrefillPoolSelectionStrategy] = None
decode_pool_selection_strategy: Optional[DecodePoolSelectionStrategy] = None
# --- agg-only fields (required when mode="agg") ---
num_agg_pools: Optional[int] = None
agg_pool_dynamo_namespaces: Optional[List[str]] = None
agg_pool_selection_strategy: Optional[AggPoolSelectionStrategy] = None
def validate(self) -> None:
"""Validate configuration consistency."""
if self.mode == "disagg":
self._validate_disagg()
elif self.mode == "agg":
self._validate_agg()
else:
raise ValueError(f"Unknown mode '{self.mode}', must be 'disagg' or 'agg'")
def _validate_disagg(self) -> None:
"""Validate disagg mode configuration."""
if self.num_prefill_pools is None:
raise ValueError("num_prefill_pools required for disagg mode")
if self.num_decode_pools is None:
raise ValueError("num_decode_pools required for disagg mode")
if self.prefill_pool_dynamo_namespaces is None:
raise ValueError("prefill_pool_dynamo_namespaces required for disagg mode")
if self.decode_pool_dynamo_namespaces is None:
raise ValueError("decode_pool_dynamo_namespaces required for disagg mode")
if self.prefill_pool_selection_strategy is None:
raise ValueError("prefill_pool_selection_strategy required for disagg mode")
if self.decode_pool_selection_strategy is None:
raise ValueError("decode_pool_selection_strategy required for disagg mode")
if len(self.prefill_pool_dynamo_namespaces) != self.num_prefill_pools:
raise ValueError(
f"num_prefill_pools ({self.num_prefill_pools}) does not match "
......@@ -206,7 +325,7 @@ class GlobalRouterConfig:
f"decode_pool_dynamo_namespaces length ({len(self.decode_pool_dynamo_namespaces)})"
)
# Validate prefill strategy ranges and resolutions
# Validate prefill strategy
prefill_strategy = self.prefill_pool_selection_strategy
if prefill_strategy.isl_resolution <= 0:
raise ValueError(
......@@ -227,27 +346,6 @@ class GlobalRouterConfig:
f"ttft_max ({prefill_strategy.ttft_max})"
)
# Validate decode strategy ranges and resolutions
decode_strategy = self.decode_pool_selection_strategy
if decode_strategy.context_length_resolution <= 0:
raise ValueError(
f"context_length_resolution must be positive, got {decode_strategy.context_length_resolution}"
)
if decode_strategy.itl_resolution <= 0:
raise ValueError(
f"itl_resolution must be positive, got {decode_strategy.itl_resolution}"
)
if decode_strategy.context_length_min >= decode_strategy.context_length_max:
raise ValueError(
f"context_length_min ({decode_strategy.context_length_min}) must be less than "
f"context_length_max ({decode_strategy.context_length_max})"
)
if decode_strategy.itl_min >= decode_strategy.itl_max:
raise ValueError(
f"itl_min ({decode_strategy.itl_min}) must be less than "
f"itl_max ({decode_strategy.itl_max})"
)
# Validate mapping dimensions match resolution
if (
len(prefill_strategy.prefill_pool_mapping)
......@@ -287,7 +385,27 @@ class GlobalRouterConfig:
f"{override.target_pool} (must be 0 to {self.num_prefill_pools - 1})"
)
# Validate decode strategy
decode_strategy = self.decode_pool_selection_strategy
if decode_strategy.context_length_resolution <= 0:
raise ValueError(
f"context_length_resolution must be positive, got {decode_strategy.context_length_resolution}"
)
if decode_strategy.itl_resolution <= 0:
raise ValueError(
f"itl_resolution must be positive, got {decode_strategy.itl_resolution}"
)
if decode_strategy.context_length_min >= decode_strategy.context_length_max:
raise ValueError(
f"context_length_min ({decode_strategy.context_length_min}) must be less than "
f"context_length_max ({decode_strategy.context_length_max})"
)
if decode_strategy.itl_min >= decode_strategy.itl_max:
raise ValueError(
f"itl_min ({decode_strategy.itl_min}) must be less than "
f"itl_max ({decode_strategy.itl_max})"
)
if (
len(decode_strategy.decode_pool_mapping)
!= decode_strategy.context_length_resolution
......@@ -326,6 +444,74 @@ class GlobalRouterConfig:
f"{override.target_pool} (must be 0 to {self.num_decode_pools - 1})"
)
def _validate_agg(self) -> None:
"""Validate agg mode configuration."""
if self.num_agg_pools is None:
raise ValueError("num_agg_pools required for agg mode")
if self.agg_pool_dynamo_namespaces is None:
raise ValueError("agg_pool_dynamo_namespaces required for agg mode")
if self.agg_pool_selection_strategy is None:
raise ValueError("agg_pool_selection_strategy required for agg mode")
if len(self.agg_pool_dynamo_namespaces) != self.num_agg_pools:
raise ValueError(
f"num_agg_pools ({self.num_agg_pools}) does not match "
f"agg_pool_dynamo_namespaces length ({len(self.agg_pool_dynamo_namespaces)})"
)
agg_strategy = self.agg_pool_selection_strategy
if agg_strategy.ttft_resolution <= 0:
raise ValueError(
f"ttft_resolution must be positive, got {agg_strategy.ttft_resolution}"
)
if agg_strategy.itl_resolution <= 0:
raise ValueError(
f"itl_resolution must be positive, got {agg_strategy.itl_resolution}"
)
if agg_strategy.ttft_min >= agg_strategy.ttft_max:
raise ValueError(
f"ttft_min ({agg_strategy.ttft_min}) must be less than "
f"ttft_max ({agg_strategy.ttft_max})"
)
if agg_strategy.itl_min >= agg_strategy.itl_max:
raise ValueError(
f"itl_min ({agg_strategy.itl_min}) must be less than "
f"itl_max ({agg_strategy.itl_max})"
)
# Validate mapping dimensions
if len(agg_strategy.agg_pool_mapping) != agg_strategy.ttft_resolution:
raise ValueError(
f"agg_pool_mapping rows ({len(agg_strategy.agg_pool_mapping)}) "
f"does not match ttft_resolution ({agg_strategy.ttft_resolution})"
)
for i, row in enumerate(agg_strategy.agg_pool_mapping):
if len(row) != agg_strategy.itl_resolution:
raise ValueError(
f"agg_pool_mapping row {i} length ({len(row)}) "
f"does not match itl_resolution ({agg_strategy.itl_resolution})"
)
for pool_idx in row:
if pool_idx < 0 or pool_idx >= self.num_agg_pools:
raise ValueError(
f"Invalid agg pool index {pool_idx} in mapping "
f"(must be 0 to {self.num_agg_pools - 1})"
)
for i, override in enumerate(agg_strategy.priority_overrides):
if override.min_priority > override.max_priority:
raise ValueError(
f"Agg priority_overrides[{i}]: min_priority "
f"({override.min_priority}) must be <= max_priority "
f"({override.max_priority})"
)
if override.target_pool < 0 or override.target_pool >= self.num_agg_pools:
raise ValueError(
f"Agg priority_overrides[{i}]: invalid target_pool "
f"{override.target_pool} (must be 0 to {self.num_agg_pools - 1})"
)
def load_config(config_path: str | Path) -> GlobalRouterConfig:
"""
......@@ -350,6 +536,21 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
logger.info(f"Loading global router config from {config_path}")
mode = data.get("mode", "disagg")
if mode == "disagg":
config = _load_disagg_config(data, mode)
elif mode == "agg":
config = _load_agg_config(data, mode)
else:
raise ValueError(f"Unknown mode '{mode}' in config")
config.validate()
return config
def _load_disagg_config(data: dict, mode: str) -> GlobalRouterConfig:
"""Load disagg mode configuration from parsed JSON data."""
# Parse prefill selection strategy
prefill_strategy_data = data["prefill_pool_selection_strategy"]
prefill_priority_overrides = [
......@@ -385,6 +586,7 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
)
config = GlobalRouterConfig(
mode=mode,
num_prefill_pools=data["num_prefill_pools"],
num_decode_pools=data["num_decode_pools"],
prefill_pool_dynamo_namespaces=data["prefill_pool_dynamo_namespaces"],
......@@ -393,10 +595,37 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
decode_pool_selection_strategy=decode_strategy,
)
config.validate()
logger.info(
f"Loaded config: {config.num_prefill_pools} prefill pools, "
f"Loaded disagg config: {config.num_prefill_pools} prefill pools, "
f"{config.num_decode_pools} decode pools"
)
return config
def _load_agg_config(data: dict, mode: str) -> GlobalRouterConfig:
"""Load agg mode configuration from parsed JSON data."""
agg_strategy_data = data["agg_pool_selection_strategy"]
agg_priority_overrides = [
PriorityPoolOverride(**rule)
for rule in agg_strategy_data.get("priority_overrides", [])
]
agg_strategy = AggPoolSelectionStrategy(
ttft_min=agg_strategy_data["ttft_min"],
ttft_max=agg_strategy_data["ttft_max"],
ttft_resolution=agg_strategy_data["ttft_resolution"],
itl_min=agg_strategy_data["itl_min"],
itl_max=agg_strategy_data["itl_max"],
itl_resolution=agg_strategy_data["itl_resolution"],
agg_pool_mapping=agg_strategy_data["agg_pool_mapping"],
priority_overrides=agg_priority_overrides,
)
config = GlobalRouterConfig(
mode=mode,
num_agg_pools=data["num_agg_pools"],
agg_pool_dynamo_namespaces=data["agg_pool_dynamo_namespaces"],
agg_pool_selection_strategy=agg_strategy,
)
logger.info(f"Loaded agg config: {config.num_agg_pools} agg pools")
return config
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for aggregated (agg) mode pool routing in the global router."""
import json
from pathlib import Path
import pytest
from dynamo.global_router.pool_selection import (
AggPoolSelectionStrategy,
GlobalRouterConfig,
PriorityPoolOverride,
load_config,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.parallel,
pytest.mark.unit,
]
# --- Helpers ---
def _make_agg_strategy(
num_pools=2, priority_overrides=None
) -> AggPoolSelectionStrategy:
return AggPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=2,
itl_min=5,
itl_max=200,
itl_resolution=2,
agg_pool_mapping=[[0, 1], [1, 1]],
priority_overrides=priority_overrides or [],
)
def _write_config(tmp_dir: Path, config_data: dict) -> Path:
config_path = tmp_dir / "config.json"
config_path.write_text(json.dumps(config_data))
return config_path
def _agg_base_config(**overrides) -> dict:
config = {
"mode": "agg",
"num_agg_pools": 2,
"agg_pool_dynamo_namespaces": ["ns-agg-0", "ns-agg-1"],
"agg_pool_selection_strategy": {
"ttft_min": 10,
"ttft_max": 3000,
"ttft_resolution": 2,
"itl_min": 5,
"itl_max": 200,
"itl_resolution": 2,
"agg_pool_mapping": [[0, 1], [1, 1]],
},
}
config.update(overrides)
return config
# --- AggPoolSelectionStrategy tests ---
class TestAggPoolSelection:
"""Tests for the TTFT x ITL grid selection.
Default strategy mapping [[0, 1], [1, 1]] with:
- TTFT: [10, 3000], resolution 2 -> step=1495, boundary at 1505
- ITL: [5, 200], resolution 2 -> step=97.5, boundary at 102.5
mapping[ttft_idx][itl_idx]:
- [0][0] = 0: tight TTFT + tight ITL -> pool 0
- [0][1] = 1: tight TTFT + relaxed ITL -> pool 1
- [1][0] = 1: relaxed TTFT + tight ITL -> pool 1
- [1][1] = 1: relaxed TTFT + relaxed ITL -> pool 1
"""
def test_tight_ttft_tight_itl(self):
strategy = _make_agg_strategy()
# ttft_idx=0, itl_idx=0 -> pool 0
result = strategy.select_pool(ttft_target=100, itl_target=10)
assert result == 0
def test_tight_ttft_relaxed_itl(self):
strategy = _make_agg_strategy()
# ttft_idx=0, itl_idx=1 -> pool 1
result = strategy.select_pool(ttft_target=100, itl_target=150)
assert result == 1
def test_relaxed_ttft_tight_itl(self):
strategy = _make_agg_strategy()
# ttft_idx=1, itl_idx=0 -> pool 1
result = strategy.select_pool(ttft_target=2000, itl_target=10)
assert result == 1
def test_relaxed_ttft_relaxed_itl(self):
strategy = _make_agg_strategy()
# ttft_idx=1, itl_idx=1 -> pool 1
result = strategy.select_pool(ttft_target=2000, itl_target=150)
assert result == 1
def test_default_ttft_uses_midpoint(self):
strategy = _make_agg_strategy()
# ttft_target=None -> midpoint=(10+3000)/2=1505 -> ttft_idx=1
# itl_target=10 -> itl_idx=0
# [1][0] = 1
result = strategy.select_pool(itl_target=10)
assert result == 1
def test_default_itl_uses_midpoint(self):
strategy = _make_agg_strategy()
# ttft_target=100 -> ttft_idx=0
# itl_target=None -> midpoint=(5+200)/2=102.5 -> itl_idx=1
# [0][1] = 1
result = strategy.select_pool(ttft_target=100)
assert result == 1
def test_both_defaults_use_midpoints(self):
strategy = _make_agg_strategy()
# ttft midpoint=1505 -> ttft_idx=1, itl midpoint=102.5 -> itl_idx=1
# [1][1] = 1
result = strategy.select_pool()
assert result == 1
def test_clamping_below_min(self):
strategy = _make_agg_strategy()
# Both below min -> both clamp to idx 0
result = strategy.select_pool(ttft_target=0, itl_target=0)
assert result == 0
def test_clamping_above_max(self):
strategy = _make_agg_strategy()
# Both above max -> both clamp to max idx
result = strategy.select_pool(ttft_target=10000, itl_target=1000)
assert result == 1
def test_priority_override_takes_precedence(self):
strategy = _make_agg_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=0)
]
)
# Grid: relaxed TTFT + relaxed ITL -> pool 1, but priority overrides to 0
result = strategy.select_pool(ttft_target=2000, itl_target=150, priority=50)
assert result == 0
def test_no_priority_uses_grid(self):
strategy = _make_agg_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=0)
]
)
result = strategy.select_pool(ttft_target=2000, itl_target=150)
assert result == 1 # grid result, no priority
def test_unmatched_priority_uses_grid(self):
strategy = _make_agg_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=0)
]
)
result = strategy.select_pool(ttft_target=2000, itl_target=150, priority=5)
assert result == 1 # priority=5 doesn't match [10, 100]
def test_no_overrides_backward_compatible(self):
strategy = _make_agg_strategy()
result = strategy.select_pool(ttft_target=100, itl_target=10, priority=50)
assert result == 0 # no overrides configured, grid result
# --- AggPoolSelectionStrategy with custom mapping ---
class TestAggPoolSelectionCustomMapping:
def test_3x3_grid(self):
strategy = AggPoolSelectionStrategy(
ttft_min=10,
ttft_max=3010,
ttft_resolution=3,
itl_min=10,
itl_max=100,
itl_resolution=3,
agg_pool_mapping=[[0, 1, 2], [1, 2, 0], [2, 0, 1]],
)
# Low TTFT, low ITL -> pool 0
assert strategy.select_pool(ttft_target=100, itl_target=15) == 0
# Low TTFT, high ITL -> pool 2
assert strategy.select_pool(ttft_target=100, itl_target=90) == 2
# Mid TTFT, mid ITL -> pool 2
assert strategy.select_pool(ttft_target=1500, itl_target=55) == 2
# High TTFT, low ITL -> pool 2
assert strategy.select_pool(ttft_target=2500, itl_target=15) == 2
# --- GlobalRouterConfig agg validation tests ---
class TestAggConfigValidation:
def test_valid_agg_config(self):
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=_make_agg_strategy(),
)
config.validate() # should not raise
def test_missing_num_agg_pools(self):
config = GlobalRouterConfig(
mode="agg",
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=_make_agg_strategy(),
)
with pytest.raises(ValueError, match="num_agg_pools required"):
config.validate()
def test_missing_namespaces(self):
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_selection_strategy=_make_agg_strategy(),
)
with pytest.raises(ValueError, match="agg_pool_dynamo_namespaces required"):
config.validate()
def test_missing_strategy(self):
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
)
with pytest.raises(ValueError, match="agg_pool_selection_strategy required"):
config.validate()
def test_namespace_count_mismatch(self):
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=3,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=_make_agg_strategy(),
)
with pytest.raises(ValueError, match="num_agg_pools.*does not match"):
config.validate()
def test_invalid_pool_idx_in_mapping(self):
strategy = AggPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=2,
itl_min=5,
itl_max=200,
itl_resolution=2,
agg_pool_mapping=[[0, 5], [0, 1]], # pool 5 is out of range
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="Invalid agg pool index"):
config.validate()
def test_mapping_row_count_mismatch(self):
strategy = AggPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=3, # expects 3 rows
itl_min=5,
itl_max=200,
itl_resolution=2,
agg_pool_mapping=[[0, 1], [0, 1]], # only 2 rows
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="agg_pool_mapping rows.*does not match"):
config.validate()
def test_mapping_col_count_mismatch(self):
strategy = AggPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=2,
itl_min=5,
itl_max=200,
itl_resolution=3, # expects 3 columns
agg_pool_mapping=[[0, 1], [0, 1]], # only 2 columns per row
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="agg_pool_mapping row.*does not match"):
config.validate()
def test_priority_override_invalid_target(self):
strategy = _make_agg_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=5)
]
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="invalid target_pool"):
config.validate()
def test_priority_override_inverted_range(self):
strategy = _make_agg_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=20, max_priority=5, target_pool=1)
]
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="min_priority"):
config.validate()
def test_unknown_mode(self):
config = GlobalRouterConfig(mode="invalid")
with pytest.raises(ValueError, match="Unknown mode"):
config.validate()
def test_ttft_range_invalid(self):
strategy = AggPoolSelectionStrategy(
ttft_min=3000,
ttft_max=10, # min > max
ttft_resolution=2,
itl_min=5,
itl_max=200,
itl_resolution=2,
agg_pool_mapping=[[0, 1], [0, 1]],
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="ttft_min.*must be less than"):
config.validate()
def test_itl_range_invalid(self):
strategy = AggPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=2,
itl_min=200,
itl_max=5, # min > max
itl_resolution=2,
agg_pool_mapping=[[0, 1], [0, 1]],
)
config = GlobalRouterConfig(
mode="agg",
num_agg_pools=2,
agg_pool_dynamo_namespaces=["a", "b"],
agg_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="itl_min.*must be less than"):
config.validate()
# --- Config loading tests ---
class TestLoadAggConfig:
def test_load_agg_config(self, tmp_path):
config_data = _agg_base_config()
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert config.mode == "agg"
assert config.num_agg_pools == 2
assert config.agg_pool_dynamo_namespaces == ["ns-agg-0", "ns-agg-1"]
assert config.agg_pool_selection_strategy is not None
assert config.agg_pool_selection_strategy.ttft_min == 10
assert config.agg_pool_selection_strategy.ttft_max == 3000
assert config.agg_pool_selection_strategy.itl_min == 5
assert config.agg_pool_selection_strategy.itl_max == 200
def test_agg_config_with_priority_overrides(self, tmp_path):
config_data = _agg_base_config()
config_data["agg_pool_selection_strategy"]["priority_overrides"] = [
{"min_priority": 10, "max_priority": 100, "target_pool": 1}
]
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert len(config.agg_pool_selection_strategy.priority_overrides) == 1
override = config.agg_pool_selection_strategy.priority_overrides[0]
assert override.min_priority == 10
assert override.max_priority == 100
assert override.target_pool == 1
def test_agg_config_without_priority_overrides(self, tmp_path):
config_data = _agg_base_config()
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert config.agg_pool_selection_strategy.priority_overrides == []
def test_unknown_mode_in_config(self, tmp_path):
config_data = {"mode": "invalid"}
config_path = _write_config(tmp_path, config_data)
with pytest.raises(ValueError, match="Unknown mode"):
load_config(config_path)
def test_agg_pool_selection_routes_correctly(self, tmp_path):
"""End-to-end: load config, then verify pool selection works."""
config_data = _agg_base_config()
config_data["agg_pool_selection_strategy"]["priority_overrides"] = [
{"min_priority": 50, "max_priority": 100, "target_pool": 0}
]
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
strategy = config.agg_pool_selection_strategy
# Tight TTFT + tight ITL -> pool 0 from grid
assert strategy.select_pool(ttft_target=100, itl_target=10) == 0
# Relaxed TTFT + relaxed ITL -> pool 1 from grid
assert strategy.select_pool(ttft_target=2000, itl_target=150) == 1
# Priority override: relaxed would be pool 1, but priority 75 -> pool 0
assert strategy.select_pool(ttft_target=2000, itl_target=150, priority=75) == 0
......@@ -235,6 +235,7 @@ class TestValidatePriorityOverrides:
]
)
config = GlobalRouterConfig(
mode="disagg",
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
......@@ -252,6 +253,7 @@ class TestValidatePriorityOverrides:
]
)
config = GlobalRouterConfig(
mode="disagg",
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
......@@ -269,6 +271,7 @@ class TestValidatePriorityOverrides:
]
)
config = GlobalRouterConfig(
mode="disagg",
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
......@@ -286,6 +289,7 @@ class TestValidatePriorityOverrides:
]
)
config = GlobalRouterConfig(
mode="disagg",
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
......@@ -294,3 +298,11 @@ class TestValidatePriorityOverrides:
decode_pool_selection_strategy=_make_decode_strategy(),
)
config.validate() # should not raise
def test_default_mode_is_disagg(self, tmp_path):
"""Config without explicit mode defaults to disagg."""
config_data = _base_config()
# No "mode" key
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert config.mode == "disagg"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment