Unverified Commit 4ede59a2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: speculative prefill (#6230)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent f4f82762
......@@ -164,9 +164,9 @@ dependencies = [
[[package]]
name = "arc-swap"
version = "1.8.1"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ded5f9a03ac8f24d1b8a25101ee812cd32cdc8c50a4c50237de2c4915850e73"
checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5"
dependencies = [
"rustversion",
]
......@@ -1926,6 +1926,21 @@ dependencies = [
"uuid",
]
[[package]]
name = "dynamo-bench"
version = "0.9.0"
dependencies = [
"anyhow",
"clap 4.5.58",
"futures-util",
"indicatif 0.18.4",
"rand 0.9.2",
"reqwest 0.12.28",
"serde",
"serde_json",
"tokio",
]
[[package]]
name = "dynamo-codegen"
version = "0.1.0"
......@@ -1951,12 +1966,13 @@ dependencies = [
"async-trait",
"clap 4.5.58",
"dashmap 6.1.0",
"dynamo-bench",
"dynamo-mocker",
"dynamo-runtime",
"dynamo-tokens",
"flume",
"futures",
"indicatif 0.18.3",
"indicatif 0.18.4",
"minstant",
"parking_lot",
"prometheus",
......@@ -2007,6 +2023,7 @@ dependencies = [
"derive_builder",
"dialoguer",
"dynamo-async-openai",
"dynamo-bench",
"dynamo-kv-router",
"dynamo-memory",
"dynamo-mocker",
......@@ -2026,7 +2043,7 @@ dependencies = [
"hyper 1.8.1",
"hyper-util",
"image",
"indicatif 0.18.3",
"indicatif 0.18.4",
"insta",
"itertools 0.14.0",
"json-five",
......@@ -3665,9 +3682,9 @@ dependencies = [
[[package]]
name = "indicatif"
version = "0.18.3"
version = "0.18.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88"
checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb"
dependencies = [
"console 0.16.2",
"portable-atomic",
......@@ -5800,9 +5817,9 @@ dependencies = [
[[package]]
name = "png"
version = "0.18.0"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0"
checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61"
dependencies = [
"bitflags 2.11.0",
"crc32fast",
......
......@@ -13,6 +13,7 @@ members = [
"lib/kvbm-logical",
"lib/async-openai",
"lib/parsers",
"lib/bench",
"lib/bindings/c",
"lib/bindings/python/codegen",
"lib/config",
......@@ -29,6 +30,7 @@ default-members = [
"lib/kvbm-logical",
"lib/async-openai",
"lib/parsers",
"lib/bench",
"lib/bindings/c",
]
resolver = "3"
......
......@@ -14,6 +14,7 @@ USE_MOCKERS=false
USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill
BASE_GPU_OFFSET=0
REASONING=""
EXTRA_ARGS=()
# Parse arguments
......@@ -55,6 +56,10 @@ while [[ $# -gt 0 ]]; do
BASE_GPU_OFFSET="$2"
shift 2
;;
--reasoning)
REASONING="$2"
shift 2
;;
--)
shift
EXTRA_ARGS+=("$@")
......@@ -192,6 +197,9 @@ if [ "$USE_MOCKERS" = true ]; then
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ -n "$REASONING" ]; then
MOCKER_ARGS+=("--reasoning" "$REASONING")
fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
python -m dynamo.mocker "${MOCKER_ARGS[@]}" &
......
......@@ -119,6 +119,11 @@ def create_temp_engine_args_file(args) -> Path:
# Note: bootstrap_port is NOT included here - it's set per-worker in launch_workers()
}
# Parse --reasoning JSON string into a nested object
reasoning_str = getattr(args, "reasoning", None)
if reasoning_str:
engine_args["reasoning"] = json.loads(reasoning_str)
# Remove None values to only include explicitly set arguments
engine_args = {k: v for k, v in engine_args.items() if v is not None}
......@@ -279,6 +284,16 @@ def parse_args():
"All workers share the same tokio runtime and thread pool.",
)
# Reasoning token output
parser.add_argument(
"--reasoning",
type=str,
default=None,
help="Enable reasoning token output. JSON object with fields: "
"start_thinking_token_id (u32), end_thinking_token_id (u32), thinking_ratio (0.0-1.0). "
'Example: \'{"start_thinking_token_id": 123, "end_thinking_token_id": 456, "thinking_ratio": 0.6}\'',
)
# Legacy support - allow direct JSON file specification
parser.add_argument(
"--extra-engine-args",
......
---
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
title: Agent Hints
subtitle: Per-request hints for scheduling, load balancing, and KV cache optimization
---
# Agent Hints
Agent hints are optional per-request hints passed via the `nvext.agent_hints` field in the request body. They allow the calling agent or application to communicate request-level metadata that the router uses to improve scheduling, load balancing, and KV cache utilization.
```json
{
"nvext": {
"agent_hints": {
"latency_sensitivity": 5.0,
"osl": 512,
"speculative_prefill": true
}
}
}
```
All three fields are optional and independent — you can use any combination.
## `latency_sensitivity`
Priority scheduling hint, specified in seconds. When `--router-queue-threshold` is set and the queue is active, this value shifts the request's effective arrival time earlier in the queue, giving it priority over requests with lower (or no) `latency_sensitivity`. A value of `5.0` means the request is treated as if it arrived 5 seconds earlier than it actually did. Has no effect when queueing is disabled.
- **Type**: `f64` (optional)
- **Recommended default**: `1.2` for latency-sensitive agentic requests
- **Requires**: `--router-queue-threshold` to be set
### Example
```json
{
"nvext": {
"agent_hints": {
"latency_sensitivity": 5.0
}
}
}
```
A request with `latency_sensitivity: 5.0` arriving at time `T` is treated as if it arrived at `T - 5s`, so it will be scheduled ahead of requests that arrived within the last 5 seconds (unless they have even higher sensitivity).
## `osl`
Expected output sequence length — the estimated number of output tokens the request will generate. The router uses this hint in two ways:
1. **Output block tracking**: When `--track-output-blocks` is enabled, the router adds placeholder blocks during generation and applies fractional decay based on progress toward `osl`. This gives the router a more accurate picture of each worker's KV cache utilization for long-running requests.
2. **Resource estimation**: Helps the router estimate total resource requirements when making routing decisions.
- **Type**: `u32` (optional)
- **Requires**: `--track-output-blocks` for output block tracking behavior
### Example
```json
{
"nvext": {
"agent_hints": {
"osl": 1024
}
}
}
```
If the request is expected to generate ~1024 tokens, providing `osl: 1024` lets the router account for the output-side KV cache growth when balancing load across workers.
## `speculative_prefill`
When set to `true`, the system speculatively prefills the predicted next-turn prompt after the current assistant turn completes. This is designed for multi-turn agentic workloads where the next request's prefix is predictable.
- **Type**: `bool` (optional, defaults to `false`)
- **No additional CLI flags required**; works automatically when the hint is set in the request
### How it works
1. As the assistant response streams, the system accumulates the full response text.
2. Once the response finishes (indicated by `finish_reason`), a background task constructs the next-turn prompt by appending the assistant response to the conversation history (with thinking content stripped by the chat template for non-last assistant turns).
3. The constructed prompt is tokenized and sent through the pipeline as a `max_tokens=1` request to warm the KV cache on a worker.
4. When the actual next request arrives, it benefits from the already-warm KV cache, reducing TTFT.
### Example
```json
{
"nvext": {
"agent_hints": {
"speculative_prefill": true
}
}
}
```
This is most effective for reasoning models in agentic loops, where the conversation grows incrementally and the next turn's prefix (everything up to the new user message) is the same as the current conversation.
## See Also
- **[Router Guide](router-guide.md)**: Full router configuration and CLI arguments
- **[Router Examples](router-examples.md)**: Usage patterns and benchmarking
......@@ -202,6 +202,8 @@ The main KV-aware routing arguments:
To implement KV event publishing for custom inference engines, enabling them to participate in Dynamo's KV cache-aware routing, see [KV Event Publishing for Custom Engines](../../integrations/kv-events-custom-engines.md).
For details on per-request agent hints (`latency_sensitivity`, `osl`, `speculative_prefill`), see the [Agent Hints Guide](agent-hints.md).
## Basic Routing
Dynamo supports several routing strategies when sending requests from one component to another component's endpoint.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-bench"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
description = "Lightweight HTTP benchmarks for Dynamo endpoints"
[[bin]]
name = "multiturn_bench"
path = "src/bin/multiturn_bench.rs"
[dependencies]
anyhow = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
futures-util = "0.3"
indicatif = "0.18"
rand = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared utilities for benchmark binaries.
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
// ---------------------------------------------------------------------------
// Latency statistics
// ---------------------------------------------------------------------------
#[derive(Debug, Clone)]
pub struct LatencyStats {
pub min: Duration,
pub max: Duration,
pub avg: Duration,
pub p50: Duration,
pub p95: Duration,
pub p99: Duration,
pub throughput_ops_sec: f64,
}
impl LatencyStats {
pub fn from_durations(durations: &[Duration]) -> Option<Self> {
if durations.is_empty() {
return None;
}
let mut sorted = durations.to_vec();
sorted.sort();
let n = sorted.len();
let total: Duration = sorted.iter().sum();
let avg = total / n as u32;
Some(Self {
min: sorted[0],
max: sorted[n - 1],
avg,
p50: sorted[n / 2],
p95: sorted[n * 95 / 100],
p99: sorted[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
})
}
/// Print formatted latency statistics to stdout.
pub fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
// ---------------------------------------------------------------------------
// Time-bucketed latency statistics
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)]
pub struct TimeBucketStats {
pub bucket_start_sec: u64,
pub bucket_end_sec: u64,
pub count: usize,
pub latency_min_us: u64,
pub latency_p50_us: u64,
pub latency_p95_us: u64,
pub latency_max_us: u64,
}
/// Compute per-bucket latency statistics.
///
/// Each item is a `(latency, completion_time)` pair where `completion_time`
/// is relative to the measurement start.
pub fn compute_time_bucket_stats(
items: &[(Duration, Duration)],
bucket_size_secs: u64,
) -> Vec<TimeBucketStats> {
if items.is_empty() || bucket_size_secs == 0 {
return Vec::new();
}
let max_completion = items
.iter()
.map(|&(_, ct)| ct)
.max()
.unwrap_or(Duration::ZERO);
let num_buckets = (max_completion.as_secs() / bucket_size_secs) + 1;
let mut bucket_latencies: Vec<Vec<Duration>> = vec![Vec::new(); num_buckets as usize];
for &(latency, completion_time) in items {
let bucket_idx = (completion_time.as_secs() / bucket_size_secs) as usize;
if bucket_idx < bucket_latencies.len() {
bucket_latencies[bucket_idx].push(latency);
}
}
bucket_latencies
.iter()
.enumerate()
.filter_map(|(idx, latencies)| {
if latencies.is_empty() {
return None;
}
let stats = LatencyStats::from_durations(latencies)?;
Some(TimeBucketStats {
bucket_start_sec: idx as u64 * bucket_size_secs,
bucket_end_sec: (idx as u64 + 1) * bucket_size_secs,
count: latencies.len(),
latency_min_us: stats.min.as_micros() as u64,
latency_p50_us: stats.p50.as_micros() as u64,
latency_p95_us: stats.p95.as_micros() as u64,
latency_max_us: stats.max.as_micros() as u64,
})
})
.collect()
}
pub fn print_time_bucket_report(buckets: &[TimeBucketStats]) {
if buckets.is_empty() {
println!(" No time bucket data available");
return;
}
println!(
" {:>8} {:>8} {:>12} {:>12} {:>12} {:>12}",
"Time(s)", "Count", "Min(ms)", "P50(ms)", "P95(ms)", "Max(ms)"
);
println!(" {}", "-".repeat(68));
for bucket in buckets {
println!(
" {:>3}-{:<4} {:>8} {:>12.1} {:>12.1} {:>12.1} {:>12.1}",
bucket.bucket_start_sec,
bucket.bucket_end_sec,
bucket.count,
bucket.latency_min_us as f64 / 1000.0,
bucket.latency_p50_us as f64 / 1000.0,
bucket.latency_p95_us as f64 / 1000.0,
bucket.latency_max_us as f64 / 1000.0,
);
}
}
// ---------------------------------------------------------------------------
// Latency sample (for raw JSON export)
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)]
pub struct LatencySample {
pub latency_us: u64,
pub completion_time_ms: u64,
pub success: bool,
}
// ---------------------------------------------------------------------------
// OpenAI-style chat types
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
}
// ---------------------------------------------------------------------------
// Model auto-detection
// ---------------------------------------------------------------------------
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
struct ModelInfo {
id: String,
}
pub async fn fetch_model_name(frontend_url: &str) -> Result<String> {
let client = reqwest::Client::new();
let url = format!("{}/v1/models", frontend_url);
println!(" Auto-detecting model from {}...", url);
let response = client
.get(&url)
.send()
.await
.context("Failed to connect to frontend /v1/models endpoint")?;
if !response.status().is_success() {
anyhow::bail!("Models endpoint returned status: {}", response.status());
}
let models: ModelsResponse = response
.json()
.await
.context("Failed to parse models response")?;
match models.data.len() {
0 => anyhow::bail!("No models found at endpoint. Is a backend running?"),
1 => {
let model_id = models.data[0].id.clone();
println!(" Auto-detected model: {}", model_id);
Ok(model_id)
}
n => {
println!(" Multiple models available ({}):", n);
for m in &models.data {
println!(" - {}", m.id);
}
anyhow::bail!("Multiple models available. Please specify --model explicitly.")
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod common;
......@@ -44,6 +44,7 @@ indicatif = { version = "0.18.0", optional = true }
uuid = { workspace = true, optional = true }
[dev-dependencies]
dynamo-bench = { path = "../bench" }
rstest = "0.18.2"
rstest_reuse = "0.7.0"
serde_json = { workspace = true }
......
......@@ -14,9 +14,10 @@
//! cargo bench --package dynamo-kv-router --bench kv_indexer_bench --features bench -- stress --help
use clap::{Args, Parser, Subcommand, ValueEnum};
use dynamo_bench::common::LatencyStats;
use dynamo_kv_router::{
ConcurrentRadixTree,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
bench_utils::{SequenceData, generate_sequences},
indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded, ThreadPoolIndexer,
},
......@@ -415,7 +416,7 @@ async fn bench_store<I: BenchableIndexer>(
}
}
LatencyStats::from_durations(durations).unwrap()
LatencyStats::from_durations(&durations).unwrap()
}
/// Benchmark find_matches operation (hit case)
......@@ -442,7 +443,7 @@ async fn bench_find_matches_hit<I: BenchableIndexer>(
}
}
LatencyStats::from_durations(durations).unwrap()
LatencyStats::from_durations(&durations).unwrap()
}
/// Benchmark find_matches operation (miss case)
......@@ -470,7 +471,7 @@ async fn bench_find_matches_miss<I: BenchableIndexer>(
}
}
LatencyStats::from_durations(durations).unwrap()
LatencyStats::from_durations(&durations).unwrap()
}
/// Benchmark apply_event (remove) operation
......@@ -501,7 +502,7 @@ async fn bench_remove<I: BenchableIndexer>(
}
}
LatencyStats::from_durations(durations).unwrap()
LatencyStats::from_durations(&durations).unwrap()
}
/// Run all microbenchmarks for an indexer
......@@ -854,7 +855,7 @@ async fn run_stress_test<I: BenchableIndexer + 'static>(
let _ = indexer.find_matches(seq.local_hashes.clone()).await;
baseline_durations.push(start.elapsed());
}
let stats = LatencyStats::from_durations(baseline_durations.clone()).unwrap();
let stats = LatencyStats::from_durations(&baseline_durations).unwrap();
let baseline_service_time = stats.p50;
let theoretical_max = stats.throughput_ops_sec;
......@@ -1041,7 +1042,7 @@ fn print_stress_results(args: &StressArgs, results: &StressResults) {
println!(" Achieved: {:.1} req/sec", achieved_throughput);
println!();
if let Some(stats) = LatencyStats::from_durations(results.latencies.clone()) {
if let Some(stats) = LatencyStats::from_durations(&results.latencies) {
println!(" Latency (end-to-end, includes queue wait):");
println!(" min: {:>12?}", stats.min);
println!(" p50: {:>12?}", stats.p50);
......@@ -1163,7 +1164,7 @@ fn print_stress_comparison(results: &[StressResults], args: &StressArgs) {
// Latency p50
let mut row = format!("{:<35}", "Latency p50 (us)");
for result in results {
if let Some(stats) = LatencyStats::from_durations(result.latencies.clone()) {
if let Some(stats) = LatencyStats::from_durations(&result.latencies) {
row.push_str(&format!(" {:>18.2}", stats.p50.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>18}", "-"));
......@@ -1174,7 +1175,7 @@ fn print_stress_comparison(results: &[StressResults], args: &StressArgs) {
// Latency p99
let mut row = format!("{:<35}", "Latency p99 (us)");
for result in results {
if let Some(stats) = LatencyStats::from_durations(result.latencies.clone()) {
if let Some(stats) = LatencyStats::from_durations(&result.latencies) {
row.push_str(&format!(" {:>18.2}", stats.p99.as_nanos() as f64 / 1000.0));
} else {
row.push_str(&format!(" {:>18}", "-"));
......
......@@ -14,9 +14,10 @@
//! Run with: cargo bench --package dynamo-kv-router --bench radix_tree_microbench --features bench -- --help
use clap::{Parser, ValueEnum};
use dynamo_bench::common::LatencyStats;
use dynamo_kv_router::{
ConcurrentRadixTree, OverlapScores, PositionalIndexer, RadixTree, RouterEvent, SyncIndexer,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
bench_utils::{SequenceData, generate_sequences},
compute_block_hash_for_seq,
protocols::LocalBlockHash,
};
......@@ -313,7 +314,7 @@ fn bench_hash(args: &Args) {
}
}
let stats = LatencyStats::from_durations(durations).unwrap();
let stats = LatencyStats::from_durations(&durations).unwrap();
stats.print("COMPUTE_BLOCK_HASH", args.depth);
}
......@@ -374,7 +375,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
}
}
let stats = LatencyStats::from_durations(durations).unwrap();
let stats = LatencyStats::from_durations(&durations).unwrap();
stats.print(op_name, args.depth);
}
......@@ -426,7 +427,7 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(hit_durations)
LatencyStats::from_durations(&hit_durations)
.unwrap()
.print("FIND_MATCHES (HIT)", args.depth);
......@@ -442,7 +443,7 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(miss_durations)
LatencyStats::from_durations(&miss_durations)
.unwrap()
.print("FIND_MATCHES (MISS)", args.depth);
......@@ -459,7 +460,7 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(partial_durations)
LatencyStats::from_durations(&partial_durations)
.unwrap()
.print("FIND_MATCHES (PARTIAL)", args.depth);
......@@ -473,7 +474,7 @@ fn bench_find_matches(args: &Args) {
early_exit_durations.push(elapsed);
}
}
LatencyStats::from_durations(early_exit_durations)
LatencyStats::from_durations(&early_exit_durations)
.unwrap()
.print("FIND_MATCHES (EARLY_EXIT)", args.depth);
}
......
......@@ -4,7 +4,6 @@
//! Benchmark utilities for kv-router benchmarks.
//!
//! This module provides shared data structures for benchmarking:
//! - `LatencyStats`: Statistics for latency measurements
//! - `SequenceData`: Pre-generated sequence data for benchmarking
use crate::protocols::{
......@@ -14,60 +13,6 @@ use crate::protocols::{
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::time::Duration;
/// Statistics for latency measurements.
#[derive(Debug, Clone)]
pub struct LatencyStats {
pub min: Duration,
pub max: Duration,
pub avg: Duration,
pub p50: Duration,
pub p95: Duration,
pub p99: Duration,
pub throughput_ops_sec: f64,
}
impl LatencyStats {
/// Compute statistics from a vector of durations.
///
/// Returns `None` if the input is empty.
pub fn from_durations(mut durations: Vec<Duration>) -> Option<Self> {
if durations.is_empty() {
return None;
}
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Some(Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
})
}
/// Print formatted latency statistics to stdout.
pub fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Pre-generated sequence data for benchmarking.
#[derive(Clone)]
pub struct SequenceData {
......
......@@ -206,6 +206,7 @@ insta = { version = "1.41", features = [
lazy_static = "1.4"
mockito = "1.7.0"
dynamo-bench = { path = "../bench" }
[[bin]]
name = "generate-frontend-openapi"
......
......@@ -15,6 +15,10 @@
use anyhow::{Context, Result};
use bytes::Bytes;
use clap::Parser;
use dynamo_bench::common::{
ChatCompletionRequest, ChatMessage, LatencySample, LatencyStats, TimeBucketStats,
compute_time_bucket_stats, fetch_model_name, print_time_bucket_report,
};
use dynamo_runtime::transports::event_plane::EventEnvelope;
use hf_hub;
use indicatif::{ProgressBar, ProgressStyle};
......@@ -592,60 +596,6 @@ struct HealthInstance {
endpoint: String,
}
/// Response from the frontend's /v1/models endpoint
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelInfo>,
}
/// Model info from /v1/models endpoint
#[derive(Debug, Deserialize)]
struct ModelInfo {
id: String,
}
/// Fetch the model name from the frontend's /v1/models endpoint.
///
/// Returns the model ID if exactly one model is available.
/// Returns an error if zero or multiple models are found (requiring explicit --model).
async fn fetch_model_name(frontend_url: &str) -> Result<String> {
let client = reqwest::Client::new();
let url = format!("{}/v1/models", frontend_url);
println!(" Auto-detecting model from {}...", url);
let response = client
.get(&url)
.send()
.await
.context("Failed to connect to frontend /v1/models endpoint")?;
if !response.status().is_success() {
anyhow::bail!("Models endpoint returned status: {}", response.status());
}
let models: ModelsResponse = response
.json()
.await
.context("Failed to parse models response")?;
match models.data.len() {
0 => anyhow::bail!("No models found at endpoint. Is a backend running?"),
1 => {
let model_id = models.data[0].id.clone();
println!(" Auto-detected model: {}", model_id);
Ok(model_id)
}
n => {
println!(" Multiple models available ({}):", n);
for m in &models.data {
println!(" - {}", m.id);
}
anyhow::bail!("Multiple models available. Please specify --model explicitly.")
}
}
}
/// Discover worker IDs from the frontend's /health endpoint.
///
/// Returns a list of instance_ids (worker_ids) that are currently registered.
......@@ -838,32 +788,6 @@ struct RequestResult {
success: bool,
}
/// Individual latency sample for raw data export
#[derive(Debug, Clone, Serialize)]
struct LatencySample {
/// Latency in microseconds
latency_us: u64,
/// Completion time in milliseconds from measurement start
completion_time_ms: u64,
/// Whether the request succeeded
success: bool,
}
/// OpenAI-style chat completion request
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}
#[derive(Debug, Serialize)]
struct ChatMessage {
role: String,
content: String,
}
/// Generate prefix text content.
/// These are long enough to span multiple KV blocks when tokenized.
/// Each prefix is designed to be distinct and consistent across requests.
......@@ -1302,129 +1226,6 @@ async fn publish_events_at_rate(
}
}
/// Latency statistics
struct LatencyStats {
min: Duration,
max: Duration,
p50: Duration,
p95: Duration,
p99: Duration,
}
impl LatencyStats {
fn from_durations(durations: &[Duration]) -> Option<Self> {
if durations.is_empty() {
return None;
}
let mut sorted = durations.to_vec();
sorted.sort();
let n = sorted.len();
Some(Self {
min: sorted[0],
max: sorted[n - 1],
p50: sorted[n / 2],
p95: sorted[n * 95 / 100],
p99: sorted[n * 99 / 100],
})
}
}
/// Time-bucketed latency statistics for tracking latency over time
#[derive(Debug, Clone, Serialize)]
struct TimeBucketStats {
/// Bucket start time in seconds from measurement start
bucket_start_sec: u64,
/// Bucket end time in seconds
bucket_end_sec: u64,
/// Number of requests completed in this bucket
count: usize,
/// Latency stats for this bucket (in microseconds)
latency_min_us: u64,
latency_p50_us: u64,
latency_p95_us: u64,
latency_max_us: u64,
}
/// Compute per-bucket latency statistics
fn compute_time_bucket_stats(
results: &[RequestResult],
bucket_size_secs: u64,
) -> Vec<TimeBucketStats> {
if results.is_empty() {
return Vec::new();
}
// Find the max completion time to determine bucket count
let max_completion = results
.iter()
.map(|r| r.completion_time)
.max()
.unwrap_or(Duration::ZERO);
let num_buckets = (max_completion.as_secs() / bucket_size_secs) + 1;
let mut bucket_latencies: Vec<Vec<Duration>> = vec![Vec::new(); num_buckets as usize];
// Group latencies by completion time bucket
for result in results {
let bucket_idx = (result.completion_time.as_secs() / bucket_size_secs) as usize;
if bucket_idx < bucket_latencies.len() {
bucket_latencies[bucket_idx].push(result.latency);
}
}
// Compute stats for each bucket
bucket_latencies
.iter()
.enumerate()
.filter_map(|(idx, latencies)| {
if latencies.is_empty() {
return None;
}
let stats = LatencyStats::from_durations(latencies)?;
Some(TimeBucketStats {
bucket_start_sec: idx as u64 * bucket_size_secs,
bucket_end_sec: (idx as u64 + 1) * bucket_size_secs,
count: latencies.len(),
latency_min_us: stats.min.as_micros() as u64,
latency_p50_us: stats.p50.as_micros() as u64,
latency_p95_us: stats.p95.as_micros() as u64,
latency_max_us: stats.max.as_micros() as u64,
})
})
.collect()
}
/// Print time-bucket latency report
fn print_time_bucket_report(buckets: &[TimeBucketStats]) {
if buckets.is_empty() {
println!(" No time bucket data available");
return;
}
println!(
" {:>8} {:>8} {:>12} {:>12} {:>12} {:>12}",
"Time(s)", "Count", "Min(ms)", "P50(ms)", "P95(ms)", "Max(ms)"
);
println!(" {}", "-".repeat(68));
for bucket in buckets {
println!(
" {:>3}-{:<4} {:>8} {:>12.1} {:>12.1} {:>12.1} {:>12.1}",
bucket.bucket_start_sec,
bucket.bucket_end_sec,
bucket.count,
bucket.latency_min_us as f64 / 1000.0,
bucket.latency_p50_us as f64 / 1000.0,
bucket.latency_p95_us as f64 / 1000.0,
bucket.latency_max_us as f64 / 1000.0,
);
}
}
/// Stress test results
#[derive(Debug, Serialize)]
struct StressResults {
......@@ -1847,7 +1648,11 @@ async fn main() -> Result<()> {
// Compute time-bucketed stats for latency-over-time tracking
let time_buckets = if args.bucket_size > 0 {
compute_time_bucket_stats(&results, args.bucket_size)
let pairs: Vec<(Duration, Duration)> = results
.iter()
.map(|r| (r.latency, r.completion_time))
.collect();
compute_time_bucket_stats(&pairs, args.bucket_size)
} else {
Vec::new()
};
......
......@@ -340,10 +340,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let bootstrap_server = self.bootstrap_server.clone();
let reasoning = self.engine_args.reasoning.clone();
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
let mut token_count = 0;
let think_len = reasoning
.as_ref()
.map(|cfg| cfg.num_thinking_tokens(max_output_tokens))
.unwrap_or(0);
loop {
tokio::select! {
......@@ -353,8 +358,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
break;
};
// Generate a new token
let token_id = generate_random_token();
// Generate a token (with thinking boundaries if configured)
let token_id = if token_count == 0 && think_len > 0 {
reasoning.as_ref().unwrap().start_thinking_token_id
} else if think_len > 0 && token_count == think_len - 1 {
reasoning.as_ref().unwrap().end_thinking_token_id
} else {
generate_random_token()
};
token_count += 1;
let output = LLMEngineOutput {
......
......@@ -13,6 +13,7 @@
pub mod media;
pub mod prompt;
pub mod speculative_prefill;
pub mod tools;
use anyhow::Context;
use anyhow::{Result, bail};
......@@ -1165,6 +1166,15 @@ impl
transformed_stream
};
// Step 5: Speculative next-turn prefill
let final_stream = speculative_prefill::maybe_wrap_stream(
final_stream,
&request,
&next,
&self.formatter,
&self.tokenizer,
);
// prepend the annotations to the response stream
let stream = annotations_stream.chain(final_stream);
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Speculative next-turn prefill for reasoning models.
//!
//! After an assistant turn completes, we know what the next turn's prompt prefix
//! will look like: the full conversation history (with thinking content stripped by
//! the Jinja template for non-last assistant turns). We render it, tokenize it,
//! and send a `max_tokens=1` request through the pipeline to warm the KV cache.
use std::pin::Pin;
use std::sync::Arc;
use anyhow::Result;
use dynamo_async_openai::types::{
ChatCompletionMessageContent, ChatCompletionRequestAssistantMessage,
ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
};
use futures::Stream;
use futures::stream::StreamExt;
use minijinja::value::Value;
use dynamo_runtime::engine::AsyncEngine;
use dynamo_runtime::pipeline::{Context as PipelineContext, Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use crate::preprocessor::prompt::{OAIChatLikeRequest, OAIPromptFormatter};
use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::tokenizers::traits::Tokenizer;
/// A minimal `OAIChatLikeRequest` for speculative next-turn prefill.
/// Holds the full conversation (including a new assistant message) and
/// renders with `add_generation_prompt = false` so the result is the
/// exact prefix the next user turn will see.
pub struct SpeculativePrefillRequest {
messages: Vec<ChatCompletionRequestMessage>,
}
impl SpeculativePrefillRequest {
pub fn new(messages: Vec<ChatCompletionRequestMessage>) -> Self {
Self { messages }
}
}
impl OAIChatLikeRequest for SpeculativePrefillRequest {
fn model(&self) -> String {
"speculative_prefill".to_string()
}
fn messages(&self) -> Value {
let json = serde_json::to_value(&self.messages).unwrap();
Value::from_serialize(&json)
}
fn typed_messages(&self) -> Option<&[ChatCompletionRequestMessage]> {
Some(&self.messages)
}
fn should_add_generation_prompt(&self) -> bool {
false
}
}
/// Optionally wraps a chat completion response stream to enable speculative
/// next-turn prefill. When `nvext.speculative_prefill` is set, the returned
/// stream accumulates the assistant response text and, on completion, spawns
/// a background task that renders the next-turn prefix and fires a
/// `max_tokens=1` request through the pipeline to warm the KV cache.
///
/// When the flag is not set, returns the stream unmodified with zero overhead.
pub fn maybe_wrap_stream(
stream: Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>,
request: &NvCreateChatCompletionRequest,
next: &Arc<
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
formatter: &Arc<dyn OAIPromptFormatter>,
tokenizer: &Arc<dyn Tokenizer>,
) -> Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>> {
let enabled = request
.nvext
.as_ref()
.and_then(|ext| ext.agent_hints.as_ref())
.and_then(|hints| hints.speculative_prefill)
.unwrap_or(false);
if !enabled {
return stream;
}
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
let next = next.clone();
let formatter = formatter.clone();
let tokenizer = tokenizer.clone();
let messages = request.inner.messages.clone();
tokio::spawn(async move {
let Ok(response_text) = rx.await else {
return;
};
if let Err(e) = prefill_task(next, formatter, tokenizer, messages, response_text).await {
tracing::warn!(error = %e, "Speculative prefill failed");
}
});
let mut accumulated_text = String::new();
let mut prefill_tx = Some(tx);
Box::pin(stream.map(move |item| {
if let Some(ref resp) = item.data {
for choice in &resp.choices {
if let Some(ChatCompletionMessageContent::Text(ref text)) = choice.delta.content {
accumulated_text.push_str(text);
}
// Send accumulated text once we see finish_reason (works
// regardless of whether usage reporting is enabled).
if choice.finish_reason.is_some()
&& let Some(tx) = prefill_tx.take()
{
let _ = tx.send(accumulated_text.clone());
}
}
}
item
}))
}
/// Fire-and-forget task that renders the next-turn prefix and sends it
/// through the pipeline as a `max_tokens=1` request to warm the KV cache.
async fn prefill_task(
next: Arc<
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>,
original_messages: Vec<ChatCompletionRequestMessage>,
response_text: String,
) -> Result<()> {
let assistant_msg =
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
response_text,
)),
..Default::default()
});
let mut messages = original_messages;
messages.push(assistant_msg);
let prefill_request = SpeculativePrefillRequest::new(messages);
let formatted_prompt = formatter.render(&prefill_request)?;
let encoding = tokenizer.encode(&formatted_prompt)?;
let token_ids = encoding.token_ids().to_vec();
tracing::info!(
num_tokens = token_ids.len(),
"Speculative prefill: sending next-turn prefix"
);
let preprocessed = PreprocessedRequest::builder()
.model("speculative_prefill".to_string())
.token_ids(token_ids)
.stop_conditions(StopConditions {
max_tokens: Some(1),
..Default::default()
})
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.eos_token_ids(vec![])
.annotations(vec![])
.build()?;
let context = PipelineContext::with_id(preprocessed, uuid::Uuid::new_v4().to_string());
// Drain the stream so the KV router's RequestGuard runs its full lifecycle
// (mark_prefill_completed, block tracking, free) instead of relying on drop.
if let Ok(mut stream) = next.generate(context).await {
while stream.next().await.is_some() {}
}
Ok(())
}
......@@ -173,6 +173,13 @@ pub struct AgentHints {
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub osl: Option<u32>,
/// When true, after the assistant turn completes, the system will speculatively
/// prefill the predicted next-turn prefix (conversation history with thinking
/// content stripped) on a worker to warm the KV cache for the next request.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub speculative_prefill: Option<bool>,
}
impl Default for NvExt {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment