Unverified Commit 7a6db48e authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(global-router): priority-based pool routing from agent hints (#8010)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 8fd1b847
...@@ -202,6 +202,7 @@ def _build_dynamo_preproc( ...@@ -202,6 +202,7 @@ def _build_dynamo_preproc(
}, },
"eos_token_ids": [eos_token_id] if eos_token_id is not None else [], "eos_token_ids": [eos_token_id] if eos_token_id is not None else [],
"annotations": [], "annotations": [],
"routing": request.get("routing"),
} }
# Forward multimodal URLs so the backend handler can load the media. # Forward multimodal URLs so the backend handler can load the media.
......
...@@ -269,6 +269,7 @@ class VllmProcessor: ...@@ -269,6 +269,7 @@ class VllmProcessor:
}, },
"eos_token_ids": self._get_eos_token_ids(), "eos_token_ids": self._get_eos_token_ids(),
"annotations": [], "annotations": [],
"routing": request.get("routing"),
} }
# Forward multimodal URLs so the backend handler can load the media. # Forward multimodal URLs so the backend handler can load the media.
......
...@@ -133,6 +133,43 @@ If `prefill_pool_mapping = [[0, 1], [0, 1]]` and `ttft_resolution=2`: ...@@ -133,6 +133,43 @@ If `prefill_pool_mapping = [[0, 1], [0, 1]]` and `ttft_resolution=2`:
- High ISL + Low TTFT target → pool 0 - High ISL + Low TTFT target → pool 0
- High ISL + High TTFT target → pool 1 - High ISL + High TTFT target → pool 1
### Priority-Based Pool Override
Both prefill and decode strategies support optional `priority_overrides` rules.
When a request carries a priority value (from `nvext.agent_hints.priority`), the
global router evaluates the override rules **after** the grid lookup. The first
rule whose `[min_priority, max_priority]` range contains the request priority
wins, and the request is routed to that rule's `target_pool` instead of the
grid result. If no rule matches (or no priority is present), the grid result
is used as normal.
This is useful for straggler mitigation in RL workloads: the RL framework can
tag slow requests with a high priority, and the global router redirects them to
a dedicated min-latency pool.
```jsonc
"priority_overrides": [
{
"min_priority": 10, // inclusive lower bound
"max_priority": 100, // inclusive upper bound
"target_pool": 1 // pool index to route to
}
]
```
Priority is set by the client via the NVIDIA OpenAI extension:
```json
{
"messages": [...],
"nvext": {
"agent_hints": {
"priority": 50
}
}
}
```
### Passing SLA Targets ### Passing SLA Targets
Clients can pass TTFT and ITL targets via `extra_args` in the request: Clients can pass TTFT and ITL targets via `extra_args` in the request:
......
...@@ -134,16 +134,20 @@ class GlobalRouterHandler: ...@@ -134,16 +134,20 @@ class GlobalRouterHandler:
extra_args = request.get("extra_args") or {} extra_args = request.get("extra_args") or {}
ttft_target = extra_args.get("ttft_target") or self.default_ttft_target ttft_target = extra_args.get("ttft_target") or self.default_ttft_target
# Extract priority from routing hints (set by nvext.agent_hints.priority)
routing = request.get("routing") or {}
priority = routing.get("priority")
# Select prefill pool # Select prefill pool
pool_idx = self.config.prefill_pool_selection_strategy.select_pool( pool_idx = self.config.prefill_pool_selection_strategy.select_pool(
isl=isl, ttft_target=ttft_target isl=isl, ttft_target=ttft_target, priority=priority
) )
namespace = self.config.prefill_pool_dynamo_namespaces[pool_idx] namespace = self.config.prefill_pool_dynamo_namespaces[pool_idx]
client = self.prefill_clients[namespace] client = self.prefill_clients[namespace]
logger.info( logger.info(
f"Routing prefill request: ISL={isl}, TTFT_target={ttft_target} -> " f"Routing prefill request: ISL={isl}, TTFT_target={ttft_target}, "
f"pool {pool_idx} ({namespace})" f"priority={priority} -> pool {pool_idx} ({namespace})"
) )
# Forward request to local router and stream back responses # Forward request to local router and stream back responses
...@@ -182,15 +186,20 @@ class GlobalRouterHandler: ...@@ -182,15 +186,20 @@ class GlobalRouterHandler:
extra_args = request.get("extra_args") or {} extra_args = request.get("extra_args") or {}
itl_target = extra_args.get("itl_target") or self.default_itl_target itl_target = extra_args.get("itl_target") or self.default_itl_target
# Extract priority from routing hints (set by nvext.agent_hints.priority)
routing = request.get("routing") or {}
priority = routing.get("priority")
# Select decode pool # Select decode pool
pool_idx = self.config.decode_pool_selection_strategy.select_pool( pool_idx = self.config.decode_pool_selection_strategy.select_pool(
context_length=context_length, itl_target=itl_target context_length=context_length, itl_target=itl_target, priority=priority
) )
namespace = self.config.decode_pool_dynamo_namespaces[pool_idx] namespace = self.config.decode_pool_dynamo_namespaces[pool_idx]
client = self.decode_clients[namespace] client = self.decode_clients[namespace]
logger.info( logger.info(
f"Routing decode request: context_length={context_length}, ITL_target={itl_target} -> " f"Routing decode request: context_length={context_length}, "
f"ITL_target={itl_target}, priority={priority} -> "
f"pool {pool_idx} ({namespace})" f"pool {pool_idx} ({namespace})"
) )
......
...@@ -12,13 +12,36 @@ The config file defines: ...@@ -12,13 +12,36 @@ The config file defines:
import json import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class PriorityPoolOverride:
"""Override pool selection based on request priority from agent hints."""
min_priority: int # inclusive lower bound
max_priority: int # inclusive upper bound
target_pool: int # pool index to route to when priority matches
def _apply_priority_overrides(
base_pool: int,
priority: Optional[int],
overrides: List[PriorityPoolOverride],
) -> int:
"""Apply priority-based pool overrides. First matching rule wins."""
if priority is None or not overrides:
return base_pool
for rule in overrides:
if rule.min_priority <= priority <= rule.max_priority:
return rule.target_pool
return base_pool
@dataclass @dataclass
class PrefillPoolSelectionStrategy: class PrefillPoolSelectionStrategy:
"""Strategy for selecting prefill pools based on ISL and TTFT target.""" """Strategy for selecting prefill pools based on ISL and TTFT target."""
...@@ -30,6 +53,7 @@ class PrefillPoolSelectionStrategy: ...@@ -30,6 +53,7 @@ class PrefillPoolSelectionStrategy:
isl_max: int isl_max: int
isl_resolution: int isl_resolution: int
prefill_pool_mapping: List[List[int]] prefill_pool_mapping: List[List[int]]
priority_overrides: List[PriorityPoolOverride] = field(default_factory=list)
@property @property
def ttft_step(self) -> float: def ttft_step(self) -> float:
...@@ -41,16 +65,23 @@ class PrefillPoolSelectionStrategy: ...@@ -41,16 +65,23 @@ class PrefillPoolSelectionStrategy:
"""Step size for ISL grid.""" """Step size for ISL grid."""
return (self.isl_max - self.isl_min) / self.isl_resolution return (self.isl_max - self.isl_min) / self.isl_resolution
def select_pool(self, isl: int, ttft_target: Optional[float] = None) -> int: def select_pool(
self,
isl: int,
ttft_target: Optional[float] = None,
priority: Optional[int] = None,
) -> int:
""" """
Select prefill pool based on ISL and TTFT target. Select prefill pool based on ISL, TTFT target, and optional priority.
Args: Args:
isl: Input sequence length (number of tokens) isl: Input sequence length (number of tokens)
ttft_target: Target time to first token in ms. If None, uses middle of range. ttft_target: Target time to first token in ms. If None, uses middle of range.
priority: Request priority from agent hints. If set and a priority
override rule matches, the override takes precedence over the grid.
Returns: Returns:
Pool index from prefill_pool_mapping Pool index from prefill_pool_mapping or a priority override
""" """
if ttft_target is None: if ttft_target is None:
ttft_target = (self.ttft_min + self.ttft_max) / 2 ttft_target = (self.ttft_min + self.ttft_max) / 2
...@@ -64,9 +95,12 @@ class PrefillPoolSelectionStrategy: ...@@ -64,9 +95,12 @@ class PrefillPoolSelectionStrategy:
) )
pool_idx = self.prefill_pool_mapping[isl_idx][ttft_idx] pool_idx = self.prefill_pool_mapping[isl_idx][ttft_idx]
pool_idx = _apply_priority_overrides(
pool_idx, priority, self.priority_overrides
)
logger.debug( logger.debug(
f"Prefill pool selection: ISL={isl}, TTFT={ttft_target} -> " f"Prefill pool selection: ISL={isl}, TTFT={ttft_target}, "
f"grid[{isl_idx}][{ttft_idx}] -> pool {pool_idx}" f"priority={priority} -> pool {pool_idx}"
) )
return pool_idx return pool_idx
...@@ -87,6 +121,7 @@ class DecodePoolSelectionStrategy: ...@@ -87,6 +121,7 @@ class DecodePoolSelectionStrategy:
context_length_max: int context_length_max: int
context_length_resolution: int context_length_resolution: int
decode_pool_mapping: List[List[int]] decode_pool_mapping: List[List[int]]
priority_overrides: List[PriorityPoolOverride] = field(default_factory=list)
@property @property
def itl_step(self) -> float: def itl_step(self) -> float:
...@@ -101,17 +136,22 @@ class DecodePoolSelectionStrategy: ...@@ -101,17 +136,22 @@ class DecodePoolSelectionStrategy:
) / self.context_length_resolution ) / self.context_length_resolution
def select_pool( def select_pool(
self, context_length: int, itl_target: Optional[float] = None self,
context_length: int,
itl_target: Optional[float] = None,
priority: Optional[int] = None,
) -> int: ) -> int:
""" """
Select decode pool based on context length and ITL target. Select decode pool based on context length, ITL target, and optional priority.
Args: Args:
context_length: Total context length (prompt + generated tokens so far) context_length: Total context length (prompt + generated tokens so far)
itl_target: Target inter-token latency in ms. If None, uses middle of range. itl_target: Target inter-token latency in ms. If None, uses middle of range.
priority: Request priority from agent hints. If set and a priority
override rule matches, the override takes precedence over the grid.
Returns: Returns:
Pool index from decode_pool_mapping Pool index from decode_pool_mapping or a priority override
""" """
if itl_target is None: if itl_target is None:
itl_target = (self.itl_min + self.itl_max) / 2 itl_target = (self.itl_min + self.itl_max) / 2
...@@ -126,9 +166,12 @@ class DecodePoolSelectionStrategy: ...@@ -126,9 +166,12 @@ class DecodePoolSelectionStrategy:
) )
pool_idx = self.decode_pool_mapping[ctx_idx][itl_idx] pool_idx = self.decode_pool_mapping[ctx_idx][itl_idx]
pool_idx = _apply_priority_overrides(
pool_idx, priority, self.priority_overrides
)
logger.debug( logger.debug(
f"Decode pool selection: context_length={context_length}, ITL={itl_target} -> " f"Decode pool selection: context_length={context_length}, ITL={itl_target}, "
f"grid[{ctx_idx}][{itl_idx}] -> pool {pool_idx}" f"priority={priority} -> pool {pool_idx}"
) )
return pool_idx return pool_idx
...@@ -228,6 +271,22 @@ class GlobalRouterConfig: ...@@ -228,6 +271,22 @@ class GlobalRouterConfig:
f"(must be 0 to {self.num_prefill_pools - 1})" f"(must be 0 to {self.num_prefill_pools - 1})"
) )
for i, override in enumerate(prefill_strategy.priority_overrides):
if override.min_priority > override.max_priority:
raise ValueError(
f"Prefill priority_overrides[{i}]: min_priority "
f"({override.min_priority}) must be <= max_priority "
f"({override.max_priority})"
)
if (
override.target_pool < 0
or override.target_pool >= self.num_prefill_pools
):
raise ValueError(
f"Prefill priority_overrides[{i}]: invalid target_pool "
f"{override.target_pool} (must be 0 to {self.num_prefill_pools - 1})"
)
decode_strategy = self.decode_pool_selection_strategy decode_strategy = self.decode_pool_selection_strategy
if ( if (
len(decode_strategy.decode_pool_mapping) len(decode_strategy.decode_pool_mapping)
...@@ -251,6 +310,22 @@ class GlobalRouterConfig: ...@@ -251,6 +310,22 @@ class GlobalRouterConfig:
f"(must be 0 to {self.num_decode_pools - 1})" f"(must be 0 to {self.num_decode_pools - 1})"
) )
for i, override in enumerate(decode_strategy.priority_overrides):
if override.min_priority > override.max_priority:
raise ValueError(
f"Decode priority_overrides[{i}]: min_priority "
f"({override.min_priority}) must be <= max_priority "
f"({override.max_priority})"
)
if (
override.target_pool < 0
or override.target_pool >= self.num_decode_pools
):
raise ValueError(
f"Decode priority_overrides[{i}]: invalid target_pool "
f"{override.target_pool} (must be 0 to {self.num_decode_pools - 1})"
)
def load_config(config_path: str | Path) -> GlobalRouterConfig: def load_config(config_path: str | Path) -> GlobalRouterConfig:
""" """
...@@ -277,6 +352,10 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig: ...@@ -277,6 +352,10 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
# Parse prefill selection strategy # Parse prefill selection strategy
prefill_strategy_data = data["prefill_pool_selection_strategy"] prefill_strategy_data = data["prefill_pool_selection_strategy"]
prefill_priority_overrides = [
PriorityPoolOverride(**rule)
for rule in prefill_strategy_data.get("priority_overrides", [])
]
prefill_strategy = PrefillPoolSelectionStrategy( prefill_strategy = PrefillPoolSelectionStrategy(
ttft_min=prefill_strategy_data["ttft_min"], ttft_min=prefill_strategy_data["ttft_min"],
ttft_max=prefill_strategy_data["ttft_max"], ttft_max=prefill_strategy_data["ttft_max"],
...@@ -285,10 +364,15 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig: ...@@ -285,10 +364,15 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
isl_max=prefill_strategy_data["isl_max"], isl_max=prefill_strategy_data["isl_max"],
isl_resolution=prefill_strategy_data["isl_resolution"], isl_resolution=prefill_strategy_data["isl_resolution"],
prefill_pool_mapping=prefill_strategy_data["prefill_pool_mapping"], prefill_pool_mapping=prefill_strategy_data["prefill_pool_mapping"],
priority_overrides=prefill_priority_overrides,
) )
# Parse decode selection strategy # Parse decode selection strategy
decode_strategy_data = data["decode_pool_selection_strategy"] decode_strategy_data = data["decode_pool_selection_strategy"]
decode_priority_overrides = [
PriorityPoolOverride(**rule)
for rule in decode_strategy_data.get("priority_overrides", [])
]
decode_strategy = DecodePoolSelectionStrategy( decode_strategy = DecodePoolSelectionStrategy(
itl_min=decode_strategy_data["itl_min"], itl_min=decode_strategy_data["itl_min"],
itl_max=decode_strategy_data["itl_max"], itl_max=decode_strategy_data["itl_max"],
...@@ -297,6 +381,7 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig: ...@@ -297,6 +381,7 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
context_length_max=decode_strategy_data["context_length_max"], context_length_max=decode_strategy_data["context_length_max"],
context_length_resolution=decode_strategy_data["context_length_resolution"], context_length_resolution=decode_strategy_data["context_length_resolution"],
decode_pool_mapping=decode_strategy_data["decode_pool_mapping"], decode_pool_mapping=decode_strategy_data["decode_pool_mapping"],
priority_overrides=decode_priority_overrides,
) )
config = GlobalRouterConfig( config = GlobalRouterConfig(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for priority-based pool routing in the global router."""
import json
from pathlib import Path
import pytest
from dynamo.global_router.pool_selection import (
DecodePoolSelectionStrategy,
GlobalRouterConfig,
PrefillPoolSelectionStrategy,
PriorityPoolOverride,
_apply_priority_overrides,
load_config,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.parallel,
pytest.mark.unit,
]
# --- _apply_priority_overrides tests ---
class TestApplyPriorityOverrides:
def test_no_priority_returns_base(self):
rules = [PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=1)]
assert _apply_priority_overrides(0, None, rules) == 0
def test_no_rules_returns_base(self):
assert _apply_priority_overrides(0, 5, []) == 0
def test_match_returns_target(self):
rules = [PriorityPoolOverride(min_priority=5, max_priority=15, target_pool=2)]
assert _apply_priority_overrides(0, 10, rules) == 2
def test_no_match_returns_base(self):
rules = [PriorityPoolOverride(min_priority=5, max_priority=15, target_pool=2)]
assert _apply_priority_overrides(0, 20, rules) == 0
def test_first_match_wins(self):
rules = [
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=1),
PriorityPoolOverride(min_priority=5, max_priority=20, target_pool=2),
]
# priority=7 matches both rules; first should win
assert _apply_priority_overrides(0, 7, rules) == 1
def test_boundary_inclusive(self):
rules = [PriorityPoolOverride(min_priority=5, max_priority=10, target_pool=1)]
assert _apply_priority_overrides(0, 5, rules) == 1
assert _apply_priority_overrides(0, 10, rules) == 1
assert _apply_priority_overrides(0, 4, rules) == 0
assert _apply_priority_overrides(0, 11, rules) == 0
def test_empty_priority_and_empty_rules(self):
assert _apply_priority_overrides(3, None, []) == 3
# --- PrefillPoolSelectionStrategy with priority ---
def _make_prefill_strategy(
num_pools=2, priority_overrides=None
) -> PrefillPoolSelectionStrategy:
return PrefillPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=2,
isl_min=0,
isl_max=32000,
isl_resolution=2,
prefill_pool_mapping=[[0, 1], [0, 1]],
priority_overrides=priority_overrides or [],
)
class TestPrefillSelectPoolWithPriority:
def test_priority_override_takes_precedence(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
# Grid would select pool 0 (low ISL, low TTFT), but priority overrides
result = strategy.select_pool(isl=100, ttft_target=50, priority=50)
assert result == 1
def test_no_priority_uses_grid(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(isl=100, ttft_target=50)
assert result == 0 # grid result, no priority
def test_unmatched_priority_uses_grid(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(isl=100, ttft_target=50, priority=5)
assert result == 0 # priority=5 doesn't match [10, 100]
def test_no_overrides_backward_compatible(self):
strategy = _make_prefill_strategy()
result = strategy.select_pool(isl=100, ttft_target=50, priority=50)
assert result == 0 # no overrides configured, grid result
# --- DecodePoolSelectionStrategy with priority ---
def _make_decode_strategy(
num_pools=2, priority_overrides=None
) -> DecodePoolSelectionStrategy:
return DecodePoolSelectionStrategy(
itl_min=10,
itl_max=500,
itl_resolution=2,
context_length_min=0,
context_length_max=32000,
context_length_resolution=2,
decode_pool_mapping=[[0, 1], [0, 1]],
priority_overrides=priority_overrides or [],
)
class TestDecodeSelectPoolWithPriority:
def test_priority_override_takes_precedence(self):
strategy = _make_decode_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(context_length=100, itl_target=20, priority=50)
assert result == 1
def test_no_priority_uses_grid(self):
strategy = _make_decode_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(context_length=100, itl_target=20)
assert result == 0
# --- Config loading tests ---
def _write_config(tmp_dir: Path, config_data: dict) -> Path:
config_path = tmp_dir / "config.json"
config_path.write_text(json.dumps(config_data))
return config_path
def _base_config(**overrides) -> dict:
config = {
"num_prefill_pools": 2,
"num_decode_pools": 2,
"prefill_pool_dynamo_namespaces": ["ns-prefill-0", "ns-prefill-1"],
"decode_pool_dynamo_namespaces": ["ns-decode-0", "ns-decode-1"],
"prefill_pool_selection_strategy": {
"isl_min": 0,
"isl_max": 32000,
"isl_resolution": 2,
"ttft_min": 10,
"ttft_max": 3000,
"ttft_resolution": 2,
"prefill_pool_mapping": [[0, 1], [0, 1]],
},
"decode_pool_selection_strategy": {
"context_length_min": 0,
"context_length_max": 32000,
"context_length_resolution": 2,
"itl_min": 10,
"itl_max": 500,
"itl_resolution": 2,
"decode_pool_mapping": [[0, 1], [0, 1]],
},
}
config.update(overrides)
return config
class TestLoadConfigWithPriorityOverrides:
def test_with_priority_overrides(self, tmp_path):
config_data = _base_config()
config_data["prefill_pool_selection_strategy"]["priority_overrides"] = [
{"min_priority": 10, "max_priority": 100, "target_pool": 1}
]
config_data["decode_pool_selection_strategy"]["priority_overrides"] = [
{"min_priority": 5, "max_priority": 50, "target_pool": 0}
]
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert len(config.prefill_pool_selection_strategy.priority_overrides) == 1
override = config.prefill_pool_selection_strategy.priority_overrides[0]
assert override.min_priority == 10
assert override.max_priority == 100
assert override.target_pool == 1
assert len(config.decode_pool_selection_strategy.priority_overrides) == 1
override = config.decode_pool_selection_strategy.priority_overrides[0]
assert override.min_priority == 5
assert override.max_priority == 50
assert override.target_pool == 0
def test_without_priority_overrides(self, tmp_path):
config_data = _base_config()
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert config.prefill_pool_selection_strategy.priority_overrides == []
assert config.decode_pool_selection_strategy.priority_overrides == []
# --- Validation tests ---
class TestValidatePriorityOverrides:
def test_invalid_prefill_target_pool(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=5)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=strategy,
decode_pool_selection_strategy=_make_decode_strategy(),
)
with pytest.raises(ValueError, match="invalid target_pool"):
config.validate()
def test_invalid_decode_target_pool(self):
strategy = _make_decode_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=3)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=_make_prefill_strategy(),
decode_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="invalid target_pool"):
config.validate()
def test_min_greater_than_max(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=20, max_priority=5, target_pool=1)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=strategy,
decode_pool_selection_strategy=_make_decode_strategy(),
)
with pytest.raises(ValueError, match="min_priority"):
config.validate()
def test_valid_overrides_pass(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=1)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=strategy,
decode_pool_selection_strategy=_make_decode_strategy(),
)
config.validate() # should not raise
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment