Unverified Commit 7a6db48e authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(global-router): priority-based pool routing from agent hints (#8010)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 8fd1b847
......@@ -202,6 +202,7 @@ def _build_dynamo_preproc(
},
"eos_token_ids": [eos_token_id] if eos_token_id is not None else [],
"annotations": [],
"routing": request.get("routing"),
}
# Forward multimodal URLs so the backend handler can load the media.
......
......@@ -269,6 +269,7 @@ class VllmProcessor:
},
"eos_token_ids": self._get_eos_token_ids(),
"annotations": [],
"routing": request.get("routing"),
}
# Forward multimodal URLs so the backend handler can load the media.
......
......@@ -133,6 +133,43 @@ If `prefill_pool_mapping = [[0, 1], [0, 1]]` and `ttft_resolution=2`:
- High ISL + Low TTFT target → pool 0
- High ISL + High TTFT target → pool 1
### Priority-Based Pool Override
Both prefill and decode strategies support optional `priority_overrides` rules.
When a request carries a priority value (from `nvext.agent_hints.priority`), the
global router evaluates the override rules **after** the grid lookup. The first
rule whose `[min_priority, max_priority]` range contains the request priority
wins, and the request is routed to that rule's `target_pool` instead of the
grid result. If no rule matches (or no priority is present), the grid result
is used as normal.
This is useful for straggler mitigation in RL workloads: the RL framework can
tag slow requests with a high priority, and the global router redirects them to
a dedicated min-latency pool.
```jsonc
"priority_overrides": [
{
"min_priority": 10, // inclusive lower bound
"max_priority": 100, // inclusive upper bound
"target_pool": 1 // pool index to route to
}
]
```
Priority is set by the client via the NVIDIA OpenAI extension:
```json
{
"messages": [...],
"nvext": {
"agent_hints": {
"priority": 50
}
}
}
```
### Passing SLA Targets
Clients can pass TTFT and ITL targets via `extra_args` in the request:
......
......@@ -134,16 +134,20 @@ class GlobalRouterHandler:
extra_args = request.get("extra_args") or {}
ttft_target = extra_args.get("ttft_target") or self.default_ttft_target
# Extract priority from routing hints (set by nvext.agent_hints.priority)
routing = request.get("routing") or {}
priority = routing.get("priority")
# Select prefill pool
pool_idx = self.config.prefill_pool_selection_strategy.select_pool(
isl=isl, ttft_target=ttft_target
isl=isl, ttft_target=ttft_target, priority=priority
)
namespace = self.config.prefill_pool_dynamo_namespaces[pool_idx]
client = self.prefill_clients[namespace]
logger.info(
f"Routing prefill request: ISL={isl}, TTFT_target={ttft_target} -> "
f"pool {pool_idx} ({namespace})"
f"Routing prefill request: ISL={isl}, TTFT_target={ttft_target}, "
f"priority={priority} -> pool {pool_idx} ({namespace})"
)
# Forward request to local router and stream back responses
......@@ -182,15 +186,20 @@ class GlobalRouterHandler:
extra_args = request.get("extra_args") or {}
itl_target = extra_args.get("itl_target") or self.default_itl_target
# Extract priority from routing hints (set by nvext.agent_hints.priority)
routing = request.get("routing") or {}
priority = routing.get("priority")
# Select decode pool
pool_idx = self.config.decode_pool_selection_strategy.select_pool(
context_length=context_length, itl_target=itl_target
context_length=context_length, itl_target=itl_target, priority=priority
)
namespace = self.config.decode_pool_dynamo_namespaces[pool_idx]
client = self.decode_clients[namespace]
logger.info(
f"Routing decode request: context_length={context_length}, ITL_target={itl_target} -> "
f"Routing decode request: context_length={context_length}, "
f"ITL_target={itl_target}, priority={priority} -> "
f"pool {pool_idx} ({namespace})"
)
......
......@@ -12,13 +12,36 @@ The config file defines:
import json
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
logger = logging.getLogger(__name__)
@dataclass
class PriorityPoolOverride:
"""Override pool selection based on request priority from agent hints."""
min_priority: int # inclusive lower bound
max_priority: int # inclusive upper bound
target_pool: int # pool index to route to when priority matches
def _apply_priority_overrides(
base_pool: int,
priority: Optional[int],
overrides: List[PriorityPoolOverride],
) -> int:
"""Apply priority-based pool overrides. First matching rule wins."""
if priority is None or not overrides:
return base_pool
for rule in overrides:
if rule.min_priority <= priority <= rule.max_priority:
return rule.target_pool
return base_pool
@dataclass
class PrefillPoolSelectionStrategy:
"""Strategy for selecting prefill pools based on ISL and TTFT target."""
......@@ -30,6 +53,7 @@ class PrefillPoolSelectionStrategy:
isl_max: int
isl_resolution: int
prefill_pool_mapping: List[List[int]]
priority_overrides: List[PriorityPoolOverride] = field(default_factory=list)
@property
def ttft_step(self) -> float:
......@@ -41,16 +65,23 @@ class PrefillPoolSelectionStrategy:
"""Step size for ISL grid."""
return (self.isl_max - self.isl_min) / self.isl_resolution
def select_pool(self, isl: int, ttft_target: Optional[float] = None) -> int:
def select_pool(
self,
isl: int,
ttft_target: Optional[float] = None,
priority: Optional[int] = None,
) -> int:
"""
Select prefill pool based on ISL and TTFT target.
Select prefill pool based on ISL, TTFT target, and optional priority.
Args:
isl: Input sequence length (number of tokens)
ttft_target: Target time to first token in ms. If None, uses middle of range.
priority: Request priority from agent hints. If set and a priority
override rule matches, the override takes precedence over the grid.
Returns:
Pool index from prefill_pool_mapping
Pool index from prefill_pool_mapping or a priority override
"""
if ttft_target is None:
ttft_target = (self.ttft_min + self.ttft_max) / 2
......@@ -64,9 +95,12 @@ class PrefillPoolSelectionStrategy:
)
pool_idx = self.prefill_pool_mapping[isl_idx][ttft_idx]
pool_idx = _apply_priority_overrides(
pool_idx, priority, self.priority_overrides
)
logger.debug(
f"Prefill pool selection: ISL={isl}, TTFT={ttft_target} -> "
f"grid[{isl_idx}][{ttft_idx}] -> pool {pool_idx}"
f"Prefill pool selection: ISL={isl}, TTFT={ttft_target}, "
f"priority={priority} -> pool {pool_idx}"
)
return pool_idx
......@@ -87,6 +121,7 @@ class DecodePoolSelectionStrategy:
context_length_max: int
context_length_resolution: int
decode_pool_mapping: List[List[int]]
priority_overrides: List[PriorityPoolOverride] = field(default_factory=list)
@property
def itl_step(self) -> float:
......@@ -101,17 +136,22 @@ class DecodePoolSelectionStrategy:
) / self.context_length_resolution
def select_pool(
self, context_length: int, itl_target: Optional[float] = None
self,
context_length: int,
itl_target: Optional[float] = None,
priority: Optional[int] = None,
) -> int:
"""
Select decode pool based on context length and ITL target.
Select decode pool based on context length, ITL target, and optional priority.
Args:
context_length: Total context length (prompt + generated tokens so far)
itl_target: Target inter-token latency in ms. If None, uses middle of range.
priority: Request priority from agent hints. If set and a priority
override rule matches, the override takes precedence over the grid.
Returns:
Pool index from decode_pool_mapping
Pool index from decode_pool_mapping or a priority override
"""
if itl_target is None:
itl_target = (self.itl_min + self.itl_max) / 2
......@@ -126,9 +166,12 @@ class DecodePoolSelectionStrategy:
)
pool_idx = self.decode_pool_mapping[ctx_idx][itl_idx]
pool_idx = _apply_priority_overrides(
pool_idx, priority, self.priority_overrides
)
logger.debug(
f"Decode pool selection: context_length={context_length}, ITL={itl_target} -> "
f"grid[{ctx_idx}][{itl_idx}] -> pool {pool_idx}"
f"Decode pool selection: context_length={context_length}, ITL={itl_target}, "
f"priority={priority} -> pool {pool_idx}"
)
return pool_idx
......@@ -228,6 +271,22 @@ class GlobalRouterConfig:
f"(must be 0 to {self.num_prefill_pools - 1})"
)
for i, override in enumerate(prefill_strategy.priority_overrides):
if override.min_priority > override.max_priority:
raise ValueError(
f"Prefill priority_overrides[{i}]: min_priority "
f"({override.min_priority}) must be <= max_priority "
f"({override.max_priority})"
)
if (
override.target_pool < 0
or override.target_pool >= self.num_prefill_pools
):
raise ValueError(
f"Prefill priority_overrides[{i}]: invalid target_pool "
f"{override.target_pool} (must be 0 to {self.num_prefill_pools - 1})"
)
decode_strategy = self.decode_pool_selection_strategy
if (
len(decode_strategy.decode_pool_mapping)
......@@ -251,6 +310,22 @@ class GlobalRouterConfig:
f"(must be 0 to {self.num_decode_pools - 1})"
)
for i, override in enumerate(decode_strategy.priority_overrides):
if override.min_priority > override.max_priority:
raise ValueError(
f"Decode priority_overrides[{i}]: min_priority "
f"({override.min_priority}) must be <= max_priority "
f"({override.max_priority})"
)
if (
override.target_pool < 0
or override.target_pool >= self.num_decode_pools
):
raise ValueError(
f"Decode priority_overrides[{i}]: invalid target_pool "
f"{override.target_pool} (must be 0 to {self.num_decode_pools - 1})"
)
def load_config(config_path: str | Path) -> GlobalRouterConfig:
"""
......@@ -277,6 +352,10 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
# Parse prefill selection strategy
prefill_strategy_data = data["prefill_pool_selection_strategy"]
prefill_priority_overrides = [
PriorityPoolOverride(**rule)
for rule in prefill_strategy_data.get("priority_overrides", [])
]
prefill_strategy = PrefillPoolSelectionStrategy(
ttft_min=prefill_strategy_data["ttft_min"],
ttft_max=prefill_strategy_data["ttft_max"],
......@@ -285,10 +364,15 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
isl_max=prefill_strategy_data["isl_max"],
isl_resolution=prefill_strategy_data["isl_resolution"],
prefill_pool_mapping=prefill_strategy_data["prefill_pool_mapping"],
priority_overrides=prefill_priority_overrides,
)
# Parse decode selection strategy
decode_strategy_data = data["decode_pool_selection_strategy"]
decode_priority_overrides = [
PriorityPoolOverride(**rule)
for rule in decode_strategy_data.get("priority_overrides", [])
]
decode_strategy = DecodePoolSelectionStrategy(
itl_min=decode_strategy_data["itl_min"],
itl_max=decode_strategy_data["itl_max"],
......@@ -297,6 +381,7 @@ def load_config(config_path: str | Path) -> GlobalRouterConfig:
context_length_max=decode_strategy_data["context_length_max"],
context_length_resolution=decode_strategy_data["context_length_resolution"],
decode_pool_mapping=decode_strategy_data["decode_pool_mapping"],
priority_overrides=decode_priority_overrides,
)
config = GlobalRouterConfig(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for priority-based pool routing in the global router."""
import json
from pathlib import Path
import pytest
from dynamo.global_router.pool_selection import (
DecodePoolSelectionStrategy,
GlobalRouterConfig,
PrefillPoolSelectionStrategy,
PriorityPoolOverride,
_apply_priority_overrides,
load_config,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.parallel,
pytest.mark.unit,
]
# --- _apply_priority_overrides tests ---
class TestApplyPriorityOverrides:
def test_no_priority_returns_base(self):
rules = [PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=1)]
assert _apply_priority_overrides(0, None, rules) == 0
def test_no_rules_returns_base(self):
assert _apply_priority_overrides(0, 5, []) == 0
def test_match_returns_target(self):
rules = [PriorityPoolOverride(min_priority=5, max_priority=15, target_pool=2)]
assert _apply_priority_overrides(0, 10, rules) == 2
def test_no_match_returns_base(self):
rules = [PriorityPoolOverride(min_priority=5, max_priority=15, target_pool=2)]
assert _apply_priority_overrides(0, 20, rules) == 0
def test_first_match_wins(self):
rules = [
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=1),
PriorityPoolOverride(min_priority=5, max_priority=20, target_pool=2),
]
# priority=7 matches both rules; first should win
assert _apply_priority_overrides(0, 7, rules) == 1
def test_boundary_inclusive(self):
rules = [PriorityPoolOverride(min_priority=5, max_priority=10, target_pool=1)]
assert _apply_priority_overrides(0, 5, rules) == 1
assert _apply_priority_overrides(0, 10, rules) == 1
assert _apply_priority_overrides(0, 4, rules) == 0
assert _apply_priority_overrides(0, 11, rules) == 0
def test_empty_priority_and_empty_rules(self):
assert _apply_priority_overrides(3, None, []) == 3
# --- PrefillPoolSelectionStrategy with priority ---
def _make_prefill_strategy(
num_pools=2, priority_overrides=None
) -> PrefillPoolSelectionStrategy:
return PrefillPoolSelectionStrategy(
ttft_min=10,
ttft_max=3000,
ttft_resolution=2,
isl_min=0,
isl_max=32000,
isl_resolution=2,
prefill_pool_mapping=[[0, 1], [0, 1]],
priority_overrides=priority_overrides or [],
)
class TestPrefillSelectPoolWithPriority:
def test_priority_override_takes_precedence(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
# Grid would select pool 0 (low ISL, low TTFT), but priority overrides
result = strategy.select_pool(isl=100, ttft_target=50, priority=50)
assert result == 1
def test_no_priority_uses_grid(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(isl=100, ttft_target=50)
assert result == 0 # grid result, no priority
def test_unmatched_priority_uses_grid(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(isl=100, ttft_target=50, priority=5)
assert result == 0 # priority=5 doesn't match [10, 100]
def test_no_overrides_backward_compatible(self):
strategy = _make_prefill_strategy()
result = strategy.select_pool(isl=100, ttft_target=50, priority=50)
assert result == 0 # no overrides configured, grid result
# --- DecodePoolSelectionStrategy with priority ---
def _make_decode_strategy(
num_pools=2, priority_overrides=None
) -> DecodePoolSelectionStrategy:
return DecodePoolSelectionStrategy(
itl_min=10,
itl_max=500,
itl_resolution=2,
context_length_min=0,
context_length_max=32000,
context_length_resolution=2,
decode_pool_mapping=[[0, 1], [0, 1]],
priority_overrides=priority_overrides or [],
)
class TestDecodeSelectPoolWithPriority:
def test_priority_override_takes_precedence(self):
strategy = _make_decode_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(context_length=100, itl_target=20, priority=50)
assert result == 1
def test_no_priority_uses_grid(self):
strategy = _make_decode_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=10, max_priority=100, target_pool=1)
]
)
result = strategy.select_pool(context_length=100, itl_target=20)
assert result == 0
# --- Config loading tests ---
def _write_config(tmp_dir: Path, config_data: dict) -> Path:
config_path = tmp_dir / "config.json"
config_path.write_text(json.dumps(config_data))
return config_path
def _base_config(**overrides) -> dict:
config = {
"num_prefill_pools": 2,
"num_decode_pools": 2,
"prefill_pool_dynamo_namespaces": ["ns-prefill-0", "ns-prefill-1"],
"decode_pool_dynamo_namespaces": ["ns-decode-0", "ns-decode-1"],
"prefill_pool_selection_strategy": {
"isl_min": 0,
"isl_max": 32000,
"isl_resolution": 2,
"ttft_min": 10,
"ttft_max": 3000,
"ttft_resolution": 2,
"prefill_pool_mapping": [[0, 1], [0, 1]],
},
"decode_pool_selection_strategy": {
"context_length_min": 0,
"context_length_max": 32000,
"context_length_resolution": 2,
"itl_min": 10,
"itl_max": 500,
"itl_resolution": 2,
"decode_pool_mapping": [[0, 1], [0, 1]],
},
}
config.update(overrides)
return config
class TestLoadConfigWithPriorityOverrides:
def test_with_priority_overrides(self, tmp_path):
config_data = _base_config()
config_data["prefill_pool_selection_strategy"]["priority_overrides"] = [
{"min_priority": 10, "max_priority": 100, "target_pool": 1}
]
config_data["decode_pool_selection_strategy"]["priority_overrides"] = [
{"min_priority": 5, "max_priority": 50, "target_pool": 0}
]
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert len(config.prefill_pool_selection_strategy.priority_overrides) == 1
override = config.prefill_pool_selection_strategy.priority_overrides[0]
assert override.min_priority == 10
assert override.max_priority == 100
assert override.target_pool == 1
assert len(config.decode_pool_selection_strategy.priority_overrides) == 1
override = config.decode_pool_selection_strategy.priority_overrides[0]
assert override.min_priority == 5
assert override.max_priority == 50
assert override.target_pool == 0
def test_without_priority_overrides(self, tmp_path):
config_data = _base_config()
config_path = _write_config(tmp_path, config_data)
config = load_config(config_path)
assert config.prefill_pool_selection_strategy.priority_overrides == []
assert config.decode_pool_selection_strategy.priority_overrides == []
# --- Validation tests ---
class TestValidatePriorityOverrides:
def test_invalid_prefill_target_pool(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=5)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=strategy,
decode_pool_selection_strategy=_make_decode_strategy(),
)
with pytest.raises(ValueError, match="invalid target_pool"):
config.validate()
def test_invalid_decode_target_pool(self):
strategy = _make_decode_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=3)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=_make_prefill_strategy(),
decode_pool_selection_strategy=strategy,
)
with pytest.raises(ValueError, match="invalid target_pool"):
config.validate()
def test_min_greater_than_max(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=20, max_priority=5, target_pool=1)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=strategy,
decode_pool_selection_strategy=_make_decode_strategy(),
)
with pytest.raises(ValueError, match="min_priority"):
config.validate()
def test_valid_overrides_pass(self):
strategy = _make_prefill_strategy(
priority_overrides=[
PriorityPoolOverride(min_priority=1, max_priority=10, target_pool=1)
]
)
config = GlobalRouterConfig(
num_prefill_pools=2,
num_decode_pools=2,
prefill_pool_dynamo_namespaces=["a", "b"],
decode_pool_dynamo_namespaces=["c", "d"],
prefill_pool_selection_strategy=strategy,
decode_pool_selection_strategy=_make_decode_strategy(),
)
config.validate() # should not raise
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment