Unverified Commit fc2ad4eb authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: mocker can use planner profile data (#4422)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
parent 48c340f4
...@@ -2722,6 +2722,8 @@ dependencies = [ ...@@ -2722,6 +2722,8 @@ dependencies = [
"modelexpress-client", "modelexpress-client",
"modelexpress-common", "modelexpress-common",
"ndarray", "ndarray",
"ndarray-interp",
"ndarray-npy",
"nix 0.26.4", "nix 0.26.4",
"nixl-sys", "nixl-sys",
"offset-allocator", "offset-allocator",
...@@ -6527,6 +6529,31 @@ dependencies = [ ...@@ -6527,6 +6529,31 @@ dependencies = [
"rawpointer", "rawpointer",
] ]
[[package]]
name = "ndarray-interp"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e43087829efb5ec2736598e88587df286425b59df5a9ce991994cdd2c5855d3f"
dependencies = [
"ndarray",
"num-traits",
"thiserror 2.0.17",
]
[[package]]
name = "ndarray-npy"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b313788c468c49141a9d9b6131fc15f403e6ef4e8446a0b2e18f664ddb278a9"
dependencies = [
"byteorder",
"ndarray",
"num-complex",
"num-traits",
"py_literal",
"zip 2.4.2",
]
[[package]] [[package]]
name = "neli" name = "neli"
version = "0.6.5" version = "0.6.5"
...@@ -8126,6 +8153,19 @@ dependencies = [ ...@@ -8126,6 +8153,19 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "py_literal"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1"
dependencies = [
"num-bigint",
"num-complex",
"num-traits",
"pest",
"pest_derive",
]
[[package]] [[package]]
name = "qoi" name = "qoi"
version = "0.4.1" version = "0.4.1"
...@@ -12700,6 +12740,23 @@ dependencies = [ ...@@ -12700,6 +12740,23 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "zip"
version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50"
dependencies = [
"arbitrary",
"crc32fast",
"crossbeam-utils",
"displaydoc",
"flate2",
"indexmap 2.12.0",
"memchr",
"thiserror 2.0.17",
"zopfli",
]
[[package]] [[package]]
name = "zip" name = "zip"
version = "3.0.0" version = "3.0.0"
......
...@@ -43,4 +43,74 @@ python -m dynamo.frontend --http-port 8000 ...@@ -43,4 +43,74 @@ python -m dynamo.frontend --http-port 8000
``` ```
> [!Note] > [!Note]
> Each mocker instance runs as a single process, and each DP worker (specified by `--data-parallel-size`) is spawned as a lightweight async task within that process. For benchmarking (e.g., router testing), you can use `--num-workers` to launch multiple mocker engines in the same process, which is more efficient than launching separate processes since they all share the same tokio runtime and thread pool. > Each mocker instance runs as a single process, and each DP worker (specified by `--data-parallel-size`) is spawned as a lightweight async task within that process. For benchmarking (e.g., router testing), you can use `--num-workers` to launch multiple mocker engines in the same process, which is more efficient than launching separate processes since they all share the same tokio runtime and thread pool.
\ No newline at end of file
## Performance modeling with planner profile data
By default, the mocker uses hardcoded polynomial formulas to estimate prefill and decode timing. For more realistic simulations, you can load performance data from actual profiling results.
### Using profiled performance data
Add the `--planner-profile-data` flag to load an NPZ file containing interpolation grids from the planner profiler:
```bash
python -m dynamo.mocker \
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--planner-profile-data /path/to/profiling_results/perf_data.npz \
--speedup-ratio 1.0
```
The NPZ file should contain the following arrays:
- `prefill_isl`: 1D array of input sequence lengths
- `prefill_ttft_ms`: 1D array of time-to-first-token values (ms)
- `decode_active_kv_tokens`: 1D array of active KV token counts
- `decode_context_length`: 1D array of context lengths
- `decode_itl`: 2D array of inter-token latencies (ms)
### Generating performance data from profiler results
#### Option 1: Use existing pre-swept results
The repository includes pre-swept profiling results for common models and hardware configurations. For example, to use Llama-3.1-8B-Instruct-FP8 on H200 SXM:
```bash
# Convert existing pre-swept results to mocker-compatible NPZ format
python components/src/dynamo/mocker/utils/planner_profiler_perf_data_converter.py \
--profile_results_dir tests/planner/profiling_results/H200_TP1P_TP1D \
--output_dir mocker_perf_data \
--resolution 100
# Use the generated NPZ with mocker
python -m dynamo.mocker \
--model-path nvidia/Llama-3.1-8B-Instruct-FP8 \
--planner-profile-data mocker_perf_data/perf_data.npz
```
#### Option 2: Generate from custom profiler runs
To convert your own profiler results into the NPZ format suitable for the mocker, you'll need to run the profiler (see [SLA-driven profiling documentation](../../../../docs/benchmarks/sla_driven_profiling.md) for details). Note that this is generally run in a Kubernetes environment.
```bash
# Run the profiler
python benchmarks/profiler/profile_sla.py \
--profile-config your_profile_config.yaml
# Convert profiler results to mocker-compatible NPZ format
python components/src/dynamo/mocker/utils/planner_profiler_perf_data_converter.py \
--profile_results_dir profiling_results/selected_prefill_interpolation \
--output_dir profiling_results \
--resolution 100
# This creates profiling_results/perf_data.npz
```
The converter script combines prefill and decode interpolation data into a single NPZ file with the appropriate array structure.
### How it works
When you provide `--planner-profile-data`:
1. The mocker loads the NPZ file during initialization
2. Prefill timing uses 1D linear interpolation on the ISL grid
3. Decode timing uses 2D bilinear interpolation on (active_kv_tokens, context_length)
Without `--planner-profile-data`, the mocker falls back to the default polynomial formulas for backward compatibility.
\ No newline at end of file
...@@ -38,6 +38,9 @@ def create_temp_engine_args_file(args) -> Path: ...@@ -38,6 +38,9 @@ def create_temp_engine_args_file(args) -> Path:
"speedup_ratio": getattr(args, "speedup_ratio", None), "speedup_ratio": getattr(args, "speedup_ratio", None),
"dp_size": getattr(args, "dp_size", None), "dp_size": getattr(args, "dp_size", None),
"startup_time": getattr(args, "startup_time", None), "startup_time": getattr(args, "startup_time", None),
"planner_profile_data": str(getattr(args, "planner_profile_data", None))
if getattr(args, "planner_profile_data", None)
else None,
"is_prefill": getattr(args, "is_prefill_worker", None), "is_prefill": getattr(args, "is_prefill_worker", None),
"is_decode": getattr(args, "is_decode_worker", None), "is_decode": getattr(args, "is_decode_worker", None),
} }
...@@ -175,6 +178,12 @@ def parse_args(): ...@@ -175,6 +178,12 @@ def parse_args():
default=None, default=None,
help="Simulated engine startup time in seconds (default: None)", help="Simulated engine startup time in seconds (default: None)",
) )
parser.add_argument(
"--planner-profile-data",
type=Path,
default=None,
help="Path to JSON configmap or NPZ file containing performance profiling data from planner_profiler_perf_data_converter.py (default: None, uses hardcoded polynomials)",
)
parser.add_argument( parser.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
......
...@@ -31,6 +31,7 @@ and might leads to slightly higher latency. ...@@ -31,6 +31,7 @@ and might leads to slightly higher latency.
import argparse import argparse
import logging import logging
import os import os
from pathlib import Path
import numpy as np import numpy as np
...@@ -50,8 +51,16 @@ if __name__ == "__main__": ...@@ -50,8 +51,16 @@ if __name__ == "__main__":
parser.add_argument("--output_dir", type=str, default="") parser.add_argument("--output_dir", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
# Convert to absolute paths to handle relative directories properly
args.profile_results_dir = str(Path(args.profile_results_dir).resolve())
if not args.output_dir: if not args.output_dir:
args.output_dir = args.profile_results_dir args.output_dir = args.profile_results_dir
else:
args.output_dir = str(Path(args.output_dir).resolve())
# Create output directory if it doesn't exist
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
logger.info( logger.info(
f"Converting profile results from {args.profile_results_dir} to {args.output_dir}..." f"Converting profile results from {args.profile_results_dir} to {args.output_dir}..."
......
...@@ -1435,7 +1435,7 @@ dependencies = [ ...@@ -1435,7 +1435,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users", "redox_users",
"windows-sys 0.59.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -1580,6 +1580,8 @@ dependencies = [ ...@@ -1580,6 +1580,8 @@ dependencies = [
"modelexpress-client", "modelexpress-client",
"modelexpress-common", "modelexpress-common",
"ndarray", "ndarray",
"ndarray-interp",
"ndarray-npy",
"offset-allocator", "offset-allocator",
"oneshot", "oneshot",
"parking_lot", "parking_lot",
...@@ -1872,7 +1874,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -1872,7 +1874,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -2807,7 +2809,7 @@ dependencies = [ ...@@ -2807,7 +2809,7 @@ dependencies = [
"libc", "libc",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"socket2 0.5.10", "socket2 0.6.1",
"system-configuration", "system-configuration",
"tokio", "tokio",
"tower-service", "tower-service",
...@@ -3197,7 +3199,7 @@ dependencies = [ ...@@ -3197,7 +3199,7 @@ dependencies = [
"portable-atomic", "portable-atomic",
"portable-atomic-util", "portable-atomic-util",
"serde_core", "serde_core",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -3933,6 +3935,31 @@ dependencies = [ ...@@ -3933,6 +3935,31 @@ dependencies = [
"rawpointer", "rawpointer",
] ]
[[package]]
name = "ndarray-interp"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e43087829efb5ec2736598e88587df286425b59df5a9ce991994cdd2c5855d3f"
dependencies = [
"ndarray",
"num-traits",
"thiserror 2.0.17",
]
[[package]]
name = "ndarray-npy"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b313788c468c49141a9d9b6131fc15f403e6ef4e8446a0b2e18f664ddb278a9"
dependencies = [
"byteorder",
"ndarray",
"num-complex",
"num-traits",
"py_literal",
"zip 2.4.2",
]
[[package]] [[package]]
name = "neli" name = "neli"
version = "0.6.5" version = "0.6.5"
...@@ -4068,7 +4095,7 @@ version = "0.50.3" ...@@ -4068,7 +4095,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -4916,6 +4943,19 @@ dependencies = [ ...@@ -4916,6 +4943,19 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "py_literal"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1"
dependencies = [
"num-bigint",
"num-complex",
"num-traits",
"pest",
"pest_derive",
]
[[package]] [[package]]
name = "pyo3" name = "pyo3"
version = "0.23.5" version = "0.23.5"
...@@ -5045,7 +5085,7 @@ dependencies = [ ...@@ -5045,7 +5085,7 @@ dependencies = [
"quinn-udp", "quinn-udp",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2 0.5.10", "socket2 0.6.1",
"thiserror 2.0.17", "thiserror 2.0.17",
"tokio", "tokio",
"tracing", "tracing",
...@@ -5082,9 +5122,9 @@ dependencies = [ ...@@ -5082,9 +5122,9 @@ dependencies = [
"cfg_aliases", "cfg_aliases",
"libc", "libc",
"once_cell", "once_cell",
"socket2 0.5.10", "socket2 0.6.1",
"tracing", "tracing",
"windows-sys 0.52.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]
...@@ -5587,7 +5627,7 @@ dependencies = [ ...@@ -5587,7 +5627,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -6383,7 +6423,7 @@ dependencies = [ ...@@ -6383,7 +6423,7 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
"once_cell", "once_cell",
"rustix", "rustix",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -7563,7 +7603,7 @@ version = "0.1.11" ...@@ -7563,7 +7603,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
...@@ -8130,6 +8170,23 @@ dependencies = [ ...@@ -8130,6 +8170,23 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "zip"
version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50"
dependencies = [
"arbitrary",
"crc32fast",
"crossbeam-utils",
"displaydoc",
"flate2",
"indexmap 2.12.0",
"memchr",
"thiserror 2.0.17",
"zopfli",
]
[[package]] [[package]]
name = "zip" name = "zip"
version = "3.0.0" version = "3.0.0"
......
...@@ -148,6 +148,8 @@ base64 = { version = "0.22" } ...@@ -148,6 +148,8 @@ base64 = { version = "0.22" }
image = { version = "0.25" } image = { version = "0.25" }
tokio-rayon = {version = "2" } tokio-rayon = {version = "2" }
ndarray = { version = "0.16" } ndarray = { version = "0.16" }
ndarray-npy = { version = "0.9" }
ndarray-interp = { version = "0.5" }
# Publishers # Publishers
zeromq = "0.4.1" zeromq = "0.4.1"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
pub mod engine; pub mod engine;
pub mod evictor; pub mod evictor;
pub mod kv_manager; pub mod kv_manager;
pub mod perf_model;
pub mod protocols; pub mod protocols;
pub mod running_mean; pub mod running_mean;
pub mod scheduler; pub mod scheduler;
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Performance model for timing simulations in the mocker.
//!
//! This module provides two timing models:
//! 1. Polynomial: Hardcoded polynomial formulas (default, backward compatible)
//! 2. Interpolated: Grid-based interpolation from profiler data (loaded from NPZ files)
use anyhow::{Context, Result};
use ndarray::{Array1, Array2};
use ndarray_interp::InterpolateError;
use ndarray_interp::interp1d::{Interp1DBuilder, Linear};
use ndarray_interp::interp2d::{Bilinear, Interp2DBuilder};
use std::path::Path;
use std::sync::Arc;
/// Trait to abstract over 1D interpolation for prefill timing
pub trait PrefillInterpolator: Send + Sync {
fn interp(&self, x: f64) -> Result<f64, InterpolateError>;
}
/// Trait to abstract over 2D interpolation for decode timing
pub trait DecodeInterpolator: Send + Sync {
fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError>;
}
/// Wrapper to implement PrefillInterpolator for the concrete Interp1D type
struct PrefillInterp1D {
inner: ndarray_interp::interp1d::Interp1D<
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::Ix1,
Linear,
>,
}
impl PrefillInterpolator for PrefillInterp1D {
fn interp(&self, x: f64) -> Result<f64, InterpolateError> {
self.inner.interp_scalar(x)
}
}
/// Wrapper to implement DecodeInterpolator for the concrete Interp2D type
struct DecodeInterp2D {
inner: ndarray_interp::interp2d::Interp2D<
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::OwnedRepr<f64>,
ndarray::Ix2,
Bilinear,
>,
}
impl DecodeInterpolator for DecodeInterp2D {
fn interp(&self, x: f64, y: f64) -> Result<f64, InterpolateError> {
self.inner.interp_scalar(x, y)
}
}
/// Performance model for predicting prefill and decode timing
#[derive(Default)]
pub enum PerfModel {
/// Default polynomial-based model using hardcoded formulas
#[default]
Polynomial,
/// Interpolation-based model using profiler data
/// Interpolators are built once and stored as trait objects
Interpolated {
prefill_interp: Arc<dyn PrefillInterpolator>,
decode_interp: Arc<dyn DecodeInterpolator>,
},
}
impl Clone for PerfModel {
fn clone(&self) -> Self {
match self {
PerfModel::Polynomial => PerfModel::Polynomial,
PerfModel::Interpolated {
prefill_interp,
decode_interp,
} => PerfModel::Interpolated {
prefill_interp: Arc::clone(prefill_interp),
decode_interp: Arc::clone(decode_interp),
},
}
}
}
impl std::fmt::Debug for PerfModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PerfModel::Polynomial => write!(f, "PerfModel::Polynomial"),
PerfModel::Interpolated { .. } => write!(f, "PerfModel::Interpolated {{ .. }}"),
}
}
}
impl PerfModel {
/// Load performance model from NPZ file
///
/// Expected arrays in NPZ file:
/// - prefill_isl: 1D array of input sequence lengths
/// - prefill_ttft_ms: 1D array of time to first token in milliseconds
/// - decode_active_kv_tokens: 1D array of active KV token counts
/// - decode_context_length: 1D array of context lengths
/// - decode_itl: 2D array of inter-token latencies in milliseconds
pub fn from_npz(path: &Path) -> Result<Self> {
use ndarray_npy::NpzReader;
use std::fs::File;
tracing::info!("Loading performance model from NPZ file: {:?}", path);
let file =
File::open(path).with_context(|| format!("Failed to open NPZ file: {:?}", path))?;
let mut npz = NpzReader::new(file)
.with_context(|| format!("Failed to create NPZ reader for: {:?}", path))?;
// Load prefill arrays
let prefill_isl: Array1<f64> = npz
.by_name("prefill_isl")
.with_context(|| "Failed to load prefill_isl from NPZ")?;
let prefill_ttft_ms: Array1<f64> = npz
.by_name("prefill_ttft_ms")
.with_context(|| "Failed to load prefill_ttft_ms from NPZ")?;
// Load decode arrays
let decode_active_kv_tokens: Array1<f64> = npz
.by_name("decode_active_kv_tokens")
.with_context(|| "Failed to load decode_active_kv_tokens from NPZ")?;
let decode_context_length: Array1<f64> = npz
.by_name("decode_context_length")
.with_context(|| "Failed to load decode_context_length from NPZ")?;
let decode_itl: Array2<f64> = npz
.by_name("decode_itl")
.with_context(|| "Failed to load decode_itl from NPZ")?;
// Validate dimensions
if prefill_isl.len() != prefill_ttft_ms.len() {
anyhow::bail!(
"Prefill array length mismatch: isl={}, ttft={}",
prefill_isl.len(),
prefill_ttft_ms.len()
);
}
if decode_itl.nrows() != decode_active_kv_tokens.len()
|| decode_itl.ncols() != decode_context_length.len()
{
anyhow::bail!(
"Decode array dimension mismatch: itl shape=({}, {}), active_kv={}, context={}",
decode_itl.nrows(),
decode_itl.ncols(),
decode_active_kv_tokens.len(),
decode_context_length.len()
);
}
tracing::info!(
"Loaded performance model: prefill_points={}, decode_grid={}x{}",
prefill_isl.len(),
decode_itl.nrows(),
decode_itl.ncols()
);
// Build interpolators once during loading
let prefill_interp = Interp1DBuilder::new(prefill_ttft_ms)
.x(prefill_isl)
.strategy(Linear::new().extrapolate(true))
.build()
.with_context(|| "Failed to build prefill interpolator")?;
let decode_interp = Interp2DBuilder::new(decode_itl)
.x(decode_active_kv_tokens)
.y(decode_context_length)
.strategy(Bilinear::new().extrapolate(true))
.build()
.with_context(|| "Failed to build decode interpolator")?;
Ok(PerfModel::Interpolated {
prefill_interp: Arc::new(PrefillInterp1D {
inner: prefill_interp,
}),
decode_interp: Arc::new(DecodeInterp2D {
inner: decode_interp,
}),
})
}
/// Predict prefill time in milliseconds given the number of new tokens
pub fn predict_prefill_time(&self, new_tokens: usize) -> f64 {
let time = match self {
PerfModel::Polynomial => {
// Original polynomial formula
let tokens = new_tokens as f64;
4.209989e-07 * tokens.powi(2) + 1.518344e-02 * tokens + 1.650142e+01
}
PerfModel::Interpolated { prefill_interp, .. } => {
// Use pre-built interpolator
let query = new_tokens as f64;
prefill_interp.interp(query).unwrap_or(0.0)
}
};
// Ensure non-negative timing
let result = time.max(0.0);
tracing::debug!("Prefill time prediction: new_tokens={new_tokens}, time={result:.2}ms");
result
}
/// Predict decode time in milliseconds given active KV tokens and context length
///
/// For the Polynomial variant, this computes active percentage as active_kv_tokens / 16384.
/// For the Interpolated variant, this performs 2D bilinear interpolation.
pub fn predict_decode_time(&self, active_kv_tokens: usize, context_length: usize) -> f64 {
let time = match self {
PerfModel::Polynomial => {
// Compute active percentage using default capacity
let active_perc = active_kv_tokens as f64 / 16384.0;
// Original polynomial formula
-25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
}
PerfModel::Interpolated { decode_interp, .. } => {
// Use pre-built interpolator
let query_x = active_kv_tokens as f64;
let query_y = context_length as f64;
decode_interp.interp(query_x, query_y).unwrap_or(0.0)
}
};
// Ensure non-negative timing
let result = time.max(0.0);
tracing::debug!(
"Decode time prediction: active_kv_tokens={active_kv_tokens}, context_length={context_length}, time={result:.2}ms"
);
result
}
}
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::path::Path; use std::path::{Path, PathBuf};
use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
use crate::mocker::perf_model::PerfModel;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash, Token}; use crate::tokens::{BlockHash, SequenceHash, Token};
...@@ -44,9 +46,13 @@ pub struct PrefillCost { ...@@ -44,9 +46,13 @@ pub struct PrefillCost {
} }
impl PrefillCost { impl PrefillCost {
pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 { pub fn predict_prefill_compute(
&self,
new_tokens: Option<usize>,
perf_model: &PerfModel,
) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens); let tokens = new_tokens.unwrap_or(self.new_tokens);
4.209989e-07 * (tokens as f64).powi(2) + 1.518344e-02 * (tokens as f64) + 1.650142e+01 perf_model.predict_prefill_time(tokens)
} }
} }
...@@ -109,6 +115,11 @@ pub struct MockEngineArgs { ...@@ -109,6 +115,11 @@ pub struct MockEngineArgs {
/// Worker type for disaggregated serving (Aggregated, Prefill, or Decode) /// Worker type for disaggregated serving (Aggregated, Prefill, or Decode)
#[builder(default = "WorkerType::Aggregated")] #[builder(default = "WorkerType::Aggregated")]
pub worker_type: WorkerType, pub worker_type: WorkerType,
/// Performance model for timing predictions (not serialized, loaded from planner_profile_data)
#[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")]
pub perf_model: Arc<PerfModel>,
} }
impl Default for MockEngineArgs { impl Default for MockEngineArgs {
...@@ -146,6 +157,7 @@ impl MockEngineArgs { ...@@ -146,6 +157,7 @@ impl MockEngineArgs {
"startup_time", "startup_time",
"is_prefill", "is_prefill",
"is_decode", "is_decode",
"planner_profile_data",
] ]
.iter() .iter()
.cloned() .cloned()
...@@ -249,6 +261,30 @@ impl MockEngineArgs { ...@@ -249,6 +261,30 @@ impl MockEngineArgs {
}; };
builder = builder.worker_type(worker_type); builder = builder.worker_type(worker_type);
// Load performance model from NPZ file if provided
let perf_model = if let Some(path_str) = extra_args.get("planner_profile_data")
&& let Some(path_str) = path_str.as_str()
{
let npz_path = PathBuf::from(path_str);
match PerfModel::from_npz(&npz_path) {
Ok(model) => {
tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
Arc::new(model)
}
Err(e) => {
tracing::error!(
"Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
npz_path,
e
);
Arc::new(PerfModel::default())
}
}
} else {
Arc::new(PerfModel::default())
};
builder = builder.perf_model(perf_model);
// Build the MockEngineArgs with either defaults or overridden values // Build the MockEngineArgs with either defaults or overridden values
builder builder
.build() .build()
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats};
use crate::mocker::evictor::LRUEvictor; use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager; use crate::mocker::kv_manager::KvManager;
use crate::mocker::perf_model::PerfModel;
use crate::mocker::protocols::{ use crate::mocker::protocols::{
DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType, DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType,
}; };
...@@ -112,7 +113,7 @@ impl SchedulerState { ...@@ -112,7 +113,7 @@ impl SchedulerState {
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation /// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation /// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked /// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> { fn try_prefill(&mut self, perf_model: &PerfModel) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?; let uuid = self.prefill.pop_front()?;
// Remove and extract prefill_compute from prefill_costs // Remove and extract prefill_compute from prefill_costs
...@@ -134,11 +135,12 @@ impl SchedulerState { ...@@ -134,11 +135,12 @@ impl SchedulerState {
let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
{ {
let prefill_compute = prefill_cost.predict_prefill_compute(Some(prefill_tokens)); let prefill_compute =
prefill_cost.predict_prefill_compute(Some(prefill_tokens), perf_model);
prefill_cost.new_tokens -= prefill_tokens; prefill_cost.new_tokens -= prefill_tokens;
assert!( assert!(
(prefill_cost.new_tokens > 0) && (prefill_compute > 0.0), prefill_cost.new_tokens > 0,
"Encountered negative prefill tokens or prefill compute cost." "Encountered negative prefill tokens."
); );
self.prefill.push_front(uuid); self.prefill.push_front(uuid);
...@@ -155,7 +157,7 @@ impl SchedulerState { ...@@ -155,7 +157,7 @@ impl SchedulerState {
self.active_tokens += new_tokens; self.active_tokens += new_tokens;
self.waiting_tokens -= new_tokens; self.waiting_tokens -= new_tokens;
(prefill_cost.predict_prefill_compute(None), true) (prefill_cost.predict_prefill_compute(None, perf_model), true)
}; };
// NOTE: the current behavior allocates the KV blocks for the entire sequence, // NOTE: the current behavior allocates the KV blocks for the entire sequence,
...@@ -312,7 +314,7 @@ impl Scheduler { ...@@ -312,7 +314,7 @@ impl Scheduler {
// Process prefilling // Process prefilling
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state.try_prefill() state.try_prefill(&args.perf_model)
{ {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they // NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior. // could be cached by other requests in the same batch. This matches vLLM behavior.
...@@ -333,9 +335,24 @@ impl Scheduler { ...@@ -333,9 +335,24 @@ impl Scheduler {
} }
} }
let active_perc = kv_manager.get_active_perc(); // Compute decode timing
// TODO: share the same logic with Planner let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
let decoding_time = -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74; // Compute average context length across all active decode requests
let (total_length, count) = state
.decode
.keys()
.filter_map(|uuid| state.requests.get(uuid))
.fold((0usize, 0usize), |(sum, cnt), req| {
if let Request::Active(seq) = req {
(sum + seq.len(), cnt + 1)
} else {
(sum, cnt)
}
});
let context_length = if count > 0 { total_length / count } else { 0 };
let decoding_time = args
.perf_model
.predict_decode_time(active_kv_tokens, context_length);
total_time += Duration::from_secs_f64(decoding_time / 1000.0); total_time += Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens(); state.reset_active_tokens();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment