Unverified Commit bddaaa26 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): pluggable scheduling policy for router queue [DYN-2454] (#7260)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 12785247
...@@ -1896,6 +1896,7 @@ dependencies = [ ...@@ -1896,6 +1896,7 @@ dependencies = [
"dynamo-tokens", "dynamo-tokens",
"flume", "flume",
"indicatif 0.18.4", "indicatif 0.18.4",
"ordered-float 4.6.0",
"parking_lot", "parking_lot",
"plotters", "plotters",
"prometheus", "prometheus",
...@@ -5262,6 +5263,15 @@ dependencies = [ ...@@ -5262,6 +5263,15 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "ordered-float"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ordered-multimap" name = "ordered-multimap"
version = "0.7.3" version = "0.7.3"
...@@ -6952,7 +6962,7 @@ version = "0.7.0" ...@@ -6952,7 +6962,7 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c"
dependencies = [ dependencies = [
"ordered-float", "ordered-float 2.10.1",
"serde", "serde",
] ]
......
...@@ -93,6 +93,7 @@ modelexpress-common = { version = "0.2.0" } ...@@ -93,6 +93,7 @@ modelexpress-common = { version = "0.2.0" }
humantime = { version = "2.2.0" } humantime = { version = "2.2.0" }
libc = { version = "0.2" } libc = { version = "0.2" }
oneshot = { version = "0.1.13", features = ["std", "async"] } oneshot = { version = "0.1.13", features = ["std", "async"] }
ordered-float = "4"
parking_lot = "0.12.5" parking_lot = "0.12.5"
prometheus = { version = "0.14"} prometheus = { version = "0.14"}
rand = { version = "0.9.2" } rand = { version = "0.9.2" }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Shared KV router configuration ArgGroup. """Shared KV router configuration ArgGroup.
Defines the 16 KvRouterConfig parameters once so that both Defines the 17 KvRouterConfig parameters once so that both
``dynamo.frontend`` and ``dynamo.router`` can reuse them without duplication. ``dynamo.frontend`` and ``dynamo.router`` can reuse them without duplication.
Field names on ``KvRouterConfigBase`` match the ``KvRouterConfig`` Python Field names on ``KvRouterConfigBase`` match the ``KvRouterConfig`` Python
constructor kwargs 1:1, so ``kv_router_kwargs()`` returns a dict that can be constructor kwargs 1:1, so ``kv_router_kwargs()`` returns a dict that can be
...@@ -34,11 +34,12 @@ _KV_ROUTER_FIELDS: tuple[str, ...] = ( ...@@ -34,11 +34,12 @@ _KV_ROUTER_FIELDS: tuple[str, ...] = (
"router_queue_threshold", "router_queue_threshold",
"router_event_threads", "router_event_threads",
"router_enable_cache_control", "router_enable_cache_control",
"router_queue_policy",
) )
class KvRouterConfigBase(ConfigBase): class KvRouterConfigBase(ConfigBase):
"""Mixin carrying the 16 KvRouterConfig fields.""" """Mixin carrying the 17 KvRouterConfig fields."""
overlap_score_weight: float overlap_score_weight: float
router_temperature: float router_temperature: float
...@@ -56,6 +57,7 @@ class KvRouterConfigBase(ConfigBase): ...@@ -56,6 +57,7 @@ class KvRouterConfigBase(ConfigBase):
router_queue_threshold: Optional[float] router_queue_threshold: Optional[float]
router_event_threads: int router_event_threads: int
router_enable_cache_control: bool router_enable_cache_control: bool
router_queue_policy: str
def kv_router_kwargs(self) -> dict: def kv_router_kwargs(self) -> dict:
"""Return a dict suitable for ``KvRouterConfig(**kwargs)``.""" """Return a dict suitable for ``KvRouterConfig(**kwargs)``."""
...@@ -63,7 +65,7 @@ class KvRouterConfigBase(ConfigBase): ...@@ -63,7 +65,7 @@ class KvRouterConfigBase(ConfigBase):
class KvRouterArgGroup(ArgGroup): class KvRouterArgGroup(ArgGroup):
"""CLI arguments for the 16 KvRouterConfig parameters.""" """CLI arguments for the 17 KvRouterConfig parameters."""
def add_arguments(self, parser) -> None: def add_arguments(self, parser) -> None:
g = parser.add_argument_group("KV Router Options") g = parser.add_argument_group("KV Router Options")
...@@ -222,11 +224,11 @@ class KvRouterArgGroup(ArgGroup): ...@@ -222,11 +224,11 @@ class KvRouterArgGroup(ArgGroup):
g, g,
flag_name="--router-queue-threshold", flag_name="--router-queue-threshold",
env_var="DYN_ROUTER_QUEUE_THRESHOLD", env_var="DYN_ROUTER_QUEUE_THRESHOLD",
default=None, default=2.0,
help=( help=(
"KV Router: Queue threshold fraction for prefill token capacity. " "KV Router: Queue threshold fraction for prefill token capacity. "
"When set, requests are queued if all workers exceed this fraction of " "Requests are queued if all workers exceed this fraction of "
"max_num_batched_tokens. Must be > 0. If not set, queueing is disabled." "max_num_batched_tokens. Must be > 0."
), ),
arg_type=float, arg_type=float,
) )
...@@ -254,3 +256,16 @@ class KvRouterArgGroup(ArgGroup): ...@@ -254,3 +256,16 @@ class KvRouterArgGroup(ArgGroup):
"requests with nvext.cache_control." "requests with nvext.cache_control."
), ),
) )
add_argument(
g,
flag_name="--router-queue-policy",
env_var="DYN_ROUTER_QUEUE_POLICY",
default="fcfs",
help=(
"KV Router: Scheduling policy for the router queue. "
"'fcfs' (default): first-come first-served with priority bumps — optimizes tail TTFT. "
"'wspt': weighted shortest processing time (Smith's rule) — optimizes average TTFT."
),
arg_type=str,
choices=["fcfs", "wspt"],
)
...@@ -265,6 +265,7 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path) ...@@ -265,6 +265,7 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path)
model_path=args.model_path, model_path=args.model_path,
model_name=args.model_name, model_name=args.model_name,
endpoint_id=args.endpoint, endpoint_id=args.endpoint,
context_length=0,
extra_engine_args=str(worker_engine_args_path), extra_engine_args=str(worker_engine_args_path),
runtime_config=runtime_config, runtime_config=runtime_config,
kv_cache_block_size=kv_cache_block_size, kv_cache_block_size=kv_cache_block_size,
......
...@@ -43,7 +43,8 @@ The Rust HTTP server also reads these environment variables (not exposed as CLI ...@@ -43,7 +43,8 @@ The Rust HTTP server also reads these environment variables (not exposed as CLI
| `--router-assume-kv-reuse` / `--no-router-assume-kv-reuse` | `DYN_ROUTER_ASSUME_KV_REUSE` | `true` | Assume KV cache reuse when tracking active blocks | | `--router-assume-kv-reuse` / `--no-router-assume-kv-reuse` | `DYN_ROUTER_ASSUME_KV_REUSE` | `true` | Assume KV cache reuse when tracking active blocks |
| `--router-track-output-blocks` / `--no-router-track-output-blocks` | `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` | `false` | Track output blocks with fractional decay during generation | | `--router-track-output-blocks` / `--no-router-track-output-blocks` | `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` | `false` | Track output blocks with fractional decay during generation |
| `--router-event-threads` | `DYN_ROUTER_EVENT_THREADS` | `4` | Event processing threads. >1 enables concurrent radix tree | | `--router-event-threads` | `DYN_ROUTER_EVENT_THREADS` | `4` | Event processing threads. >1 enables concurrent radix tree |
| `--router-queue-threshold` | `DYN_ROUTER_QUEUE_THRESHOLD` | — | Queue threshold fraction of prefill capacity. Enables priority scheduling | | `--router-queue-threshold` | `DYN_ROUTER_QUEUE_THRESHOLD` | `2.0` | Queue threshold fraction of prefill capacity. Enables priority scheduling |
| `--router-queue-policy` | `DYN_ROUTER_QUEUE_POLICY` | `fcfs` | Queue scheduling policy: `fcfs` (tail TTFT) or `wspt` (avg TTFT) |
| `--enable-cache-control` / `--no-enable-cache-control` | `DYN_ENABLE_CACHE_CONTROL` | `false` | Enable TTL-based cache pinning (requires `--router-mode=kv`) | | `--enable-cache-control` / `--no-enable-cache-control` | `DYN_ENABLE_CACHE_CONTROL` | `false` | Enable TTL-based cache pinning (requires `--router-mode=kv`) |
| `--decode-fallback` / `--no-decode-fallback` | `DYN_DECODE_FALLBACK` | `false` | Fall back to aggregated mode when prefill workers unavailable | | `--decode-fallback` / `--no-decode-fallback` | `DYN_DECODE_FALLBACK` | `false` | Fall back to aggregated mode when prefill workers unavailable |
......
...@@ -21,7 +21,8 @@ For Kubernetes, set `DYN_ROUTER_MODE=kv` on the Frontend service. Workers automa ...@@ -21,7 +21,8 @@ For Kubernetes, set `DYN_ROUTER_MODE=kv` on the Frontend service. Workers automa
| `--router-mode kv` | `round_robin` | Enable KV cache-aware routing | | `--router-mode kv` | `round_robin` | Enable KV cache-aware routing |
| `--router-kv-overlap-score-weight` | `1.0` | Balance prefill vs decode optimization (higher = better TTFT) | | `--router-kv-overlap-score-weight` | `1.0` | Balance prefill vs decode optimization (higher = better TTFT) |
| `--no-router-kv-events` | enabled | Fall back to approximate routing (no event consumption from workers) | | `--no-router-kv-events` | enabled | Fall back to approximate routing (no event consumption from workers) |
| `--router-queue-threshold` | disabled | Enable backpressure queue under high concurrency; also enables priority scheduling via `nvext.agent_hints.latency_sensitivity` | | `--router-queue-threshold` | `2.0` | Backpressure queue threshold; enables priority scheduling via `nvext.agent_hints.latency_sensitivity` |
| `--router-queue-policy` | `fcfs` | Queue scheduling policy: `fcfs` (tail TTFT) or `wspt` (avg TTFT) |
### Standalone Router ### Standalone Router
......
...@@ -36,7 +36,8 @@ Backend workers register themselves using the `register_model` API, after which ...@@ -36,7 +36,8 @@ Backend workers register themselves using the `register_model` API, after which
| `--kv-cache-block-size <size>` | Backend-specific | KV cache block size (should match backend config) | | `--kv-cache-block-size <size>` | Backend-specific | KV cache block size (should match backend config) |
| `--router-kv-events` / `--no-router-kv-events` | `--router-kv-events` | Enable/disable real-time KV event tracking | | `--router-kv-events` / `--no-router-kv-events` | `--router-kv-events` | Enable/disable real-time KV event tracking |
| `--router-kv-overlap-score-weight <float>` | `1.0` | Balance prefill vs decode optimization (higher = better TTFT) | | `--router-kv-overlap-score-weight <float>` | `1.0` | Balance prefill vs decode optimization (higher = better TTFT) |
| `--router-queue-threshold <float>` | None (disabled) | Queue threshold fraction; enables priority scheduling via `latency_sensitivity` | | `--router-queue-threshold <float>` | `2.0` | Queue threshold fraction; enables priority scheduling via `latency_sensitivity` |
| `--router-queue-policy <str>` | `fcfs` | Scheduling policy for the queue: `fcfs` (tail TTFT) or `wspt` (avg TTFT) |
For all available options: `python -m dynamo.frontend --help` For all available options: `python -m dynamo.frontend --help`
...@@ -78,6 +79,7 @@ All CLI arguments can be configured via environment variables using the `DYN_` p ...@@ -78,6 +79,7 @@ All CLI arguments can be configured via environment variables using the `DYN_` p
| `--kv-cache-block-size` | `DYN_KV_CACHE_BLOCK_SIZE` | Backend-specific | | `--kv-cache-block-size` | `DYN_KV_CACHE_BLOCK_SIZE` | Backend-specific |
| `--no-router-kv-events` | `DYN_ROUTER_USE_KV_EVENTS=false` | `true` | | `--no-router-kv-events` | `DYN_ROUTER_USE_KV_EVENTS=false` | `true` |
| `--router-kv-overlap-score-weight` | `DYN_ROUTER_KV_OVERLAP_SCORE_WEIGHT` | `1.0` | | `--router-kv-overlap-score-weight` | `DYN_ROUTER_KV_OVERLAP_SCORE_WEIGHT` | `1.0` |
| `--router-queue-policy` | `DYN_ROUTER_QUEUE_POLICY` | `fcfs` |
For complete K8s examples and advanced configuration, see [K8s Examples](router-examples.md#k8s-examples). For complete K8s examples and advanced configuration, see [K8s Examples](router-examples.md#k8s-examples).
For A/B testing and advanced K8s setup, see the [KV Router A/B Benchmarking Guide](../../benchmarks/kv-router-ab-testing.md). For A/B testing and advanced K8s setup, see the [KV Router A/B Benchmarking Guide](../../benchmarks/kv-router-ab-testing.md).
...@@ -178,7 +180,11 @@ The main KV-aware routing arguments (frontend uses the same `--router-*` flag na ...@@ -178,7 +180,11 @@ The main KV-aware routing arguments (frontend uses the same `--router-*` flag na
- `--router-temperature`: Controls worker selection randomness through softmax sampling of router cost logits. A value of 0 (default) ensures deterministic selection of the lowest-cost worker, while higher values introduce more randomness. - `--router-temperature`: Controls worker selection randomness through softmax sampling of router cost logits. A value of 0 (default) ensures deterministic selection of the lowest-cost worker, while higher values introduce more randomness.
- `--router-queue-threshold`: Queue threshold fraction for prefill token capacity. When set, the router holds incoming requests in a priority queue while all workers exceed this fraction of `max_num_batched_tokens`, releasing them when capacity frees up. This defers dispatch (not rejection) so that routing decisions use the most up-to-date load metrics at the moment the request is actually sent to a worker. It also enables **priority scheduling** via `latency_sensitivity` hints in `nvext.agent_hints` — higher values shift a request's effective arrival time earlier in the queue, giving it priority over lower-valued requests. Must be > 0. If not set (default), queueing is disabled and requests are dispatched immediately. - `--router-queue-threshold`: Queue threshold fraction for prefill token capacity (default: 2.0). The router holds incoming requests in a priority queue while all workers exceed this fraction of `max_num_batched_tokens`, releasing them when capacity frees up. This defers dispatch (not rejection) so that routing decisions use the most up-to-date load metrics at the moment the request is actually sent to a worker. It also enables **priority scheduling** via `latency_sensitivity` hints in `nvext.agent_hints` — higher values shift a request's effective arrival time earlier in the queue, giving it priority over lower-valued requests. Must be > 0. Set to None to disable queueing (requests are dispatched immediately).
- `--router-queue-policy`: Scheduling policy for the router queue (default: `fcfs`). Two policies are available:
- **`fcfs`** (first-come first-served): Orders by adjusted arrival time (`priority_jump - arrival_offset`). Optimizes **tail TTFT** — no request waits longer than necessary.
- **`wspt`** (weighted shortest processing time, Smith's rule): Orders by `(1 + priority_jump) / isl_tokens`. Optimizes **average TTFT** — short or high-priority requests are scheduled before long low-priority ones, minimizing total weighted completion time.
### KV Event Transport and Persistence ### KV Event Transport and Persistence
...@@ -224,7 +230,9 @@ Use `--no-router-assume-kv-reuse` in disaggregated setups where the decode worke ...@@ -224,7 +230,9 @@ Use `--no-router-assume-kv-reuse` in disaggregated setups where the decode worke
Use `--router-track-output-blocks` **(experimental)** when your workload is output-heavy and you want the router to account for output-side KV cache growth in load balancing. This is useful in two scenarios: (1) workloads with long output sequences and little multi-turn reuse, where output blocks dominate the KV cache footprint; (2) agentic schedulers (e.g. NAT or other LLM routers) that can accurately predict the expected output sequence length per request. When enabled, the router adds placeholder blocks as tokens are generated. If you additionally pass `nvext.agent_hints.osl` (expected output sequence length in tokens) per request, the router applies fractional decay to output blocks — each output block's weight starts at 1.0 and decays linearly toward 0.0 as generation approaches the expected OSL. This lets the router predict that a request nearing completion will soon free its blocks, effectively modeling the future load trajectory rather than just the current snapshot. Without `osl`, output blocks are added at full weight with no decay. The flag requires `--router-track-active-blocks` (the default). Use `--router-track-output-blocks` **(experimental)** when your workload is output-heavy and you want the router to account for output-side KV cache growth in load balancing. This is useful in two scenarios: (1) workloads with long output sequences and little multi-turn reuse, where output blocks dominate the KV cache footprint; (2) agentic schedulers (e.g. NAT or other LLM routers) that can accurately predict the expected output sequence length per request. When enabled, the router adds placeholder blocks as tokens are generated. If you additionally pass `nvext.agent_hints.osl` (expected output sequence length in tokens) per request, the router applies fractional decay to output blocks — each output block's weight starts at 1.0 and decays linearly toward 0.0 as generation approaches the expected OSL. This lets the router predict that a request nearing completion will soon free its blocks, effectively modeling the future load trajectory rather than just the current snapshot. Without `osl`, output blocks are added at full weight with no decay. The flag requires `--router-track-active-blocks` (the default).
Set `--router-queue-threshold` (e.g. `1.5`) to enable backpressure under very high concurrency workloads. When set, the router holds incoming requests in a priority queue while all workers exceed the given fraction of `max_num_batched_tokens`, releasing them as capacity frees up. This defers the routing decision so it is made with the freshest load metrics, rather than dispatching into an already-saturated system. It also enables priority scheduling via `nvext.agent_hints.latency_sensitivity`. The `--router-queue-threshold` (default: 2.0) controls when incoming requests are held in a priority queue. The router holds requests while all workers exceed the given fraction of `max_num_batched_tokens`, releasing them as capacity frees up. This defers the routing decision so it is made with the freshest load metrics, rather than dispatching into an already-saturated system. It also enables priority scheduling via `nvext.agent_hints.latency_sensitivity`. Set to None to disable queueing entirely.
Use `--router-queue-policy wspt` when your workload has a mix of short and long requests and you want to minimize **average** TTFT. WSPT (Smith's rule) schedules short or high-priority requests first, reducing mean latency across the batch. Use the default `fcfs` when you want to minimize **tail** TTFT — no request waits longer than necessary, since ordering is purely by (adjusted) arrival time.
### Prometheus Metrics ### Prometheus Metrics
......
...@@ -554,6 +554,9 @@ fn kv_router_config_from_env() -> KvRouterConfig { ...@@ -554,6 +554,9 @@ fn kv_router_config_from_env() -> KvRouterConfig {
if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") { if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") {
cfg.router_track_output_blocks = v; cfg.router_track_output_blocks = v;
} }
if let Some(v) = env_f64("DYN_ROUTER_QUEUE_THRESHOLD") {
cfg.router_queue_threshold = Some(v);
}
tracing::info!( tracing::info!(
overlap_score_weight = cfg.overlap_score_weight, overlap_score_weight = cfg.overlap_score_weight,
...@@ -562,6 +565,7 @@ fn kv_router_config_from_env() -> KvRouterConfig { ...@@ -562,6 +565,7 @@ fn kv_router_config_from_env() -> KvRouterConfig {
router_replica_sync = cfg.router_replica_sync, router_replica_sync = cfg.router_replica_sync,
router_track_active_blocks = cfg.router_track_active_blocks, router_track_active_blocks = cfg.router_track_active_blocks,
router_track_output_blocks = cfg.router_track_output_blocks, router_track_output_blocks = cfg.router_track_output_blocks,
router_queue_threshold = ?cfg.router_queue_threshold,
"KvRouterConfig initialized (DYN_* env overrides applied)" "KvRouterConfig initialized (DYN_* env overrides applied)"
); );
......
...@@ -878,7 +878,7 @@ version = "3.1.1" ...@@ -878,7 +878,7 @@ version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -1424,7 +1424,7 @@ dependencies = [ ...@@ -1424,7 +1424,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users", "redox_users",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -1525,6 +1525,7 @@ dependencies = [ ...@@ -1525,6 +1525,7 @@ dependencies = [
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens", "dynamo-tokens",
"flume", "flume",
"ordered-float 4.6.0",
"parking_lot", "parking_lot",
"prometheus", "prometheus",
"rand 0.9.2", "rand 0.9.2",
...@@ -1886,7 +1887,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -1886,7 +1887,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -2976,7 +2977,7 @@ dependencies = [ ...@@ -2976,7 +2977,7 @@ dependencies = [
"portable-atomic", "portable-atomic",
"portable-atomic-util", "portable-atomic-util",
"serde_core", "serde_core",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -3957,7 +3958,7 @@ version = "0.50.3" ...@@ -3957,7 +3958,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -4445,6 +4446,15 @@ dependencies = [ ...@@ -4445,6 +4446,15 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "ordered-float"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ordered-multimap" name = "ordered-multimap"
version = "0.7.3" version = "0.7.3"
...@@ -4889,7 +4899,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4889,7 +4899,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [ dependencies = [
"heck", "heck",
"itertools 0.14.0", "itertools 0.11.0",
"log", "log",
"multimap", "multimap",
"once_cell", "once_cell",
...@@ -4909,7 +4919,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4909,7 +4919,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [ dependencies = [
"heck", "heck",
"itertools 0.14.0", "itertools 0.11.0",
"log", "log",
"multimap", "multimap",
"petgraph 0.8.3", "petgraph 0.8.3",
...@@ -4930,7 +4940,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4930,7 +4940,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.14.0", "itertools 0.11.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn",
...@@ -4943,7 +4953,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4943,7 +4953,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.14.0", "itertools 0.11.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn",
...@@ -5194,7 +5204,7 @@ dependencies = [ ...@@ -5194,7 +5204,7 @@ dependencies = [
"once_cell", "once_cell",
"socket2 0.6.3", "socket2 0.6.3",
"tracing", "tracing",
"windows-sys 0.60.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -5646,7 +5656,7 @@ dependencies = [ ...@@ -5646,7 +5656,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.12.1", "linux-raw-sys 0.12.1",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -5947,7 +5957,7 @@ version = "0.7.0" ...@@ -5947,7 +5957,7 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c"
dependencies = [ dependencies = [
"ordered-float", "ordered-float 2.10.1",
"serde", "serde",
] ]
...@@ -6397,7 +6407,7 @@ dependencies = [ ...@@ -6397,7 +6407,7 @@ dependencies = [
"getrandom 0.4.2", "getrandom 0.4.2",
"once_cell", "once_cell",
"rustix 1.1.4", "rustix 1.1.4",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
...@@ -7628,7 +7638,7 @@ version = "0.1.11" ...@@ -7628,7 +7638,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
......
...@@ -1535,6 +1535,7 @@ dependencies = [ ...@@ -1535,6 +1535,7 @@ dependencies = [
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens", "dynamo-tokens",
"flume", "flume",
"ordered-float 4.6.0",
"parking_lot", "parking_lot",
"prometheus", "prometheus",
"rand 0.9.2", "rand 0.9.2",
...@@ -4502,6 +4503,15 @@ dependencies = [ ...@@ -4502,6 +4503,15 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "ordered-float"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ordered-multimap" name = "ordered-multimap"
version = "0.7.3" version = "0.7.3"
...@@ -6014,7 +6024,7 @@ version = "0.7.0" ...@@ -6014,7 +6024,7 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c"
dependencies = [ dependencies = [
"ordered-float", "ordered-float 2.10.1",
"serde", "serde",
] ]
......
...@@ -54,7 +54,7 @@ impl KvRouterConfig { ...@@ -54,7 +54,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=None, router_event_threads=4, router_enable_cache_control=false))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(2.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs"))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -73,6 +73,7 @@ impl KvRouterConfig { ...@@ -73,6 +73,7 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>, router_queue_threshold: Option<f64>,
router_event_threads: u32, router_event_threads: u32,
router_enable_cache_control: bool, router_enable_cache_control: bool,
router_queue_policy: &str,
) -> Self { ) -> Self {
KvRouterConfig { KvRouterConfig {
inner: RsKvRouterConfig { inner: RsKvRouterConfig {
...@@ -92,6 +93,9 @@ impl KvRouterConfig { ...@@ -92,6 +93,9 @@ impl KvRouterConfig {
router_queue_threshold, router_queue_threshold,
router_event_threads, router_event_threads,
router_enable_cache_control, router_enable_cache_control,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}")
}),
}, },
} }
} }
......
...@@ -1099,9 +1099,10 @@ class KvRouterConfig: ...@@ -1099,9 +1099,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576, router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = None, router_queue_threshold: Optional[float] = 2.0,
router_event_threads: int = 4, router_event_threads: int = 4,
router_enable_cache_control: bool = False, router_enable_cache_control: bool = False,
router_queue_policy: str = "fcfs",
) -> None: ) -> None:
""" """
Create a KV router configuration. Create a KV router configuration.
...@@ -1126,14 +1127,17 @@ class KvRouterConfig: ...@@ -1126,14 +1127,17 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0) router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20) router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8) router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: None). router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 2.0).
When set, requests are queued if all workers exceed this fraction of Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
max_num_batched_tokens. Enables priority scheduling via latency_sensitivity hints. Enables priority scheduling via latency_sensitivity hints.
If None, queueing is disabled and all requests go directly to the scheduler. Set to None to disable queueing (all requests go directly to the scheduler).
router_event_threads: Number of event processing threads (default: 4). router_event_threads: Number of event processing threads (default: 4).
When > 1, uses a concurrent radix tree with a thread pool. When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False). cache_control service mesh endpoint (default: False).
router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
""" """
... ...
......
...@@ -27,6 +27,7 @@ dynamo-tokens = { workspace = true } ...@@ -27,6 +27,7 @@ dynamo-tokens = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
dashmap = { workspace = true } dashmap = { workspace = true }
ordered-float = { workspace = true }
derive_builder = { workspace = true } derive_builder = { workspace = true }
derive-getters = { workspace = true } derive-getters = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
......
...@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{ ...@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{
}; };
pub use self::sequence::{ActiveSequences, RequestId}; pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use config::{KvRouterConfig, RouterConfigOverride}; pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink; pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
...@@ -52,5 +52,6 @@ pub use protocols::{ ...@@ -52,5 +52,6 @@ pub use protocols::{
}; };
pub use queue::SchedulerQueue; pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree; pub use radix_tree::RadixTree;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector}; pub use selector::{DefaultWorkerSelector, WorkerSelector};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::fmt;
use std::str::FromStr;
use derive_builder::Builder; use derive_builder::Builder;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -8,6 +11,37 @@ use validator::{Validate, ValidationError}; ...@@ -8,6 +11,37 @@ use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block}; use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy {
#[default]
Fcfs,
Wspt,
}
impl fmt::Display for RouterQueuePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fcfs => f.write_str("fcfs"),
Self::Wspt => f.write_str("wspt"),
}
}
}
impl FromStr for RouterQueuePolicy {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"fcfs" => Ok(Self::Fcfs),
"wspt" => Ok(Self::Wspt),
_ => Err(format!(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'"
)),
}
}
}
/// Override configuration for router settings that can be specified per-request /// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride { pub struct RouterConfigOverride {
...@@ -75,8 +109,8 @@ pub struct KvRouterConfig { ...@@ -75,8 +109,8 @@ pub struct KvRouterConfig {
/// Queue threshold fraction for prefill token capacity. /// Queue threshold fraction for prefill token capacity.
/// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens. /// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
/// If None (default), queueing is disabled and all requests go directly to ready. /// If None, queueing is disabled and all requests go directly to ready.
/// Must be > 0. /// Default: 2.0. Must be > 0.
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub router_queue_threshold: Option<f64>, pub router_queue_threshold: Option<f64>,
...@@ -91,6 +125,11 @@ pub struct KvRouterConfig { ...@@ -91,6 +125,11 @@ pub struct KvRouterConfig {
/// requests, firing a pin_prefix call (with TTL) to the worker after generation completes. /// requests, firing a pin_prefix call (with TTL) to the worker after generation completes.
/// When false (default), cache_control is ignored and no cache_control client is created. /// When false (default), cache_control is ignored and no cache_control client is created.
pub router_enable_cache_control: bool, pub router_enable_cache_control: bool,
/// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
pub router_queue_policy: RouterQueuePolicy,
} }
impl Default for KvRouterConfig { impl Default for KvRouterConfig {
...@@ -109,9 +148,10 @@ impl Default for KvRouterConfig { ...@@ -109,9 +148,10 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default() router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8, router_prune_target_ratio: 0.8,
router_queue_threshold: None, router_queue_threshold: Some(2.0),
router_event_threads: 4, router_event_threads: 4,
router_enable_cache_control: false, router_enable_cache_control: false,
router_queue_policy: RouterQueuePolicy::default(),
} }
} }
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod config; pub mod config;
pub mod policy;
pub mod queue; pub mod queue;
pub mod selector; pub mod selector;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
use ordered_float::OrderedFloat;
use super::config::RouterQueuePolicy;
use super::types::SchedulingRequest;
/// Pluggable scheduling policy that determines queue ordering.
/// Monomorphized for zero-cost inlining on the hot comparison path.
///
/// Higher key = higher priority (natural max-heap ordering).
pub trait SchedulingPolicy: Send + Sync + 'static {
/// Priority key stored in each queue entry.
type Key: Ord + Eq + Clone + Send + 'static;
/// Compute priority key at enqueue time.
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key;
/// Recompute priority key during update(). Default: return old key unchanged.
fn rekey(&self, _now: Duration, old_key: &Self::Key, _req: &SchedulingRequest) -> Self::Key {
old_key.clone()
}
/// When true, queue rebuilds heap via rekey() on each update() call.
/// When false (default), rekey path is compiled out entirely.
const DYNAMIC: bool = false;
}
/// FCFS with priority bumps: key = priority_jump - arrival_offset.
/// Earlier arrival or higher priority_jump produces a higher key, scheduled first.
///
/// Optimizes for tail TTFT — no request waits longer than necessary,
/// since ordering is purely by (adjusted) arrival time.
pub struct FcfsPolicy;
impl SchedulingPolicy for FcfsPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
OrderedFloat(request.priority_jump.max(0.0) - arrival_offset.as_secs_f64())
}
}
/// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL.
/// We use max because the downstream selector routes to the best-overlap
/// worker, so the realized overlap is well-approximated by the best available.
///
/// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before
/// long low-priority ones, reducing mean latency across the batch.
pub struct WsptPolicy {
pub block_size: usize,
}
impl SchedulingPolicy for WsptPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
let weight = 1.0 + request.priority_jump.max(0.0);
let max_overlap = request.overlaps.scores.values().copied().max().unwrap_or(0) as usize;
let cached_tokens = max_overlap * self.block_size;
let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
OrderedFloat(weight / new_tokens as f64)
}
}
/// Runtime-dispatched scheduling policy selected via configuration.
/// Delegates to the concrete policy variant; the branch is fully predictable
/// since the variant is fixed at queue construction time.
pub enum RouterSchedulingPolicy {
Fcfs(FcfsPolicy),
Wspt(WsptPolicy),
}
impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
}
}
}
impl SchedulingPolicy for RouterSchedulingPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
match self {
Self::Fcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Wspt(p) => p.enqueue_key(arrival_offset, request),
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use super::*;
use crate::protocols::{OverlapScores, WorkerWithDpRank};
fn request_with(
isl_tokens: usize,
priority_jump: f64,
overlaps: OverlapScores,
) -> SchedulingRequest {
SchedulingRequest {
maybe_request_id: None,
token_seq: None,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: None,
update_states: false,
lora_name: None,
priority_jump,
expected_output_tokens: None,
allowed_worker_ids: None,
resp_tx: None,
}
}
fn overlaps_from(scores: &[(u64, u32)]) -> OverlapScores {
let mut map = FxHashMap::default();
for &(worker_id, score) in scores {
map.insert(WorkerWithDpRank::new(worker_id, 0), score);
}
OverlapScores {
scores: map,
frequencies: vec![],
tree_sizes: FxHashMap::default(),
}
}
// ---- FCFS policy tests ----
#[test]
fn fcfs_earlier_arrival_scheduled_first() {
let policy = FcfsPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let early = policy.enqueue_key(Duration::from_secs(1), &req);
let late = policy.enqueue_key(Duration::from_secs(10), &req);
assert!(early > late, "earlier arrival should have higher key");
}
#[test]
fn fcfs_priority_jump_promotes() {
let policy = FcfsPolicy;
// Both arrive at the same wall-clock offset (10s), but one has priority.
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 100.0, OverlapScores::default());
let t = Duration::from_secs(10);
let key_normal = policy.enqueue_key(t, &normal);
let key_boosted = policy.enqueue_key(t, &boosted);
assert!(
key_boosted > key_normal,
"priority_jump should produce a higher key"
);
}
#[test]
fn fcfs_priority_jump_beats_earlier_arrival() {
let policy = FcfsPolicy;
// Request A arrived at t=0 with no priority.
// Request B arrived at t=5 with priority_jump=50s.
// B should be scheduled first despite arriving later.
let a = request_with(512, 0.0, OverlapScores::default());
let b = request_with(512, 50.0, OverlapScores::default());
let key_a = policy.enqueue_key(Duration::from_secs(0), &a);
let key_b = policy.enqueue_key(Duration::from_secs(5), &b);
assert!(key_b > key_a);
}
// ---- WSPT policy tests ----
#[test]
fn wspt_shorter_request_scheduled_first() {
let policy = WsptPolicy { block_size: 16 };
let short = request_with(100, 0.0, OverlapScores::default());
let long = request_with(1000, 0.0, OverlapScores::default());
let t = Duration::ZERO;
assert!(
policy.enqueue_key(t, &short) > policy.enqueue_key(t, &long),
"shorter request should have higher key"
);
}
#[test]
fn wspt_overlap_reduces_effective_cost() {
let policy = WsptPolicy { block_size: 16 };
// Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens).
let no_cache = request_with(1024, 0.0, OverlapScores::default());
let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
let t = Duration::ZERO;
let key_no_cache = policy.enqueue_key(t, &no_cache);
let key_cached = policy.enqueue_key(t, &cached);
assert!(
key_cached > key_no_cache,
"request with overlap should have higher key (fewer new tokens)"
);
}
#[test]
fn wspt_priority_promotes() {
let policy = WsptPolicy { block_size: 16 };
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 5.0, OverlapScores::default());
let t = Duration::ZERO;
assert!(
policy.enqueue_key(t, &boosted) > policy.enqueue_key(t, &normal),
"priority_jump should increase key"
);
}
#[test]
fn wspt_uses_max_overlap() {
let policy = WsptPolicy { block_size: 16 };
// 4 workers with overlaps [10, 20, 50, 60]. max = 60.
// new_tokens = 1024 - 60*16 = 64
let req = request_with(
1024,
0.0,
overlaps_from(&[(0, 10), (1, 20), (2, 50), (3, 60)]),
);
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 64.0);
assert_eq!(key, expected);
}
#[test]
fn wspt_no_overlap_falls_back_to_isl() {
let policy = WsptPolicy { block_size: 16 };
let req = request_with(512, 0.0, OverlapScores::default());
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 512.0);
assert_eq!(key, expected);
}
#[test]
fn wspt_full_overlap_clamps_to_one() {
let policy = WsptPolicy { block_size: 16 };
// 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1)
let req = request_with(512, 0.0, overlaps_from(&[(0, 64)]));
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 1.0);
assert_eq!(key, expected);
}
}
...@@ -2,14 +2,15 @@ ...@@ -2,14 +2,15 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering; use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap}; use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::time::{Duration, Instant}; use std::time::Instant;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::watch; use tokio::sync::watch;
use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::selector::WorkerSelector; use super::selector::WorkerSelector;
use super::types::{SchedulingRequest, SchedulingResponse}; use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
...@@ -18,29 +19,27 @@ use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRe ...@@ -18,29 +19,27 @@ use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRe
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker) /// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
pub const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000; pub const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000;
/// Entry in the priority queue, ordered by effective arrival time (lower = higher priority). /// Entry in the priority queue, ordered by key (higher key = higher priority).
/// Effective arrival = elapsed time since queue start minus `priority_jump`. struct QueueEntry<K: Ord + Eq> {
struct QueueEntry { key: K,
effective_offset: Duration,
request: SchedulingRequest, request: SchedulingRequest,
} }
impl Eq for QueueEntry {} impl<K: Ord + Eq> Eq for QueueEntry<K> {}
impl PartialEq for QueueEntry { impl<K: Ord + Eq> PartialEq for QueueEntry<K> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.effective_offset == other.effective_offset self.key == other.key
} }
} }
impl Ord for QueueEntry { impl<K: Ord + Eq> Ord for QueueEntry<K> {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
// BinaryHeap is a max-heap; reverse so lower effective_offset = higher priority self.key.cmp(&other.key)
other.effective_offset.cmp(&self.effective_offset)
} }
} }
impl PartialOrd for QueueEntry { impl<K: Ord + Eq> PartialOrd for QueueEntry<K> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other)) Some(self.cmp(other))
} }
...@@ -50,8 +49,12 @@ impl PartialOrd for QueueEntry { ...@@ -50,8 +49,12 @@ impl PartialOrd for QueueEntry {
/// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`. /// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`.
/// When capacity frees up (`update()`), pending requests are scheduled in priority order. /// When capacity frees up (`update()`), pending requests are scheduled in priority order.
/// If queueing is disabled (threshold_frac is None), requests are scheduled immediately. /// If queueing is disabled (threshold_frac is None), requests are scheduled immediately.
pub struct SchedulerQueue<P: SequencePublisher, C: WorkerConfigLike> { pub struct SchedulerQueue<
pending: Mutex<BinaryHeap<QueueEntry>>, P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy = FcfsPolicy,
> {
pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>,
/// Number of requests currently parked in the pending queue. /// Number of requests currently parked in the pending queue.
/// Incremented after push, decremented after pop. Lock-free reads via `Relaxed` load. /// Incremented after push, decremented after pop. Lock-free reads via `Relaxed` load.
pending_count: AtomicUsize, pending_count: AtomicUsize,
...@@ -63,15 +66,19 @@ pub struct SchedulerQueue<P: SequencePublisher, C: WorkerConfigLike> { ...@@ -63,15 +66,19 @@ pub struct SchedulerQueue<P: SequencePublisher, C: WorkerConfigLike> {
start_time: Instant, start_time: Instant,
block_size: u32, block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>, selector: Box<dyn WorkerSelector<C> + Send + Sync>,
policy: S,
} }
impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { impl<P: SequencePublisher + 'static, C: WorkerConfigLike, S: SchedulingPolicy>
SchedulerQueue<P, C, S>
{
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>, slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>, workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
block_size: u32, block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>, selector: Box<dyn WorkerSelector<C> + Send + Sync>,
policy: S,
) -> Self { ) -> Self {
if let Some(frac) = threshold_frac { if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}"); tracing::info!("Router queue enabled with threshold fraction {frac}");
...@@ -85,17 +92,7 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -85,17 +92,7 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
start_time: Instant::now(), start_time: Instant::now(),
block_size, block_size,
selector, selector,
} policy,
}
/// Build a QueueEntry for a request, computing its effective arrival offset.
fn make_entry(&self, request: SchedulingRequest) -> QueueEntry {
let arrival_offset = self.start_time.elapsed();
let jump = Duration::from_secs_f64(request.priority_jump.max(0.0));
let effective_offset = arrival_offset.saturating_sub(jump);
QueueEntry {
effective_offset,
request,
} }
} }
...@@ -108,10 +105,11 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -108,10 +105,11 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
return; return;
}; };
if self.all_workers_busy(threshold) { if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref()) {
tracing::debug!("all workers busy, queueing request"); tracing::debug!("all workers busy, queueing request");
let entry = self.make_entry(request); let arrival_offset = self.start_time.elapsed();
self.pending.lock().await.push(entry); let key = self.policy.enqueue_key(arrival_offset, &request);
self.pending.lock().await.push(QueueEntry { key, request });
self.pending_count.fetch_add(1, AtomicOrdering::Relaxed); self.pending_count.fetch_add(1, AtomicOrdering::Relaxed);
} else { } else {
self.schedule(request).await; self.schedule(request).await;
...@@ -126,8 +124,22 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -126,8 +124,22 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
return; return;
}; };
if S::DYNAMIC {
let now = self.start_time.elapsed();
let mut heap = self.pending.lock().await;
let rekeyed: Vec<_> = std::mem::take(&mut *heap)
.into_vec()
.into_iter()
.map(|e| QueueEntry {
key: self.policy.rekey(now, &e.key, &e.request),
request: e.request,
})
.collect();
*heap = BinaryHeap::from(rekeyed);
}
loop { loop {
if self.all_workers_busy(threshold) { if self.all_workers_busy(threshold, None) {
break; break;
} }
let Some(entry) = self.pending.lock().await.pop() else { let Some(entry) = self.pending.lock().await.pop() else {
...@@ -201,13 +213,22 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -201,13 +213,22 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
self.pending_count.load(AtomicOrdering::Relaxed) self.pending_count.load(AtomicOrdering::Relaxed)
} }
/// Check if all workers are busy based on threshold. /// Check if all eligible workers are busy based on threshold.
/// Returns true only if ALL workers exceed the threshold (no worker has capacity). /// When `allowed` is `Some`, only those worker IDs are considered;
fn all_workers_busy(&self, threshold: f64) -> bool { /// otherwise all registered workers are checked.
/// Returns false when no eligible workers exist so the request falls
/// through to `schedule`, which returns a proper `NoEndpoints` error.
fn all_workers_busy(&self, threshold: f64, allowed: Option<&HashSet<WorkerId>>) -> bool {
let active_tokens = self.slots.active_tokens(); let active_tokens = self.slots.active_tokens();
let configs = self.workers_with_configs.borrow(); let configs = self.workers_with_configs.borrow();
let mut checked_any = false;
for (&worker_id, config) in configs.iter() { for (&worker_id, config) in configs.iter() {
if let Some(ids) = allowed
&& !ids.contains(&worker_id)
{
continue;
}
let dp_size = config.data_parallel_size(); let dp_size = config.data_parallel_size();
let dp_start_rank = config.data_parallel_start_rank(); let dp_start_rank = config.data_parallel_start_rank();
let max_batched = config let max_batched = config
...@@ -215,6 +236,7 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -215,6 +236,7 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS); .unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
for dp_rank in dp_start_rank..dp_start_rank + dp_size { for dp_rank in dp_start_rank..dp_start_rank + dp_size {
checked_any = true;
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let tokens = active_tokens.get(&worker).copied().unwrap_or(0); let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
if (tokens as f64) <= threshold * (max_batched as f64) { if (tokens as f64) <= threshold * (max_batched as f64) {
...@@ -222,7 +244,7 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -222,7 +244,7 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
} }
} }
} }
true checked_any
} }
} }
...@@ -279,6 +301,7 @@ mod tests { ...@@ -279,6 +301,7 @@ mod tests {
threshold_frac, threshold_frac,
block_size, block_size,
selector, selector,
FcfsPolicy,
)); ));
(queue, slots) (queue, slots)
......
...@@ -366,9 +366,14 @@ pub fn convert_event( ...@@ -366,9 +366,14 @@ pub fn convert_event(
event_id, event_id,
"Self-referencing block detected: duplicate hash in store event; dropping" "Self-referencing block detected: duplicate hash in store event; dropping"
); );
// Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())).
return KvCacheEvent { return KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Cleared, data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![],
}),
dp_rank, dp_rank,
}; };
} }
......
...@@ -5,7 +5,11 @@ pub use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS; ...@@ -5,7 +5,11 @@ pub use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use crate::kv_router::sequence::RuntimeSequencePublisher; use crate::kv_router::sequence::RuntimeSequencePublisher;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
/// Concrete `SchedulerQueue` wired to the runtime publisher and config types. /// Concrete `SchedulerQueue` wired to the runtime publisher and config types.
pub type SchedulerQueue = pub type SchedulerQueue = dynamo_kv_router::queue::SchedulerQueue<
dynamo_kv_router::queue::SchedulerQueue<RuntimeSequencePublisher, ModelRuntimeConfig>; RuntimeSequencePublisher,
ModelRuntimeConfig,
RouterSchedulingPolicy,
>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment