Unverified Commit bc514fbe authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: router priority queue (#6010)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent 4673e47f
...@@ -246,6 +246,74 @@ python real_data_benchmark.py --input-dataset trace.jsonl --prefix-root-multipli ...@@ -246,6 +246,74 @@ python real_data_benchmark.py --input-dataset trace.jsonl --prefix-root-multipli
> ``` > ```
> However, by the time of release, the aiperf version included in the vLLM runtime container should be up to date enough to use as-is. > However, by the time of release, the aiperf version included in the vLLM runtime container should be up to date enough to use as-is.
### Step 4 (Alternative): Priority Queue Benchmark
`real_data_priority_benchmark.py` measures whether the router's priority queue correctly differentiates high-, medium-, and low-priority requests. It splits a trace into three tiers, runs a **baseline** (no priority tagging) and a **priority-tagged** run using the same split, then produces a bar chart comparing TTFT across tiers.
#### How it works
1. The trace is synthesized (same parameters as `real_data_benchmark.py`) and split into low / medium / high tiers according to `--priority-distribution`.
2. Each tier is sent to aiperf as a concurrent stream. In the priority-tagged run, every request carries an OpenAI-compatible extension header:
```json
{"nvext": {"agent_hints": {"latency_sensitivity": <value>}}}
```
The `latency_sensitivity` value acts as a **priority jump** (in seconds) inside the router's scheduler queue -- a higher value shifts the request's effective arrival time earlier, giving it priority over lower-valued requests.
3. Two separate aiperf seeds are used for baseline vs. priority runs to ensure different generated prompt content and prevent mocker KV cache cross-contamination.
#### Prerequisites: enable the priority queue
The router queue only activates when `--router-queue-threshold` is set. Without it, requests bypass the queue entirely and priority has no effect.
```bash
# Launch the router with priority queue enabled.
# The fraction (e.g. 1.2) controls the busy threshold:
# workers are considered "busy" when active prefill tokens exceed
# threshold * max_num_batched_tokens. Values > 1.0 effectively make
# the queue always active.
python -m dynamo.frontend \
--router-mode kv \
--router-reset-states \
--router-queue-threshold 1.2
```
#### Running the benchmark
Because the mocker default speedup ratio is 1.0 (real-time), you need a sufficiently high `--speedup-ratio` to generate enough concurrent load for requests to actually queue up. A ratio of 8 or higher is recommended:
```bash
python real_data_priority_benchmark.py \
--input-dataset mooncake_trace.jsonl \
--num-requests 5000 \
--speedup-ratio 8 \
--prefix-len-multiplier 4 \
--prefix-root-multiplier 4
```
**Priority-specific parameters:**
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--priority-distribution` | `0.5,0.3,0.2` | Fraction of requests assigned to low/medium/high tiers (must sum to 1.0) |
| `--priority-values` | `0,1,2` | `latency_sensitivity` values for low/medium/high tiers (seconds of priority jump) |
Examples:
```bash
# Equal tier sizes with aggressive priority differentiation.
# --priority-values sets the latency_sensitivity per tier (low, medium, high).
# Each value is a priority jump in seconds: the router subtracts it from the
# request's arrival time, so higher values move the request further ahead
# in the queue. Here low gets no boost, medium jumps 2s ahead, high jumps 5s.
python real_data_priority_benchmark.py \
--input-dataset mooncake_trace.jsonl \
--num-requests 5000 \
--speedup-ratio 8 \
--priority-distribution 0.33,0.34,0.33 \
--priority-values 0,2,5
```
The benchmark outputs a `ttft_comparison.png` bar chart in the results directory showing TTFT (p50 with p25-p75 error bars) for each tier, comparing baseline vs. priority-tagged runs. If the priority queue is working correctly, high-priority requests should show lower TTFT in the priority run compared to baseline, while low-priority requests may show slightly higher TTFT.
### Step 4 (Alternative): Agent Benchmark (Concurrency-Based Multi-Turn) ### Step 4 (Alternative): Agent Benchmark (Concurrency-Based Multi-Turn)
For benchmarking with multi-turn conversation traces using concurrency-based load generation (instead of timestamp-based replay), use `agent_benchmark.py`. This is useful for testing how the system handles multiple concurrent agent sessions. For benchmarking with multi-turn conversation traces using concurrency-based load generation (instead of timestamp-based replay), use `agent_benchmark.py`. This is useful for testing how the system handles multiple concurrent agent sessions.
......
...@@ -5,7 +5,12 @@ ...@@ -5,7 +5,12 @@
"""Common utilities shared across router benchmark scripts.""" """Common utilities shared across router benchmark scripts."""
import json
import logging import logging
import os
import numpy as np
from prefix_data_generator.synthesizer import Synthesizer
# Default values # Default values
DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
...@@ -58,7 +63,83 @@ def add_common_args(parser): ...@@ -58,7 +63,83 @@ def add_common_args(parser):
parser.add_argument( parser.add_argument(
"--use-expected-osl", "--use-expected-osl",
action="store_true", action="store_true",
help="Pass expected_output_tokens to nvext for router tracking", help="Pass agent_hints.osl to nvext for router output block tracking",
)
def add_synthesis_args(parser):
"""Add CLI arguments for trace dataset synthesis, shared across benchmark scripts."""
parser.add_argument(
"--output-dir",
type=str,
default="real_data_benchmark_results",
help="Output directory for results",
)
parser.add_argument(
"--input-dataset",
type=str,
default="mooncake_trace.jsonl",
help="Path to the input mooncake-style trace dataset file",
)
parser.add_argument(
"--num-requests",
type=int,
default=None,
help="Number of requests to synthesize (default: use all from input file)",
)
parser.add_argument(
"--speedup-ratio",
type=float,
default=1.0,
help="Factor to speed up request intervals (default: 1.0)",
)
parser.add_argument(
"--prefix-len-multiplier",
type=float,
default=1.0,
help="Multiplier for prefix lengths (default: 1.0)",
)
parser.add_argument(
"--prefix-root-multiplier",
type=int,
default=1,
help="Number of times to replicate the core radix tree (default: 1)",
)
parser.add_argument(
"--prompt-len-multiplier",
type=float,
default=1.0,
help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
)
parser.add_argument(
"--max-isl",
type=int,
default=None,
help="Maximum input sequence length to include in output (default: None, no filtering)",
)
parser.add_argument(
"--min-isl",
type=int,
default=None,
help="Minimum input sequence length to include in output (default: None, no filtering)",
)
parser.add_argument(
"--min-osl",
type=int,
default=None,
help="Minimum output sequence length - clips values below this threshold (default: None, no clipping)",
)
parser.add_argument(
"--max-osl",
type=int,
default=None,
help="Maximum output sequence length - clips values above this threshold (default: None, no clipping)",
)
parser.add_argument(
"--block-size",
type=int,
default=DEFAULT_MOONCAKE_BLOCK_SIZE,
help=f"Block size for prefilling and decoding (default: {DEFAULT_MOONCAKE_BLOCK_SIZE})",
) )
...@@ -84,3 +165,164 @@ def get_common_aiperf_flags(): ...@@ -84,3 +165,164 @@ def get_common_aiperf_flags():
"-H", "-H",
"Accept: text/event-stream", "Accept: text/event-stream",
] ]
def get_aiperf_cmd_for_trace(
model,
tokenizer,
input_dataset,
artifact_dir,
seed,
block_size,
url="http://localhost:8888",
):
"""Build the aiperf CLI command for a mooncake trace run."""
cmd = [
"aiperf",
"profile",
"--model",
model,
"--tokenizer",
tokenizer,
"--url",
url,
"--input-file",
f"{input_dataset}",
"--custom-dataset-type",
"mooncake_trace",
"--fixed-schedule",
"--fixed-schedule-auto-offset",
"--prompt-input-tokens-block-size",
str(block_size),
"--random-seed",
str(seed),
"--artifact-dir",
artifact_dir,
]
cmd.extend(get_common_aiperf_flags())
return cmd
def prepare_trace_dataset(args, output_dir, logger):
"""Prepare a trace dataset, optionally synthesizing or modifying it.
Handles three paths:
1. No synthesis needed: use the original dataset as-is
2. Expected OSL injection only: inject agent_hints.osl into nvext
3. Full synthesis: generate synthetic data from the input dataset
Returns:
tuple[list[dict], str]: (list of request dicts, path to the trace file)
"""
needs_synthesis = (
args.num_requests is not None
or args.speedup_ratio != 1.0
or args.prefix_len_multiplier != 1.0
or args.prefix_root_multiplier != 1
or args.prompt_len_multiplier != 1.0
or args.max_isl is not None
or args.min_isl is not None
or args.min_osl is not None
or args.max_osl is not None
)
if not needs_synthesis and not args.use_expected_osl:
# No synthesis or modification needed, use original dataset
trace_dataset_path = args.input_dataset
logger.info(
f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}"
)
requests = []
with open(args.input_dataset, "r") as f:
for line in f:
requests.append(json.loads(line.strip()))
return requests, trace_dataset_path
if not needs_synthesis and args.use_expected_osl:
# Only inject agent_hints.osl into nvext, no other synthesis
logger.info("Injecting agent_hints.osl into original trace dataset...")
requests = []
with open(args.input_dataset, "r") as f:
for line in f:
requests.append(json.loads(line.strip()))
for request in requests:
osl = request.get("output_tokens", 0)
if "nvext" not in request:
request["nvext"] = {}
request["nvext"].setdefault("agent_hints", {})["osl"] = osl
trace_dataset_path = os.path.join(output_dir, "trace_with_expected_osl.jsonl")
with open(trace_dataset_path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
logger.info(f"Modified trace data saved to: {trace_dataset_path}")
return requests, trace_dataset_path
# Generate synthetic data based on input dataset
logger.info("Generating synthetic trace data...")
logger.info(f" Base dataset: {args.input_dataset}")
logger.info(f" Num requests: {args.num_requests if args.num_requests else 'all'}")
logger.info(f" Speedup ratio: {args.speedup_ratio}")
logger.info(f" Prefix len multiplier: {args.prefix_len_multiplier}")
logger.info(f" Prefix root multiplier: {args.prefix_root_multiplier}")
logger.info(f" Prompt len multiplier: {args.prompt_len_multiplier}")
logger.info(
f" Max ISL: {args.max_isl if args.max_isl else 'no limit'} (filtering)"
)
logger.info(
f" Min ISL: {args.min_isl if args.min_isl else 'no limit'} (filtering)"
)
logger.info(
f" Min OSL: {args.min_osl if args.min_osl else 'no clipping'} (clipping)"
)
logger.info(
f" Max OSL: {args.max_osl if args.max_osl else 'no clipping'} (clipping)"
)
logger.info(f" Random seed: {args.seed}")
np.random.seed(args.seed)
synthesizer = Synthesizer(
args.input_dataset,
block_size=args.block_size,
speedup_ratio=args.speedup_ratio,
prefix_len_multiplier=args.prefix_len_multiplier,
prefix_root_multiplier=args.prefix_root_multiplier,
prompt_len_multiplier=args.prompt_len_multiplier,
)
if args.num_requests is None:
with open(args.input_dataset, "r") as f:
num_requests = sum(1 for _ in f)
logger.info(f"Using all {num_requests} requests from input dataset")
else:
num_requests = args.num_requests
requests = synthesizer.synthesize_requests(
num_requests,
max_isl=args.max_isl,
min_isl=args.min_isl,
min_osl=args.min_osl,
max_osl=args.max_osl,
)
logger.info(f"Generated {len(requests)} synthetic requests")
trace_dataset_path = os.path.join(output_dir, "synthetic_trace.jsonl")
if args.use_expected_osl:
for request in requests:
osl = request.get("output_tokens", 0)
if "nvext" not in request:
request["nvext"] = {}
request["nvext"].setdefault("agent_hints", {})["osl"] = osl
logger.info("Injected agent_hints.osl into nvext for each request")
with open(trace_dataset_path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
logger.info(f"Synthetic trace data saved to: {trace_dataset_path}")
return requests, trace_dataset_path
...@@ -41,10 +41,10 @@ def get_aiperf_cmd( ...@@ -41,10 +41,10 @@ def get_aiperf_cmd(
prefix_length = int(isl * prefix_ratio) prefix_length = int(isl * prefix_ratio)
synthetic_input_length = int(isl * (1 - prefix_ratio)) synthetic_input_length = int(isl * (1 - prefix_ratio))
# Build nvext JSON with optional expected_output_tokens # Build nvext JSON with optional agent_hints.osl
nvext_dict = {"ignore_eos": True} nvext_dict = {"ignore_eos": True}
if use_expected_osl: if use_expected_osl:
nvext_dict["expected_output_tokens"] = osl nvext_dict["agent_hints"] = {"osl": osl}
nvext_json = json.dumps({"nvext": nvext_dict}) nvext_json = json.dumps({"nvext": nvext_dict})
cmd = [ cmd = [
......
...@@ -4,57 +4,21 @@ ...@@ -4,57 +4,21 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import json
import os import os
import subprocess import subprocess
import numpy as np
from common import ( from common import (
DEFAULT_MOONCAKE_BLOCK_SIZE,
add_common_args, add_common_args,
get_common_aiperf_flags, add_synthesis_args,
get_aiperf_cmd_for_trace,
prepare_trace_dataset,
resolve_tokenizer, resolve_tokenizer,
setup_logger, setup_logger,
) )
from prefix_data_generator.synthesizer import Synthesizer
logger = setup_logger(__name__) logger = setup_logger(__name__)
def get_aiperf_cmd_for_trace(
model,
tokenizer,
input_dataset,
artifact_dir,
seed,
block_size,
url="http://localhost:8888",
):
cmd = [
"aiperf",
"profile",
"--model",
model,
"--tokenizer",
tokenizer,
"--url",
url,
"--input-file",
f"{input_dataset}",
"--custom-dataset-type",
"mooncake_trace",
"--fixed-schedule-auto-offset",
"--prompt-input-tokens-block-size",
str(block_size),
"--random-seed",
str(seed),
"--artifact-dir",
artifact_dir,
]
cmd.extend(get_common_aiperf_flags())
return cmd
def run_benchmark_with_trace( def run_benchmark_with_trace(
model, model,
tokenizer, tokenizer,
...@@ -95,216 +59,16 @@ def main(): ...@@ -95,216 +59,16 @@ def main():
description="Benchmark with real or synthesized mooncake-style trace data" description="Benchmark with real or synthesized mooncake-style trace data"
) )
# Common arguments
add_common_args(parser) add_common_args(parser)
add_synthesis_args(parser)
parser.add_argument(
"--output-dir",
type=str,
default="real_data_benchmark_results",
help="Output directory for results",
)
# Trace dataset and synthesis configuration (similar to synthesizer.py)
parser.add_argument(
"--input-dataset",
type=str,
default="mooncake_trace.jsonl",
help="Path to the input mooncake-style trace dataset file",
)
parser.add_argument(
"--num-requests",
type=int,
default=None,
help="Number of requests to synthesize (default: use all from input file)",
)
parser.add_argument(
"--speedup-ratio",
type=float,
default=1.0,
help="Factor to speed up request intervals (default: 1.0)",
)
parser.add_argument(
"--prefix-len-multiplier",
type=float,
default=1.0,
help="Multiplier for prefix lengths (default: 1.0)",
)
parser.add_argument(
"--prefix-root-multiplier",
type=int,
default=1,
help="Number of times to replicate the core radix tree (default: 1)",
)
parser.add_argument(
"--prompt-len-multiplier",
type=float,
default=1.0,
help="Multiplier for leaf path lengths (default: 1.0, use <1 for shorter prompts)",
)
parser.add_argument(
"--max-isl",
type=int,
default=None,
help="Maximum input sequence length to include in output (default: None, no filtering)",
)
parser.add_argument(
"--min-isl",
type=int,
default=None,
help="Minimum input sequence length to include in output (default: None, no filtering)",
)
parser.add_argument(
"--min-osl",
type=int,
default=None,
help="Minimum output sequence length - clips values below this threshold (default: None, no clipping)",
)
parser.add_argument(
"--max-osl",
type=int,
default=None,
help="Maximum output sequence length - clips values above this threshold (default: None, no clipping)",
)
parser.add_argument(
"--block-size",
type=int,
default=DEFAULT_MOONCAKE_BLOCK_SIZE,
help=f"Block size for prefilling and decoding (default: {DEFAULT_MOONCAKE_BLOCK_SIZE})",
)
args = parser.parse_args() args = parser.parse_args()
resolve_tokenizer(args) resolve_tokenizer(args)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Determine whether to use original or synthesized data _, trace_dataset_path = prepare_trace_dataset(args, args.output_dir, logger)
# Check if any synthesis parameters are non-default
needs_synthesis = (
args.num_requests is not None
or args.speedup_ratio != 1.0
or args.prefix_len_multiplier != 1.0
or args.prefix_root_multiplier != 1
or args.prompt_len_multiplier != 1.0
or args.max_isl is not None
or args.min_isl is not None
or args.min_osl is not None
or args.max_osl is not None
)
if not needs_synthesis and not args.use_expected_osl:
# No synthesis or modification needed, use original dataset
trace_dataset_path = args.input_dataset
logger.info(
f"Using original trace dataset (no synthesis parameters modified): {trace_dataset_path}"
)
elif not needs_synthesis and args.use_expected_osl:
# Only inject expected_output_tokens into nvext, no other synthesis
logger.info("Injecting expected_output_tokens into original trace dataset...")
# Read original dataset
requests = []
with open(args.input_dataset, "r") as f:
for line in f:
requests.append(json.loads(line.strip()))
# Inject expected_output_tokens into nvext for each request
for request in requests:
osl = request.get("output_tokens", 0)
if "nvext" not in request:
request["nvext"] = {}
request["nvext"]["expected_output_tokens"] = osl
# Write modified data to output directory
trace_dataset_path = os.path.join(
args.output_dir, "trace_with_expected_osl.jsonl"
)
with open(trace_dataset_path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
logger.info(f"Modified trace data saved to: {trace_dataset_path}")
else:
# Generate synthetic data based on input dataset
logger.info("Generating synthetic trace data...")
logger.info(f" Base dataset: {args.input_dataset}")
logger.info(
f" Num requests: {args.num_requests if args.num_requests else 'all'}"
)
logger.info(f" Speedup ratio: {args.speedup_ratio}")
logger.info(f" Prefix len multiplier: {args.prefix_len_multiplier}")
logger.info(f" Prefix root multiplier: {args.prefix_root_multiplier}")
logger.info(f" Prompt len multiplier: {args.prompt_len_multiplier}")
logger.info(
f" Max ISL: {args.max_isl if args.max_isl else 'no limit'} (filtering)"
)
logger.info(
f" Min ISL: {args.min_isl if args.min_isl else 'no limit'} (filtering)"
)
logger.info(
f" Min OSL: {args.min_osl if args.min_osl else 'no clipping'} (clipping)"
)
logger.info(
f" Max OSL: {args.max_osl if args.max_osl else 'no clipping'} (clipping)"
)
logger.info(f" Random seed: {args.seed}")
# Set random seed for reproducibility
np.random.seed(args.seed)
# Create synthesizer
synthesizer = Synthesizer(
args.input_dataset,
block_size=args.block_size,
speedup_ratio=args.speedup_ratio,
prefix_len_multiplier=args.prefix_len_multiplier,
prefix_root_multiplier=args.prefix_root_multiplier,
prompt_len_multiplier=args.prompt_len_multiplier,
)
# Determine number of requests
if args.num_requests is None:
# Count requests in original dataset
with open(args.input_dataset, "r") as f:
num_requests = sum(1 for _ in f)
logger.info(f"Using all {num_requests} requests from input dataset")
else:
num_requests = args.num_requests
# Generate synthetic requests
requests = synthesizer.synthesize_requests(
num_requests,
max_isl=args.max_isl,
min_isl=args.min_isl,
min_osl=args.min_osl,
max_osl=args.max_osl,
)
logger.info(f"Generated {len(requests)} synthetic requests")
# Save synthetic data to a permanent file in output directory
synthetic_trace_filename = "synthetic_trace.jsonl"
trace_dataset_path = os.path.join(args.output_dir, synthetic_trace_filename)
# Optionally inject expected_output_tokens into nvext for each request
if args.use_expected_osl:
for request in requests:
# Get the output_tokens (OSL) for this request
osl = request.get("output_tokens", 0)
# Initialize or update nvext with expected_output_tokens
if "nvext" not in request:
request["nvext"] = {}
request["nvext"]["expected_output_tokens"] = osl
logger.info("Injected expected_output_tokens into nvext for each request")
# Write synthetic data to file
with open(trace_dataset_path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
logger.info(f"Synthetic trace data saved to: {trace_dataset_path}")
# Run benchmark with the trace dataset
artifact_dir = os.path.join(args.output_dir, "aiperf_artifacts") artifact_dir = os.path.join(args.output_dir, "aiperf_artifacts")
os.makedirs(artifact_dir, exist_ok=True) os.makedirs(artifact_dir, exist_ok=True)
......
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Priority queue benchmark: splits a trace into priority tiers, runs a baseline
(no priority tagging) and a priority-tagged run with the same split, then
produces a bar chart comparing TTFT across tiers."""
import argparse
import copy
import json
import os
import subprocess
import matplotlib.pyplot as plt
import numpy as np
from common import (
add_common_args,
add_synthesis_args,
get_aiperf_cmd_for_trace,
prepare_trace_dataset,
resolve_tokenizer,
setup_logger,
)
logger = setup_logger(__name__)
TIERS = ["low", "medium", "high"]
def parse_float_list(s):
"""Parse a comma-separated string into a list of floats."""
return [float(x.strip()) for x in s.split(",")]
def split_trace(requests, distribution, seed):
"""Split requests into priority tiers by distribution. Deterministic given seed."""
rng = np.random.RandomState(seed)
labels = rng.choice(len(distribution), size=len(requests), p=distribution)
return {
tier: [r for r, label in zip(requests, labels) if label == i]
for i, tier in enumerate(TIERS)
}
def offset_hash_ids(tier_requests):
"""Return a deep copy of tier_requests with all hash_ids shifted by max_hash_id + 1.
Preserves the prefix tree structure (relative ordering and sharing)
while ensuring no KV cache hits from a previous run.
"""
max_hash_id = max(
h
for requests in tier_requests.values()
for req in requests
for h in req["hash_ids"]
)
offset = max_hash_id + 1
shifted = {}
for tier, requests in tier_requests.items():
shifted[tier] = []
for req in requests:
r = copy.copy(req)
r["hash_ids"] = [h + offset for h in r["hash_ids"]]
shifted[tier].append(r)
return shifted
def write_trace_file(requests, path):
"""Write a list of request dicts to a JSONL file."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
for request in requests:
f.write(json.dumps(request) + "\n")
def run_concurrent_streams(
args, tier_requests, priority_values, run_dir, tag_priority, logger, seed=None
):
"""Launch concurrent aiperf subprocesses for each tier.
Args:
tag_priority: If True, inject nvext.agent_hints.latency_sensitivity per tier.
"""
processes = []
log_files = []
for tier, pj in zip(TIERS, priority_values):
tier_dir = os.path.join(run_dir, f"{tier}_priority")
os.makedirs(tier_dir, exist_ok=True)
trace_path = os.path.join(tier_dir, "trace.jsonl")
write_trace_file(tier_requests[tier], trace_path)
artifact_dir = os.path.join(tier_dir, "aiperf_artifacts")
os.makedirs(artifact_dir, exist_ok=True)
cmd = get_aiperf_cmd_for_trace(
args.model,
args.tokenizer,
trace_path,
artifact_dir,
seed if seed is not None else args.seed,
args.block_size,
args.url,
)
cmd.extend(["--log-level", "WARNING", "--ui-type", "none"])
if tag_priority:
cmd.extend(
[
"--extra-inputs",
json.dumps({"nvext": {"agent_hints": {"latency_sensitivity": pj}}}),
]
)
log_path = os.path.join(tier_dir, "aiperf.log")
log_file = open(log_path, "w")
log_files.append(log_file)
label = "priority" if tag_priority else "baseline"
logger.info(f"Launching {tier} tier ({label}, latency_sensitivity={pj})")
logger.info(f" Command: {' '.join(cmd)}")
proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT)
processes.append((tier, proc))
failed = []
for tier, proc in processes:
proc.wait()
if proc.returncode == 0:
logger.info(f" {tier} tier completed successfully")
else:
logger.error(f" {tier} tier failed with exit code {proc.returncode}")
failed.append(tier)
for log_file in log_files:
log_file.close()
if failed:
label = "priority" if tag_priority else "baseline"
logger.error(f"Failed tiers in {label} run: {', '.join(failed)}")
logger.error("Check the aiperf.log files in each tier directory for details")
raise SystemExit(1)
def load_ttft(run_dir, tier):
"""Load TTFT stats from an aiperf result JSON."""
result_path = os.path.join(
run_dir, f"{tier}_priority", "aiperf_artifacts", "profile_export_aiperf.json"
)
with open(result_path, "r") as f:
data = json.load(f)
ttft = data["time_to_first_token"]
return ttft["p50"], ttft["p25"], ttft["p75"]
def plot_ttft_comparison(baseline_dir, priority_dir, output_path, priority_values):
"""Create a grouped bar chart comparing TTFT between baseline and priority runs."""
x = np.arange(len(TIERS))
width = 0.35
baseline_medians = []
baseline_lo = []
baseline_hi = []
priority_medians = []
priority_lo = []
priority_hi = []
for tier in TIERS:
p50, p25, p75 = load_ttft(baseline_dir, tier)
baseline_medians.append(p50)
baseline_lo.append(p50 - p25)
baseline_hi.append(p75 - p50)
p50, p25, p75 = load_ttft(priority_dir, tier)
priority_medians.append(p50)
priority_lo.append(p50 - p25)
priority_hi.append(p75 - p50)
fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(
x - width / 2,
baseline_medians,
width,
yerr=[baseline_lo, baseline_hi],
label="Baseline (no priority)",
capsize=4,
)
ax.bar(
x + width / 2,
priority_medians,
width,
yerr=[priority_lo, priority_hi],
label="With latency_sensitivity",
capsize=4,
)
tier_labels = [
f"{tier.capitalize()}\n(ls={pj})" for tier, pj in zip(TIERS, priority_values)
]
ax.set_xticks(x)
ax.set_xticklabels(tier_labels)
ax.set_ylabel("TTFT (ms)")
ax.set_title("Time to First Token by Priority Tier")
ax.legend()
ax.grid(axis="y", alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
logger.info(f"Plot saved to: {output_path}")
plt.close(fig)
def main():
parser = argparse.ArgumentParser(
description="Priority benchmark: compare TTFT with and without priority tagging"
)
add_common_args(parser)
add_synthesis_args(parser)
parser.add_argument(
"--priority-distribution",
type=str,
default="0.5,0.3,0.2",
help="Comma-separated fractions for low/medium/high tiers (default: 0.5,0.3,0.2)",
)
parser.add_argument(
"--priority-values",
type=str,
default="0,1,2",
help="Comma-separated latency_sensitivity values for low/medium/high tiers (default: 0,1,2)",
)
args = parser.parse_args()
resolve_tokenizer(args)
distribution = parse_float_list(args.priority_distribution)
priority_values = parse_float_list(args.priority_values)
if len(distribution) != len(TIERS):
parser.error(
f"--priority-distribution must have {len(TIERS)} values, got {len(distribution)}"
)
if len(priority_values) != len(TIERS):
parser.error(
f"--priority-values must have {len(TIERS)} values, got {len(priority_values)}"
)
if abs(sum(distribution) - 1.0) > 1e-6:
parser.error(
f"--priority-distribution must sum to 1.0, got {sum(distribution)}"
)
os.makedirs(args.output_dir, exist_ok=True)
# Prepare the trace dataset (synthesis if needed)
requests, _ = prepare_trace_dataset(args, args.output_dir, logger)
# Split into priority tiers (deterministic via seed)
tier_requests = split_trace(requests, distribution, args.seed)
for tier in TIERS:
logger.info(f" {tier} priority: {len(tier_requests[tier])} requests")
# Use different aiperf random seeds per run so that the generated prompts
# differ, preventing mocker KV cache hits between runs.
baseline_seed = args.seed
priority_seed = args.seed + 1
# Run 1: Baseline (same split, no priority tagging)
baseline_dir = os.path.join(args.output_dir, "baseline")
logger.info("=== Running baseline (no priority tagging) ===")
run_concurrent_streams(
args,
tier_requests,
priority_values,
baseline_dir,
tag_priority=False,
logger=logger,
seed=baseline_seed,
)
# Run 2: With priority tagging
priority_dir = os.path.join(args.output_dir, "priority")
logger.info("=== Running with priority tagging ===")
run_concurrent_streams(
args,
tier_requests,
priority_values,
priority_dir,
tag_priority=True,
logger=logger,
seed=priority_seed,
)
# Plot comparison
plot_path = os.path.join(args.output_dir, "ttft_comparison.png")
plot_ttft_comparison(baseline_dir, priority_dir, plot_path, priority_values)
logger.info(f"All runs completed. Results saved to: {args.output_dir}")
if __name__ == "__main__":
main()
...@@ -62,6 +62,7 @@ class FrontendConfig(ConfigBase): ...@@ -62,6 +62,7 @@ class FrontendConfig(ConfigBase):
router_assume_kv_reuse: bool router_assume_kv_reuse: bool
router_track_output_blocks: bool router_track_output_blocks: bool
router_event_threads: int router_event_threads: int
router_queue_threshold: Optional[float]
enforce_disagg: bool enforce_disagg: bool
migration_limit: int migration_limit: int
...@@ -336,6 +337,19 @@ class FrontendArgGroup(ArgGroup): ...@@ -336,6 +337,19 @@ class FrontendArgGroup(ArgGroup):
), ),
arg_type=int, arg_type=int,
) )
add_argument(
g,
flag_name="--router-queue-threshold",
env_var="DYN_ROUTER_QUEUE_THRESHOLD",
default=None,
help=(
"KV Router: Queue threshold fraction for prefill token capacity. "
"When set, requests are queued if all workers exceed this fraction of "
"max_num_batched_tokens. Enables priority scheduling via latency_sensitivity "
"hints. Must be > 0. If not set, queueing is disabled."
),
arg_type=float,
)
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--enforce-disagg", flag_name="--enforce-disagg",
......
...@@ -191,6 +191,7 @@ async def async_main(): ...@@ -191,6 +191,7 @@ async def async_main():
router_ttl_secs=config.router_ttl, router_ttl_secs=config.router_ttl,
router_max_tree_size=config.router_max_tree_size, router_max_tree_size=config.router_max_tree_size,
router_prune_target_ratio=config.router_prune_target_ratio, router_prune_target_ratio=config.router_prune_target_ratio,
router_queue_threshold=config.router_queue_threshold,
router_event_threads=config.router_event_threads, router_event_threads=config.router_event_threads,
) )
elif config.router_mode == "random": elif config.router_mode == "random":
......
...@@ -233,7 +233,7 @@ def parse_args(): ...@@ -233,7 +233,7 @@ def parse_args():
action="store_true", action="store_true",
dest="router_track_output_blocks", dest="router_track_output_blocks",
default=False, default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens (default: False)", help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected output sequence length (agent_hints.osl in nvext). Default: False.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -38,6 +38,7 @@ Backend workers register themselves using the `register_llm` API, after which th ...@@ -38,6 +38,7 @@ Backend workers register themselves using the `register_llm` API, after which th
| `--kv-cache-block-size <size>` | Backend-specific | KV cache block size (should match backend config) | | `--kv-cache-block-size <size>` | Backend-specific | KV cache block size (should match backend config) |
| `--kv-events` / `--no-kv-events` | `--kv-events` | Enable/disable real-time KV event tracking | | `--kv-events` / `--no-kv-events` | `--kv-events` | Enable/disable real-time KV event tracking |
| `--kv-overlap-score-weight <float>` | `1.0` | Balance prefill vs decode optimization (higher = better TTFT) | | `--kv-overlap-score-weight <float>` | `1.0` | Balance prefill vs decode optimization (higher = better TTFT) |
| `--router-queue-threshold <float>` | None (disabled) | Queue threshold fraction; enables priority scheduling via `latency_sensitivity` |
For all available options: `python -m dynamo.frontend --help` For all available options: `python -m dynamo.frontend --help`
...@@ -159,10 +160,12 @@ The main KV-aware routing arguments: ...@@ -159,10 +160,12 @@ The main KV-aware routing arguments:
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management. - `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
- `--track-output-blocks`: Enables tracking of output blocks during generation (default: disabled). When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward `expected_output_tokens`. This improves load balancing accuracy for long-running generation requests by accounting for output-side KV cache growth. - `--track-output-blocks`: Enables tracking of output blocks during generation (default: disabled). When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward the expected output sequence length (`agent_hints.osl` in nvext). This improves load balancing accuracy for long-running generation requests by accounting for output-side KV cache growth.
- `--no-assume-kv-reuse`: When tracking active blocks, disables the assumption of KV cache reuse. By default (`router_assume_kv_reuse=true`), the router computes actual block hashes for sequence tracking to deduplicate blocks and optimize load balancing. When disabled via this flag, the router generates random hashes for sequence blocks, treating each request's blocks as unique. This is useful in disaggregated setups where prefill transfers blocks to decode workers that may already have those blocks cached, but the engine cannot coordinate transfers to avoid duplication. Without this flag, the router's load balancing heuristics would undercount decode blocks when duplicates exist. - `--no-assume-kv-reuse`: When tracking active blocks, disables the assumption of KV cache reuse. By default (`router_assume_kv_reuse=true`), the router computes actual block hashes for sequence tracking to deduplicate blocks and optimize load balancing. When disabled via this flag, the router generates random hashes for sequence blocks, treating each request's blocks as unique. This is useful in disaggregated setups where prefill transfers blocks to decode workers that may already have those blocks cached, but the engine cannot coordinate transfers to avoid duplication. Without this flag, the router's load balancing heuristics would undercount decode blocks when duplicates exist.
- `--router-queue-threshold`: Queue threshold fraction for prefill token capacity. When set, the router holds incoming requests in a priority queue while all workers exceed this fraction of `max_num_batched_tokens`, releasing them when capacity frees up. This defers dispatch (not rejection) so that routing decisions use the most up-to-date load metrics at the moment the request is actually sent to a worker. It also enables **priority scheduling** via `latency_sensitivity` hints in `nvext.agent_hints` — higher values shift a request's effective arrival time earlier in the queue, giving it priority over lower-valued requests. Must be > 0. If not set (default), queueing is disabled and requests are dispatched immediately.
- `--active-decode-blocks-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache block utilization. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, blocks-based busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines publish load metrics. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)). - `--active-decode-blocks-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache block utilization. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, blocks-based busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines publish load metrics. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)).
- `--active-prefill-tokens-threshold`: Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled. - `--active-prefill-tokens-threshold`: Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled.
...@@ -184,10 +187,7 @@ The main KV-aware routing arguments: ...@@ -184,10 +187,7 @@ The main KV-aware routing arguments:
> - **No KV events** (`--no-kv-events`): State persistence is not supported. > - **No KV events** (`--no-kv-events`): State persistence is not supported.
> >
> **Request plane is independent of KV event transport.** > **Request plane is independent of KV event transport.**
> The router can run without etcd or NATS when using ZMQ event plane (`--event-plane zmq`) and file/mem store (`--store-kv file` or `--store-kv mem`); in this case, KV events use ZMQ transport instead of NATS. > The request plane (`DYN_REQUEST_PLANE` / `--request-plane`) controls how requests reach workers (TCP/HTTP/NATS), while KV events travel over a separate path. KV events use NATS in JetStream or NATS Core modes, or ZMQ when `--event-plane zmq` is set. With `--event-plane zmq` and `--store-kv file` or `mem`, the router can run entirely without etcd or NATS. When using a NATS-based event plane (the default), NATS is initialized automatically; set `NATS_SERVER=nats://...` to override the default `localhost:4222`. Use `--no-kv-events` to disable KV event transport entirely.
> `DYN_REQUEST_PLANE` controls how **requests** are sent (TCP/HTTP/NATS), but KV-aware routing uses **NATS** for KV events only in JetStream or NATS Core modes (not ZMQ mode).
> When KV events are enabled (default) with NATS-based event plane, NATS is automatically initialized. You can optionally set `NATS_SERVER=nats://...` to specify a custom NATS server; otherwise, it defaults to `localhost:4222`.
> `--no-kv-events` disables KV event transport entirely.
> >
> When `--kv-overlap-score-weight` is set to 0, no KVIndexer is created and prefix matching is disabled (pure load balancing). When `--no-kv-events` is set, a KVIndexer is still created but no event subscriber is launched to consume KV events from workers. Instead, the router predicts cache state based on its own routing decisions with TTL-based expiration and pruning. > When `--kv-overlap-score-weight` is set to 0, no KVIndexer is created and prefix matching is disabled (pure load balancing). When `--no-kv-events` is set, a KVIndexer is still created but no event subscriber is launched to consume KV events from workers. Instead, the router predicts cache state based on its own routing decisions with TTL-based expiration and pruning.
> >
...@@ -197,6 +197,8 @@ The main KV-aware routing arguments: ...@@ -197,6 +197,8 @@ The main KV-aware routing arguments:
> - **TRT-LLM**: Do not use `--publish-events-and-metrics` > - **TRT-LLM**: Do not use `--publish-events-and-metrics`
> >
> The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored. > The cli args `--router-ttl`, `--router-max-tree-size`, and `--router-prune-target-ratio` control local cache management when the router operates without receiving events from workers. When KV events are enabled (default), the router relies on worker-side eviction events and these parameters are ignored.
>
> **Queue threshold vs. busy rejection thresholds:** `--router-queue-threshold` and the busy thresholds (`--active-decode-blocks-threshold`, `--active-prefill-tokens-threshold`, `--active-prefill-tokens-threshold-frac`) serve different purposes. The busy thresholds **reject** a worker entirely from the candidate set when it exceeds a utilization limit — no traffic is sent until it drops below the threshold. In contrast, `--router-queue-threshold` does not reject workers; it **defers the entire routing decision** until at least one worker has capacity, so the request is routed with the freshest load metrics. The queue also enables priority scheduling via `nvext.agent_hints.latency_sensitivity`.
To implement KV event publishing for custom inference engines, enabling them to participate in Dynamo's KV cache-aware routing, see [KV Event Publishing for Custom Engines](../../integrations/kv-events-custom-engines.md). To implement KV event publishing for custom inference engines, enabling them to participate in Dynamo's KV cache-aware routing, see [KV Event Publishing for Custom Engines](../../integrations/kv-events-custom-engines.md).
......
...@@ -52,7 +52,7 @@ impl KvRouterConfig { ...@@ -52,7 +52,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_event_threads=1))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=None, router_event_threads=1))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -68,6 +68,7 @@ impl KvRouterConfig { ...@@ -68,6 +68,7 @@ impl KvRouterConfig {
router_ttl_secs: f64, router_ttl_secs: f64,
router_max_tree_size: usize, router_max_tree_size: usize,
router_prune_target_ratio: f64, router_prune_target_ratio: f64,
router_queue_threshold: Option<f64>,
router_event_threads: u32, router_event_threads: u32,
) -> Self { ) -> Self {
KvRouterConfig { KvRouterConfig {
...@@ -85,6 +86,7 @@ impl KvRouterConfig { ...@@ -85,6 +86,7 @@ impl KvRouterConfig {
router_ttl_secs, router_ttl_secs,
router_max_tree_size, router_max_tree_size,
router_prune_target_ratio, router_prune_target_ratio,
router_queue_threshold,
router_event_threads, router_event_threads,
}, },
} }
......
...@@ -1007,6 +1007,7 @@ impl KvPushRouter { ...@@ -1007,6 +1007,7 @@ impl KvPushRouter {
router_config_override.as_ref(), router_config_override.as_ref(),
update_states, update_states,
None, // lora_name not exposed in Python API yet None, // lora_name not exposed in Python API yet
0.0,
) )
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -995,6 +995,7 @@ class KvRouterConfig: ...@@ -995,6 +995,7 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576, router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = None,
router_event_threads: int = 1, router_event_threads: int = 1,
) -> None: ) -> None:
""" """
...@@ -1011,7 +1012,8 @@ class KvRouterConfig: ...@@ -1011,7 +1012,8 @@ class KvRouterConfig:
router_track_active_blocks: Track active blocks for load balancing (default: True) router_track_active_blocks: Track active blocks for load balancing (default: True)
router_track_output_blocks: Track output blocks during generation (default: False). router_track_output_blocks: Track output blocks during generation (default: False).
When enabled, the router adds placeholder blocks as tokens are generated When enabled, the router adds placeholder blocks as tokens are generated
and applies fractional decay based on progress toward expected_output_tokens. and applies fractional decay based on progress toward expected output
sequence length (agent_hints.osl in nvext).
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True). router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes. When True, computes actual block hashes. When False, generates random hashes.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000) router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
...@@ -1019,6 +1021,10 @@ class KvRouterConfig: ...@@ -1019,6 +1021,10 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0) router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20) router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8) router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: None).
When set, requests are queued if all workers exceed this fraction of
max_num_batched_tokens. Enables priority scheduling via latency_sensitivity hints.
If None, queueing is disabled and all requests go directly to the scheduler.
router_event_threads: Number of event processing threads (default: 1). router_event_threads: Number of event processing threads (default: 1).
When > 1, uses a concurrent radix tree with a thread pool. When > 1, uses a concurrent radix tree with a thread pool.
""" """
......
...@@ -32,6 +32,7 @@ pub mod metrics; ...@@ -32,6 +32,7 @@ pub mod metrics;
pub mod prefill_router; pub mod prefill_router;
pub mod publisher; pub mod publisher;
pub mod push_router; pub mod push_router;
pub mod queue;
pub mod recorder; pub mod recorder;
pub mod scheduler; pub mod scheduler;
pub mod sequence; pub mod sequence;
...@@ -53,7 +54,7 @@ use crate::{ ...@@ -53,7 +54,7 @@ use crate::{
compute_block_hash_for_seq, compute_block_hash_for_seq,
}, },
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError, sequence::{SequenceError, SequenceRequest},
}, },
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
}; };
...@@ -315,6 +316,7 @@ impl KvRouter { ...@@ -315,6 +316,7 @@ impl KvRouter {
kv_router_config.router_replica_sync, kv_router_config.router_replica_sync,
router_id, router_id,
worker_type, worker_type,
kv_router_config.router_queue_threshold,
) )
.await?; .await?;
...@@ -371,6 +373,7 @@ impl KvRouter { ...@@ -371,6 +373,7 @@ impl KvRouter {
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64,
) -> anyhow::Result<(WorkerWithDpRank, u32)> { ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let start = Instant::now(); let start = Instant::now();
...@@ -404,6 +407,7 @@ impl KvRouter { ...@@ -404,6 +407,7 @@ impl KvRouter {
router_config_override, router_config_override,
update_states, update_states,
lora_name, lora_name,
priority_jump,
) )
.await?; .await?;
let total_elapsed = start.elapsed(); let total_elapsed = start.elapsed();
...@@ -458,15 +462,15 @@ impl KvRouter { ...@@ -458,15 +462,15 @@ impl KvRouter {
if let Err(e) = self if let Err(e) = self
.scheduler .scheduler
.add_request( .add_request(SequenceRequest {
request_id.clone(), request_id: request_id.clone(),
maybe_seq_hashes, token_sequence: maybe_seq_hashes,
isl_tokens, isl: isl_tokens,
overlap_blocks, overlap: overlap_blocks,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name, lora_name,
) })
.await .await
{ {
tracing::warn!("Failed to add request {request_id}: {e}"); tracing::warn!("Failed to add request {request_id}: {e}");
...@@ -555,7 +559,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -555,7 +559,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New { tokens } => {
let (best_worker, overlap_blocks) = self let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true, None) .find_best_match(Some(&context_id), &tokens, None, true, None, 0.0)
.await?; .await?;
RouterResponse::New { RouterResponse::New {
......
...@@ -47,7 +47,7 @@ pub struct KvRouterConfig { ...@@ -47,7 +47,7 @@ pub struct KvRouterConfig {
/// Whether to track output blocks during generation (default: false) /// Whether to track output blocks during generation (default: false)
/// When enabled, the router adds placeholder blocks as tokens are generated /// When enabled, the router adds placeholder blocks as tokens are generated
/// and applies fractional decay based on progress toward expected_output_tokens. /// and applies fractional decay based on progress toward agent_hints.osl.
pub router_track_output_blocks: bool, pub router_track_output_blocks: bool,
/// Whether to assume KV cache reuse when tracking active blocks (default: true). /// Whether to assume KV cache reuse when tracking active blocks (default: true).
...@@ -74,6 +74,13 @@ pub struct KvRouterConfig { ...@@ -74,6 +74,13 @@ pub struct KvRouterConfig {
#[validate(range(min = 0.0, max = 1.0))] #[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64, pub router_prune_target_ratio: f64,
/// Queue threshold fraction for prefill token capacity.
/// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
/// If None (default), queueing is disabled and all requests go directly to ready.
/// Must be > 0.
#[validate(range(min = 0.0))]
pub router_queue_threshold: Option<f64>,
/// Number of event processing threads for the KV indexer. /// Number of event processing threads for the KV indexer.
/// When > 1, uses ConcurrentRadixTree with a thread pool instead of the /// When > 1, uses ConcurrentRadixTree with a thread pool instead of the
/// single-threaded RadixTree. Default: 1. /// single-threaded RadixTree. Default: 1.
...@@ -97,6 +104,7 @@ impl Default for KvRouterConfig { ...@@ -97,6 +104,7 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default() router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8, router_prune_target_ratio: 0.8,
router_queue_threshold: None,
router_event_threads: 1, router_event_threads: 1,
} }
} }
......
...@@ -290,12 +290,17 @@ impl PrefillRouter { ...@@ -290,12 +290,17 @@ impl PrefillRouter {
InnerPrefillRouter::KvRouter(r) => r, InnerPrefillRouter::KvRouter(r) => r,
_ => return None, _ => return None,
}; };
// Extract LORA name from routing hints // Extract LORA name and priority jump from routing hints
let lora_name = req.routing.as_ref().and_then(|r| r.lora_name.clone()); let lora_name = req.routing.as_ref().and_then(|r| r.lora_name.clone());
let priority_jump = req
.routing
.as_ref()
.and_then(|r| r.priority_jump)
.unwrap_or(0.0);
match async { match async {
kv_router kv_router
.chooser .chooser
.find_best_match(None, &req.token_ids, None, false, lora_name) .find_best_match(None, &req.token_ids, None, false, lora_name, priority_jump)
.await .await
} }
.instrument(tracing::info_span!("kv_find_best_match")) .instrument(tracing::info_span!("kv_find_best_match"))
......
...@@ -124,6 +124,7 @@ impl KvPushRouter { ...@@ -124,6 +124,7 @@ impl KvPushRouter {
) -> Result<WorkerSelection, Error> { ) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref(); let routing = request.routing.as_ref();
let lora_name = routing.and_then(|r| r.lora_name.clone()); let lora_name = routing.and_then(|r| r.lora_name.clone());
let priority_jump = routing.and_then(|r| r.priority_jump).unwrap_or(0.0);
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens); let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
...@@ -147,6 +148,7 @@ impl KvPushRouter { ...@@ -147,6 +148,7 @@ impl KvPushRouter {
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!is_query_only, !is_query_only,
lora_name, lora_name,
priority_jump,
) )
.await?; .await?;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify};
use crate::discovery::RuntimeConfigWatch;
use super::protocols::WorkerWithDpRank;
use super::scheduler::SchedulingRequest;
use super::sequence::ActiveSequencesMultiWorker;
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000;
/// Entry in the priority queue, ordered by effective arrival time (lower = higher priority).
/// Effective arrival = elapsed time since queue start minus `priority_jump`.
struct QueueEntry {
effective_offset: Duration,
request: SchedulingRequest,
}
impl Eq for QueueEntry {}
impl PartialEq for QueueEntry {
fn eq(&self, other: &Self) -> bool {
self.effective_offset == other.effective_offset
}
}
impl Ord for QueueEntry {
fn cmp(&self, other: &Self) -> Ordering {
// BinaryHeap is a max-heap; reverse so lower effective_offset = higher priority
other.effective_offset.cmp(&self.effective_offset)
}
}
impl PartialOrd for QueueEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Queue for managing scheduling requests with interior mutability.
/// Requests are held in `pending` when all workers are busy, and moved to `ready` when capacity frees up.
/// If queueing is disabled (threshold_frac is None), all requests go directly to `ready`.
/// Requests are ordered by effective arrival time: arrival_offset - priority_jump.
pub struct SchedulerQueue {
pending: Mutex<BinaryHeap<QueueEntry>>,
ready: Mutex<VecDeque<SchedulingRequest>>,
slots: Arc<ActiveSequencesMultiWorker>,
workers_with_configs: RuntimeConfigWatch,
ready_notify: Arc<Notify>,
/// Cached threshold fraction; None means queueing is disabled.
threshold_frac: Option<f64>,
/// Reference instant for computing arrival offsets.
start_time: Instant,
}
impl SchedulerQueue {
pub fn new(
slots: Arc<ActiveSequencesMultiWorker>,
workers_with_configs: RuntimeConfigWatch,
ready_notify: Arc<Notify>,
threshold_frac: Option<f64>,
) -> Self {
if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}");
}
Self {
pending: Mutex::new(BinaryHeap::new()),
ready: Mutex::new(VecDeque::new()),
slots,
workers_with_configs,
ready_notify,
threshold_frac,
start_time: Instant::now(),
}
}
/// Build a QueueEntry for a request, computing its effective arrival offset.
fn make_entry(&self, request: SchedulingRequest) -> QueueEntry {
let arrival_offset = self.start_time.elapsed();
let jump = Duration::from_secs_f64(request.priority_jump.max(0.0));
let effective_offset = arrival_offset.saturating_sub(jump);
QueueEntry {
effective_offset,
request,
}
}
/// Enqueue a new request.
/// If queueing is disabled (threshold not set), fast-track to ready.
/// Otherwise, check busy condition and place in ready or pending.
pub async fn enqueue(&self, request: SchedulingRequest) {
let Some(threshold) = self.threshold_frac else {
self.ready.lock().await.push_back(request);
return;
};
if self.all_workers_busy(threshold).await {
tracing::debug!("all workers busy, queueing request");
let entry = self.make_entry(request);
self.pending.lock().await.push(entry);
} else {
self.ready.lock().await.push_back(request);
}
}
/// Try to dequeue the highest-priority request from the ready queue.
pub async fn try_dequeue(&self) -> Option<SchedulingRequest> {
self.ready.lock().await.pop_front()
}
/// Called on prefill_complete/free. Re-checks pending requests and moves eligible to ready.
/// Notifies scheduler loop if any requests were moved.
pub async fn update(&self) {
let Some(threshold) = self.threshold_frac else {
return;
};
let mut moved = false;
loop {
if self.pending.lock().await.is_empty() {
break;
}
if self.all_workers_busy(threshold).await {
break;
}
let entry = self.pending.lock().await.pop();
if let Some(entry) = entry {
tracing::debug!("moving request from pending to ready");
self.ready.lock().await.push_back(entry.request);
moved = true;
} else {
break;
}
}
if moved {
self.ready_notify.notify_one();
}
}
/// Check if all workers are busy based on threshold.
/// Returns true only if ALL workers exceed the threshold (no worker has capacity).
async fn all_workers_busy(&self, threshold: f64) -> bool {
let active_tokens = self.slots.active_tokens().await;
let configs = self.workers_with_configs.borrow();
for (&worker_id, config) in configs.iter() {
let dp_size = config.data_parallel_size;
let max_batched = config
.max_num_batched_tokens
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
if (tokens as f64) <= threshold * (max_batched as f64) {
return false;
}
}
}
true
}
}
...@@ -13,12 +13,14 @@ use std::sync::Arc; ...@@ -13,12 +13,14 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
use std::time::Instant; use std::time::Instant;
use tokio::sync::Notify;
use super::KvRouterConfig; use super::KvRouterConfig;
use super::RouterConfigOverride; use super::RouterConfigOverride;
use super::WorkerSelector; use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank}; use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::sequence::{ActiveSequencesMultiWorker, SequenceError}; use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
...@@ -61,6 +63,8 @@ pub struct SchedulingRequest { ...@@ -61,6 +63,8 @@ pub struct SchedulingRequest {
pub update_states: bool, pub update_states: bool,
// LORA adapter name extracted from request.model field // LORA adapter name extracted from request.model field
pub lora_name: Option<String>, pub lora_name: Option<String>,
/// Priority jump in seconds; decreases effective arrival time in the queue.
pub priority_jump: f64,
// Option to take it out to send the response without moving the struct // Option to take it out to send the response without moving the struct
resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>, resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
} }
...@@ -82,9 +86,11 @@ impl SchedulingRequest { ...@@ -82,9 +86,11 @@ impl SchedulingRequest {
pub struct KvScheduler { pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>, request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMultiWorker>,
queue: Arc<SchedulerQueue>,
} }
impl KvScheduler { impl KvScheduler {
#[allow(clippy::too_many_arguments)]
pub async fn start( pub async fn start(
component: Component, component: Component,
block_size: u32, block_size: u32,
...@@ -93,6 +99,7 @@ impl KvScheduler { ...@@ -93,6 +99,7 @@ impl KvScheduler {
replica_sync: bool, replica_sync: bool,
router_id: u64, router_id: u64,
worker_type: &'static str, worker_type: &'static str,
queue_threshold: Option<f64>,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
...@@ -150,25 +157,48 @@ impl KvScheduler { ...@@ -150,25 +157,48 @@ impl KvScheduler {
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token(); let scheduler_cancel_token = component.drt().primary_token();
// Create queue with shared notify for waking the scheduler loop
let ready_notify = Arc::new(Notify::new());
let queue = Arc::new(SchedulerQueue::new(
slots.clone(),
workers_with_configs.clone(),
ready_notify.clone(),
queue_threshold,
));
let queue_clone = queue.clone();
// Background task to handle scheduling requests // Background task to handle scheduling requests
tokio::spawn(async move { tokio::spawn(async move {
let mut request_rx = request_rx; let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
tracing::trace!("scheduler background task started"); tracing::trace!("scheduler background task started");
loop { loop {
// Check for cancellation at beginning of loop // Use select! to wait on: new request, ready_notify, periodic recheck, or cancellation
if scheduler_cancel_token.is_cancelled() { tokio::select! {
_ = scheduler_cancel_token.cancelled() => {
tracing::trace!("scheduler background task shutting down"); tracing::trace!("scheduler background task shutting down");
break; break;
} }
request = request_rx.recv() => {
// Wait for a new request let Some(request) = request else {
let Some(mut request) = request_rx.recv().await else {
tracing::warn!("scheduler shutdown"); tracing::warn!("scheduler shutdown");
break; break;
}; };
tracing::trace!("received request to be scheduled"); tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = ready_notify.notified() => {
// Woken by update() after prefill_complete/free - just continue to drain ready queue
}
_ = recheck_interval.tick() => {
// Periodic recheck to prevent requests stuck in pending
queue_clone.update().await;
}
}
// Drain ALL ready requests (each iteration uses fresh slot state)
while let Some(mut request) = queue_clone.try_dequeue().await {
let (decode_blocks, prefill_tokens) = slots_clone let (decode_blocks, prefill_tokens) = slots_clone
.potential_blocks_and_tokens( .potential_blocks_and_tokens(
request.token_seq.clone(), request.token_seq.clone(),
...@@ -180,7 +210,8 @@ impl KvScheduler { ...@@ -180,7 +210,8 @@ impl KvScheduler {
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
// Read the current workers configuration from watch receiver // Read the current workers configuration from watch receiver
let workers: HashMap<WorkerId, ModelRuntimeConfig> = scheduler_rx.borrow().clone(); let workers: HashMap<WorkerId, ModelRuntimeConfig> =
scheduler_rx.borrow().clone();
match selector.select_worker(&workers, &request, block_size) { match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => { Ok(selection) => {
...@@ -203,28 +234,26 @@ impl KvScheduler { ...@@ -203,28 +234,26 @@ impl KvScheduler {
}; };
if let Err(e) = slots_clone if let Err(e) = slots_clone
.add_request( .add_request(SequenceRequest {
request_id.clone(), request_id: request_id.clone(),
request.token_seq, token_sequence: request.token_seq,
request.isl_tokens, isl: request.isl_tokens,
selection.overlap_blocks, overlap: selection.overlap_blocks,
None, // expected_output_tokens not available in scheduler loop expected_output_tokens: None,
selection.worker, worker: selection.worker,
request.lora_name.clone(), lora_name: request.lora_name.clone(),
) })
.await .await
{ {
tracing::warn!("Failed to add request {request_id}: {e}"); tracing::warn!("Failed to add request {request_id}: {e}");
} }
} }
Err(KvSchedulerError::NoEndpoints) => { Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update"); tracing::warn!("no endpoints available, dropping request");
tokio::time::sleep(Duration::from_millis(5)).await;
continue;
} }
Err(e) => { Err(e) => {
tracing::error!("error scheduling request: {:?}", e); tracing::error!("error scheduling request: {:?}", e);
break; }
} }
} }
} }
...@@ -232,7 +261,11 @@ impl KvScheduler { ...@@ -232,7 +261,11 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down"); tracing::trace!("background endpoint subscriber shutting down");
}); });
Ok(KvScheduler { request_tx, slots }) Ok(KvScheduler {
request_tx,
slots,
queue,
})
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
...@@ -245,6 +278,7 @@ impl KvScheduler { ...@@ -245,6 +278,7 @@ impl KvScheduler {
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64,
) -> Result<WorkerWithDpRank, KvSchedulerError> { ) -> Result<WorkerWithDpRank, KvSchedulerError> {
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
let start = Instant::now(); let start = Instant::now();
...@@ -260,7 +294,8 @@ impl KvScheduler { ...@@ -260,7 +294,8 @@ impl KvScheduler {
router_config_override: router_config_override.cloned(), router_config_override: router_config_override.cloned(),
update_states, update_states,
lora_name, lora_name,
resp_tx: Some(resp_tx), // Wrap in Some() priority_jump,
resp_tx: Some(resp_tx),
}; };
self.request_tx self.request_tx
...@@ -288,38 +323,22 @@ impl KvScheduler { ...@@ -288,38 +323,22 @@ impl KvScheduler {
Ok(response.best_worker) Ok(response.best_worker)
} }
#[allow(clippy::too_many_arguments)] pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
pub async fn add_request( self.slots.add_request(req).await
&self,
request_id: String,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
lora_name: Option<String>,
) -> Result<(), SequenceError> {
self.slots
.add_request(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
worker,
lora_name,
)
.await
} }
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> { pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots self.slots
.mark_prefill_completed(&request_id.to_string()) .mark_prefill_completed(&request_id.to_string())
.await .await?;
self.queue.update().await;
Ok(())
} }
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> { pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
Ok(())
} }
/// Get the worker type for this scheduler ("prefill" or "decode"). /// Get the worker type for this scheduler ("prefill" or "decode").
......
...@@ -73,6 +73,17 @@ const EXPIRY_DURATION: Duration = Duration::from_secs(300); ...@@ -73,6 +73,17 @@ const EXPIRY_DURATION: Duration = Duration::from_secs(300);
// TODO: use the common request_id if it exists in the repo // TODO: use the common request_id if it exists in the repo
pub type RequestId = String; pub type RequestId = String;
/// Bundled parameters for adding a request to the sequence tracker.
pub struct SequenceRequest {
pub request_id: RequestId,
pub token_sequence: Option<Vec<SequenceHash>>,
pub isl: usize,
pub overlap: u32,
pub expected_output_tokens: Option<u32>,
pub worker: WorkerWithDpRank,
pub lora_name: Option<String>,
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug, Getters)]
pub struct ActiveSequences { pub struct ActiveSequences {
...@@ -770,17 +781,17 @@ impl ActiveSequencesMultiWorker { ...@@ -770,17 +781,17 @@ impl ActiveSequencesMultiWorker {
} }
} }
#[allow(clippy::too_many_arguments)] pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
pub async fn add_request( let SequenceRequest {
&self, request_id,
request_id: RequestId, token_sequence,
token_sequence: Option<Vec<SequenceHash>>, isl,
isl: usize, overlap,
overlap: u32, expected_output_tokens,
expected_output_tokens: Option<u32>, worker,
worker: WorkerWithDpRank, lora_name,
lora_name: Option<String>, } = req;
) -> Result<(), SequenceError> {
// Clone the sender upfront so we don't hold the DashMap Ref across // Clone the sender upfront so we don't hold the DashMap Ref across
// the .await points below. Also eliminates the TOCTOU between // the .await points below. Also eliminates the TOCTOU between
// contains_key and a later get().unwrap(). // contains_key and a later get().unwrap().
...@@ -791,7 +802,6 @@ impl ActiveSequencesMultiWorker { ...@@ -791,7 +802,6 @@ impl ActiveSequencesMultiWorker {
.value() .value()
.clone(); .clone();
// Check for duplicate request
if let Some(existing_worker) = self.request_to_worker.get(&request_id) { if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
return Err(SequenceError::DuplicateRequest { return Err(SequenceError::DuplicateRequest {
request_id, request_id,
...@@ -799,10 +809,8 @@ impl ActiveSequencesMultiWorker { ...@@ -799,10 +809,8 @@ impl ActiveSequencesMultiWorker {
}); });
} }
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Publish event only if replica_sync is enabled
if self.replica_sync { if self.replica_sync {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
...@@ -819,10 +827,8 @@ impl ActiveSequencesMultiWorker { ...@@ -819,10 +827,8 @@ impl ActiveSequencesMultiWorker {
self.event_publisher.publish(&event).await?; self.event_publisher.publish(&event).await?;
} }
// Update local state with full WorkerWithDpRank
self.request_to_worker.insert(request_id.clone(), worker); self.request_to_worker.insert(request_id.clone(), worker);
// Store lora_name for later use in Free/MarkPrefillCompleted events
if let Some(lora) = lora_name { if let Some(lora) = lora_name {
self.request_to_lora.insert(request_id.clone(), lora); self.request_to_lora.insert(request_id.clone(), lora);
} }
...@@ -838,18 +844,15 @@ impl ActiveSequencesMultiWorker { ...@@ -838,18 +844,15 @@ impl ActiveSequencesMultiWorker {
}) })
.map_err(|_| SequenceError::WorkerChannelClosed)?; .map_err(|_| SequenceError::WorkerChannelClosed)?;
// Wait for response and handle removed requests
let removed_requests = resp_rx let removed_requests = resp_rx
.await .await
.map_err(|_| SequenceError::WorkerChannelClosed)?; .map_err(|_| SequenceError::WorkerChannelClosed)?;
// Remove expired requests from request_to_worker mapping
for expired_id in &removed_requests { for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id); self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id); self.request_to_lora.remove(expired_id);
} }
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await; self.publish_active_load_for_worker(worker).await;
Ok(()) Ok(())
...@@ -1356,41 +1359,41 @@ mod tests { ...@@ -1356,41 +1359,41 @@ mod tests {
// Add request_0 to worker 0, dp_rank 0: sequence [0, 1, 2] // Add request_0 to worker 0, dp_rank 0: sequence [0, 1, 2]
seq_manager_1 seq_manager_1
.add_request( .add_request(SequenceRequest {
"request_0".to_string(), request_id: "request_0".to_string(),
Some(vec![0, 1, 2]), token_sequence: Some(vec![0, 1, 2]),
12, // ISL (3 blocks * 4 block_size) isl: 12,
0, // no overlap overlap: 0,
None, // expected_output_tokens expected_output_tokens: None,
WorkerWithDpRank::new(0, 0), worker: WorkerWithDpRank::new(0, 0),
None, // lora_name lora_name: None,
) })
.await?; .await?;
// Add request_1 to worker 0, dp_rank 1: sequence [3, 4] // Add request_1 to worker 0, dp_rank 1: sequence [3, 4]
seq_manager_1 seq_manager_1
.add_request( .add_request(SequenceRequest {
"request_1".to_string(), request_id: "request_1".to_string(),
Some(vec![3, 4]), token_sequence: Some(vec![3, 4]),
8, // ISL (2 blocks * 4 block_size) isl: 8,
0, // no overlap overlap: 0,
None, // expected_output_tokens expected_output_tokens: None,
WorkerWithDpRank::new(0, 1), worker: WorkerWithDpRank::new(0, 1),
None, // lora_name lora_name: None,
) })
.await?; .await?;
// Add request_2 to worker 1, dp_rank 0: sequence [0, 1, 2, 3] using seq_manager_2 // Add request_2 to worker 1, dp_rank 0: sequence [0, 1, 2, 3] using seq_manager_2
seq_manager_2 seq_manager_2
.add_request( .add_request(SequenceRequest {
"request_2".to_string(), request_id: "request_2".to_string(),
Some(vec![0, 1, 2, 3]), token_sequence: Some(vec![0, 1, 2, 3]),
16, // ISL (4 blocks * 4 block_size) isl: 16,
0, // no overlap overlap: 0,
None, // expected_output_tokens expected_output_tokens: None,
WorkerWithDpRank::new(1, 0), worker: WorkerWithDpRank::new(1, 0),
None, // lora_name lora_name: None,
) })
.await?; .await?;
// Give some time for synchronization // Give some time for synchronization
...@@ -1535,41 +1538,41 @@ mod tests { ...@@ -1535,41 +1538,41 @@ mod tests {
// Add request_0 to worker 0 with no token sequence // Add request_0 to worker 0 with no token sequence
seq_manager_1 seq_manager_1
.add_request( .add_request(SequenceRequest {
"request_0".to_string(), request_id: "request_0".to_string(),
None, // No token sequence token_sequence: None,
12, // ISL (12 tokens) isl: 12,
0, // no overlap overlap: 0,
None, // expected_output_tokens expected_output_tokens: None,
WorkerWithDpRank::from_worker_id(0), worker: WorkerWithDpRank::from_worker_id(0),
None, // lora_name lora_name: None,
) })
.await?; .await?;
// Add request_1 to worker 1 with no token sequence // Add request_1 to worker 1 with no token sequence
seq_manager_1 seq_manager_1
.add_request( .add_request(SequenceRequest {
"request_1".to_string(), request_id: "request_1".to_string(),
None, // No token sequence token_sequence: None,
8, // ISL (8 tokens) isl: 8,
0, // no overlap overlap: 0,
None, // expected_output_tokens expected_output_tokens: None,
WorkerWithDpRank::from_worker_id(1), worker: WorkerWithDpRank::from_worker_id(1),
None, // lora_name lora_name: None,
) })
.await?; .await?;
// Add request_2 to worker 2 with no token sequence using seq_manager_2 // Add request_2 to worker 2 with no token sequence using seq_manager_2
seq_manager_2 seq_manager_2
.add_request( .add_request(SequenceRequest {
"request_2".to_string(), request_id: "request_2".to_string(),
None, // No token sequence token_sequence: None,
16, // ISL (16 tokens) isl: 16,
0, // no overlap overlap: 0,
None, // expected_output_tokens expected_output_tokens: None,
WorkerWithDpRank::from_worker_id(2), worker: WorkerWithDpRank::from_worker_id(2),
None, // lora_name lora_name: None,
) })
.await?; .await?;
// Give some time for synchronization // Give some time for synchronization
......
...@@ -274,13 +274,15 @@ impl OpenAIPreprocessor { ...@@ -274,13 +274,15 @@ impl OpenAIPreprocessor {
// Extract routing hints from nvext if present // Extract routing hints from nvext if present
if let Some(nvext) = request.nvext() { if let Some(nvext) = request.nvext() {
// Build routing hints from nvext fields // Build routing hints from nvext fields
let hints = nvext.agent_hints.as_ref();
let routing = RoutingHints { let routing = RoutingHints {
backend_instance_id: nvext.backend_instance_id, backend_instance_id: nvext.backend_instance_id,
prefill_worker_id: nvext.prefill_worker_id, prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id, decode_worker_id: nvext.decode_worker_id,
dp_rank: None, // dp_rank is set later in the pipeline dp_rank: None, // dp_rank is set later in the pipeline
enable_local_updates: nvext.enable_local_updates, enable_local_updates: nvext.enable_local_updates,
expected_output_tokens: nvext.expected_output_tokens, expected_output_tokens: hints.and_then(|h| h.osl),
priority_jump: hints.and_then(|h| h.latency_sensitivity),
lora_name, lora_name,
}; };
builder.routing(Some(routing)); builder.routing(Some(routing));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment