Unverified Commit e82fa448 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Add attention benchmarking tools (#26835)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent d9aa39a3
# vLLM Attention Benchmarking Suite
Fast, flexible benchmarking for vLLM attention and MLA backends with an extended batch specification grammar.
## Quick Start
```bash
cd benchmarks/attention_benchmarks
# Run a pre-configured benchmark
python benchmark.py --config configs/mla_decode.yaml
python benchmark.py --config configs/mla_mixed_batch.yaml
python benchmark.py --config configs/speculative_decode.yaml
python benchmark.py --config configs/standard_attention.yaml
python benchmark.py --config configs/reorder_threshold.yaml
# Or run custom benchmarks
python benchmark.py \
--backends flash flashinfer \
--batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \
--output-csv results.csv
```
## Simplified Batch Specification Grammar
Express workloads concisely using query length and sequence length:
```python
"q2k" # 2048-token prefill (q_len=2048, seq_len=2048)
"q1s1k" # Decode: 1 token with 1K sequence
"8q1s1k" # 8 decode requests
"q4s1k" # 4-token extend (e.g., spec decode)
"2q2k_32q1s1k" # Mixed: 2 prefills + 32 decodes
"16q4s1k" # 16 spec decode (4 tokens each)
```
### Grammar Rule
```text
Format: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
- count: Number of identical requests (optional, default=1)
- q_len: Query length (number of new tokens)
- seq_len: Total sequence length (optional, defaults to q_len for prefill)
- 'k': Multiplies value by 1024
Mixed batches: Use _ to combine (e.g., "2q2k_32q1s1k")
```
**Note**: Decode, prefill, and spec decode are just different query lengths - no special syntax needed!
## Pre-configured Benchmarks
The suite includes several pre-configured YAML benchmark configurations:
### MLA Decode Benchmark
Tests pure decode performance across MLA backends with varying batch sizes and sequence lengths.
```bash
python benchmark.py --config configs/mla_decode.yaml
```
### MLA Mixed Batch Benchmark
Tests chunked prefill performance with mixed prefill + decode batches.
```bash
python benchmark.py --config configs/mla_mixed_batch.yaml
```
### Speculative Decoding Benchmark
Tests speculative decode scenarios (K-token verification) and reorder_batch_threshold optimization.
```bash
python benchmark.py --config configs/speculative_decode.yaml
```
### Standard Attention Benchmark
Tests standard attention backends (Flash/Triton/FlashInfer) with pure prefill, decode, and mixed batches.
```bash
python benchmark.py --config configs/standard_attention.yaml
```
### Reorder Threshold Study
**Question:** At what query length does the prefill pipeline become faster than the decode pipeline?
Tests query lengths from 1-1024 across 9 batch sizes to find the crossover point. Uses `decode_vs_prefill` mode to compare both pipelines for each query length.
```bash
python benchmark.py --config configs/reorder_threshold.yaml
```
---
## Universal Benchmark
The `benchmark.py` script handles **all** backends - both standard attention and MLA.
### Standard Attention (Flash/Triton/FlashInfer)
```bash
python benchmark.py \
--backends flash triton flashinfer \
--batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \
--num-layers 10 \
--repeats 5 \
--output-csv results.csv
```
### MLA Backends
```bash
# Compare all MLA backends
python benchmark.py \
--backends cutlass_mla flashinfer_mla flashattn_mla flashmla \
--batch-specs "64q1s1k" "64q1s4k" \
--output-csv mla_results.csv
```
### Parameter Sweeps
Use `--sweep-param` and `--sweep-values` to run parameter sweeps from the CLI:
#### CUTLASS MLA num-splits Optimization
**Question:** What is the optimal `num_kv_splits` for CUTLASS MLA?
```bash
python benchmark.py \
--backend cutlass_mla \
--batch-specs "64q1s1k" "64q1s4k" "64q1s16k" \
--sweep-param num_kv_splits \
--sweep-values 1 2 4 8 16 \
--output-json optimal_splits.json
```
#### Reorder Batch Threshold Optimization
**Question:** What's the optimal `reorder_batch_threshold` for speculative decoding?
```bash
python benchmark.py \
--backend flashmla \
--batch-specs "q4s1k" "q8s2k" \
--sweep-param reorder_batch_threshold \
--sweep-values 1 4 16 64 256 512 \
--output-csv threshold_sweep.csv
```
### All Command-Line Options
```text
--config CONFIG # Path to YAML config file (overrides other args)
--backends BACKEND [BACKEND ...] # flash, triton, flashinfer, cutlass_mla,
# flashinfer_mla, flashattn_mla, flashmla
--backend BACKEND # Single backend (alternative to --backends)
--batch-specs SPEC [SPEC ...] # Batch specifications using extended grammar
# Model configuration
--num-layers N # Number of layers
--head-dim N # Head dimension
--num-q-heads N # Query heads
--num-kv-heads N # KV heads
--block-size N # Block size
# Benchmark settings
--device DEVICE # Device (default: cuda:0)
--repeats N # Repetitions
--warmup-iters N # Warmup iterations
--profile-memory # Profile memory usage
# Parameter sweeps
--sweep-param PARAM # Parameter name to sweep (e.g., num_kv_splits,
# reorder_batch_threshold)
--sweep-values N [N ...] # Values to sweep for the parameter
# Output
--output-csv FILE # Save to CSV
--output-json FILE # Save to JSON
```
## Hardware Requirements
| Backend | Hardware |
|---------|----------|
| Flash/Triton/FlashInfer | Any CUDA GPU |
| CUTLASS MLA | Blackwell (SM100+) |
| FlashAttn MLA | Hopper (SM90+) |
| FlashMLA | Hopper (SM90+) |
| FlashInfer-MLA | Any CUDA GPU |
## Using MLA Runner Directly
All MLA backends are available through `mla_runner.run_mla_benchmark()`:
```python
from mla_runner import run_mla_benchmark
from common import BenchmarkConfig
config = BenchmarkConfig(
backend="cutlass_mla",
batch_spec="64q1s4k",
num_layers=10,
head_dim=576,
num_q_heads=128,
num_kv_heads=1,
block_size=128,
device="cuda:0",
repeats=5,
warmup_iters=3,
)
# CUTLASS MLA with specific num_kv_splits
result = run_mla_benchmark("cutlass_mla", config, num_kv_splits=4)
print(f"Time: {result.mean_time:.6f}s")
# FlashInfer-MLA
result = run_mla_benchmark("flashinfer_mla", config)
# FlashAttn MLA (Hopper SM90+)
result = run_mla_benchmark("flashattn_mla", config, reorder_batch_threshold=64)
# FlashMLA (Hopper SM90+)
result = run_mla_benchmark("flashmla", config, reorder_batch_threshold=64)
```
## Python API
```python
from batch_spec import parse_batch_spec, format_batch_spec, get_batch_stats
from common import BenchmarkConfig, BenchmarkResult, ResultsFormatter
# Parse batch specs
requests = parse_batch_spec("2q2k_q4s1k_32q1s1k")
print(format_batch_spec(requests))
# "2 prefill (2x2k), 1 extend (1xq4kv1k), 32 decode (32x1k)"
# Get batch statistics
stats = get_batch_stats(requests)
print(f"Total tokens: {stats['total_tokens']}")
print(f"Num decode: {stats['num_decode']}, Num prefill: {stats['num_prefill']}")
# Format results
formatter = ResultsFormatter()
formatter.save_csv(results, "output.csv")
formatter.save_json(results, "output.json")
```
## Tips
**1. Warmup matters** - Use `--warmup-iters 10` for stable results
**2. Multiple repeats** - Use `--repeats 20` for low variance
**3. Save results** - Always use `--output-csv` or `--output-json`
**4. Test incrementally** - Start with `--num-layers 1 --repeats 1`
**5. Extended grammar** - Leverage spec decode, chunked prefill patterns
**6. Parameter sweeps** - Use `--sweep-param` and `--sweep-values` to find optimal values
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM Attention Benchmarking Suite."""
from .batch_spec import (
BatchRequest,
format_batch_spec,
get_batch_stats,
parse_batch_spec,
reorder_for_flashinfer,
split_by_type,
)
from .common import (
BenchmarkConfig,
BenchmarkResult,
MockLayer,
MockModelConfig,
ResultsFormatter,
get_attention_scale,
is_mla_backend,
setup_mla_dims,
)
__all__ = [
# Batch specification
"BatchRequest",
"parse_batch_spec",
"format_batch_spec",
"reorder_for_flashinfer",
"split_by_type",
"get_batch_stats",
# Benchmarking infrastructure
"BenchmarkConfig",
"BenchmarkResult",
"ResultsFormatter",
# Mock objects
"MockLayer",
"MockModelConfig",
# Utilities
"setup_mla_dims",
"get_attention_scale",
"is_mla_backend",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simplified batch specification grammar for attention benchmarks.
Grammar (underscore-separated segments):
Format: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
- count: Number of identical requests (optional, default=1)
- q_len: Query length (number of new tokens)
- seq_len: Total sequence length (optional, defaults to q_len for prefill)
- 'k' suffix: Multiplies value by 1024
Common patterns:
- Prefill: q_len == seq_len (e.g., "q2k" → 2048 new tokens, 2048 seq)
- Decode: q_len == 1 (e.g., "q1s1k" → 1 token, 1024 seq length)
- Extend: q_len < seq_len (e.g., "q4s1k" → 4 tokens, 1024 seq length)
Examples:
q2k -> [(2048, 2048)] # Prefill: 2048 tokens
q1s1k -> [(1, 1024)] # Decode: 1 token, 1K sequence
8q1s1k -> [(1, 1024)] * 8 # 8 decode requests
q4s1k -> [(4, 1024)] # 4-token extend (spec decode)
2q1k_32q1s1k -> [(1024, 1024)] * 2 + [(1, 1024)] * 32 # Mixed batch
16q4s1k -> [(4, 1024)] * 16 # 16 spec decode requests
"""
from collections import Counter
from dataclasses import dataclass
import regex as re
@dataclass
class BatchRequest:
"""Represents a single request in a batch."""
q_len: int # Query length (number of new tokens)
kv_len: int # Total KV cache length
@property
def is_decode(self) -> bool:
"""True if this is a decode request (q_len == 1)."""
return self.q_len == 1
@property
def is_prefill(self) -> bool:
"""True if this is a pure prefill (q_len == kv_len)."""
return self.q_len == self.kv_len
@property
def is_extend(self) -> bool:
"""True if this is context extension (q_len > 1, kv_len > q_len)."""
return self.q_len > 1 and self.kv_len > self.q_len
@property
def context_len(self) -> int:
"""Context length (KV cache - query)."""
return self.kv_len - self.q_len
def as_tuple(self) -> tuple[int, int]:
"""Return as (q_len, kv_len) tuple for compatibility."""
return (self.q_len, self.kv_len)
def _parse_size(size_str: str, k_suffix: str) -> int:
"""Parse size string with optional 'k' suffix."""
size = int(size_str)
return size * 1024 if k_suffix == "k" else size
def parse_batch_spec(spec: str) -> list[BatchRequest]:
"""
Parse batch specification string into list of BatchRequest objects.
Grammar: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
Args:
spec: Batch specification string (see module docstring for grammar)
Returns:
List of BatchRequest objects
Raises:
ValueError: If spec format is invalid
"""
requests = []
for seg in spec.split("_"):
# Unified pattern: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
m = re.match(r"^(?:(\d+))?q(\d+)(k?)(?:s(\d+)(k?))?$", seg)
if m:
cnt = int(m.group(1)) if m.group(1) else 1
q_len = _parse_size(m.group(2), m.group(3))
kv_len = _parse_size(m.group(4), m.group(5)) if m.group(4) else q_len
requests.extend([BatchRequest(q_len=q_len, kv_len=kv_len)] * cnt)
continue
raise ValueError(f"Invalid batch spec segment: '{seg}'")
return requests
def format_batch_spec(requests: list[BatchRequest]) -> str:
"""
Format list of BatchRequest into human-readable string.
Groups requests by type and provides counts and sizes.
Args:
requests: List of BatchRequest objects
Returns:
Formatted string describing the batch
"""
kinds = {
"prefill": [],
"extend": [],
"decode": [],
}
for req in requests:
tup = (req.q_len, req.kv_len)
if req.is_prefill:
kinds["prefill"].append(tup)
elif req.is_extend:
kinds["extend"].append(tup)
elif req.is_decode:
kinds["decode"].append(tup)
parts = []
for kind in ["prefill", "extend", "decode"]:
lst = kinds[kind]
if not lst:
continue
cnt_total = len(lst)
ctr = Counter(lst)
inner = []
for (q, kv), cnt in ctr.items():
if kind == "prefill":
size = f"{q // 1024}k" if q % 1024 == 0 else str(q)
inner.append(f"{cnt}x{size}")
elif kind == "decode":
size = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv)
inner.append(f"{cnt}x{size}")
else: # extend
qstr = f"{q // 1024}k" if q % 1024 == 0 else str(q)
kstr = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv)
inner.append(f"{cnt}xq{qstr}kv{kstr}")
parts.append(f"{cnt_total} {kind} ({', '.join(inner)})")
return ", ".join(parts)
def reorder_for_flashinfer(requests: list[BatchRequest]) -> list[BatchRequest]:
"""
Reorder requests for FlashInfer: decode first, then prefill.
FlashInfer expects decode requests before prefill requests for
optimal performance.
Args:
requests: Original list of BatchRequest
Returns:
Reordered list with decode requests first
"""
decodes = [r for r in requests if r.is_decode]
non_decodes = [r for r in requests if not r.is_decode]
return decodes + non_decodes
def split_by_type(
requests: list[BatchRequest],
) -> dict[str, list[BatchRequest]]:
"""
Split requests by type for analysis.
Args:
requests: List of BatchRequest
Returns:
Dict with keys: 'decode', 'prefill', 'extend'
"""
result = {
"decode": [],
"prefill": [],
"extend": [],
}
for req in requests:
if req.is_decode:
result["decode"].append(req)
elif req.is_prefill:
result["prefill"].append(req)
elif req.is_extend:
result["extend"].append(req)
return result
def get_batch_stats(requests: list[BatchRequest]) -> dict:
"""
Compute statistics about a batch.
Args:
requests: List of BatchRequest
Returns:
Dict with batch statistics
"""
by_type = split_by_type(requests)
return {
"total_requests": len(requests),
"num_decode": len(by_type["decode"]),
"num_prefill": len(by_type["prefill"]),
"num_extend": len(by_type["extend"]),
"total_tokens": sum(r.q_len for r in requests),
"total_kv_cache": sum(r.kv_len for r in requests),
"max_q_len": max((r.q_len for r in requests), default=0),
"max_kv_len": max((r.kv_len for r in requests), default=0),
"avg_q_len": sum(r.q_len for r in requests) / len(requests) if requests else 0,
"avg_kv_len": (
sum(r.kv_len for r in requests) / len(requests) if requests else 0
),
}
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Universal vLLM Attention Benchmark
Benchmark any attention backend with the extended grammar.
Supports standard attention (Flash/Triton/FlashInfer) and MLA backends.
Examples:
# Standard attention
python benchmark.py --backends flash flashinfer --batch-specs "q2k" "8q1s1k"
# MLA backends
python benchmark.py --backends cutlass_mla flashinfer_mla --batch-specs "64q1s1k"
# Parameter sweep (CLI)
python benchmark.py --backend cutlass_mla \
--batch-specs "64q1s1k" \
--sweep-param num_kv_splits \
--sweep-values 1 4 8 16
# Parameter sweep (YAML config - recommended)
python benchmark.py --config configs/cutlass_numsplits.yaml
"""
import argparse
import sys
from dataclasses import replace
from pathlib import Path
import yaml
from rich.console import Console
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from batch_spec import parse_batch_spec
from common import (
BenchmarkConfig,
BenchmarkResult,
ModelParameterSweep,
ParameterSweep,
ResultsFormatter,
is_mla_backend,
)
def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""Run standard attention benchmark (Flash/Triton/FlashInfer)."""
from runner import run_attention_benchmark
return run_attention_benchmark(config)
def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
"""Run MLA benchmark with appropriate backend."""
from mla_runner import run_mla_benchmark as run_mla
return run_mla(config.backend, config, **kwargs)
def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
"""
Run a single benchmark with proper backend selection.
Args:
config: BenchmarkConfig with backend, batch_spec, and model params
**kwargs: Additional arguments passed to MLA benchmarks
Returns:
BenchmarkResult (may have error field set on failure)
"""
try:
if is_mla_backend(config.backend):
return run_mla_benchmark(config, **kwargs)
else:
return run_standard_attention_benchmark(config)
except Exception as e:
return BenchmarkResult(
config=config,
mean_time=float("inf"),
std_time=0,
min_time=float("inf"),
max_time=float("inf"),
error=str(e),
)
def run_model_parameter_sweep(
backends: list[str],
batch_specs: list[str],
base_config_args: dict,
sweep: ModelParameterSweep,
console: Console,
) -> list[BenchmarkResult]:
"""
Run model parameter sweep for given backends and batch specs.
Args:
backends: List of backend names
batch_specs: List of batch specifications
base_config_args: Base configuration arguments (num_layers, head_dim, etc.)
sweep: ModelParameterSweep configuration
console: Rich console for output
Returns:
List of BenchmarkResult objects
"""
all_results = []
console.print(
f"[yellow]Model sweep mode: testing {sweep.param_name} = {sweep.values}[/]"
)
total = len(backends) * len(batch_specs) * len(sweep.values)
with tqdm(total=total, desc="Benchmarking") as pbar:
for backend in backends:
for spec in batch_specs:
for value in sweep.values:
# Create config with modified model parameter
config_args = base_config_args.copy()
config_args[sweep.param_name] = value
# Create config with original backend for running
clean_config = BenchmarkConfig(
backend=backend, batch_spec=spec, **config_args
)
# Run benchmark
result = run_benchmark(clean_config)
# Replace backend with labeled version for display
backend_label = sweep.get_label(backend, value)
labeled_config = replace(result.config, backend=backend_label)
result = replace(result, config=labeled_config)
all_results.append(result)
if not result.success:
console.print(
f"[red]Error {backend} {spec} {sweep.param_name}="
f"{value}: {result.error}[/]"
)
pbar.update(1)
# Display sweep results - create separate table for each parameter value
console.print("\n[bold green]Model Parameter Sweep Results:[/]")
formatter = ResultsFormatter(console)
# Group results by parameter value and extract backend mapping
by_param_value = {}
backend_mapping = {} # Maps labeled backend -> original backend
for r in all_results:
# Extract original backend and param value from labeled backend
# The label format is: {backend}_{param_name}_{value}
# We need to reverse engineer this
labeled_backend = r.config.backend
# Try each backend to find which one this result belongs to
for backend in backends:
for value in sweep.values:
expected_label = sweep.get_label(backend, value)
if labeled_backend == expected_label:
backend_mapping[labeled_backend] = backend
param_value = str(value)
if param_value not in by_param_value:
by_param_value[param_value] = []
by_param_value[param_value].append(r)
break
# Create a table for each parameter value
sorted_param_values = sorted(
by_param_value.keys(), key=lambda x: int(x) if x.isdigit() else x
)
for param_value in sorted_param_values:
console.print(f"\n[bold cyan]{sweep.param_name} = {param_value}[/]")
param_results = by_param_value[param_value]
# Create modified results with original backend names
modified_results = []
for r in param_results:
# Get the original backend name from our mapping
original_backend = backend_mapping[r.config.backend]
modified_config = replace(r.config, backend=original_backend)
modified_result = replace(r, config=modified_config)
modified_results.append(modified_result)
# Print table with original backend names
formatter.print_table(modified_results, backends, compare_to_fastest=True)
# Show optimal backend for each (param_value, batch_spec) combination
console.print(
f"\n[bold cyan]Optimal backend for each ({sweep.param_name}, batch_spec):[/]"
)
# Group by (param_value, batch_spec)
by_param_and_spec = {}
for r in all_results:
if r.success:
# Find which (backend, value) this result corresponds to
labeled_backend = r.config.backend
for backend in backends:
for value in sweep.values:
expected_label = sweep.get_label(backend, value)
if labeled_backend == expected_label:
param_value = str(value)
spec = r.config.batch_spec
key = (param_value, spec)
if key not in by_param_and_spec:
by_param_and_spec[key] = []
by_param_and_spec[key].append(r)
break
# Sort by param value then spec
sorted_keys = sorted(
by_param_and_spec.keys(),
key=lambda x: (int(x[0]) if x[0].isdigit() else x[0], x[1]),
)
current_param_value = None
for param_value, spec in sorted_keys:
# Print header when param value changes
if param_value != current_param_value:
console.print(f"\n [bold]{sweep.param_name}={param_value}:[/]")
current_param_value = param_value
results = by_param_and_spec[(param_value, spec)]
best = min(results, key=lambda r: r.mean_time)
# Extract original backend name using the mapping
backend_name = backend_mapping[best.config.backend]
# Show all backends' times for comparison
times_str = " | ".join(
[
f"{backend_mapping[r.config.backend]}: {r.mean_time:.6f}s"
for r in sorted(results, key=lambda r: r.mean_time)
]
)
console.print(
f" {spec:12s} -> [bold green]{backend_name:15s}[/] ({times_str})"
)
return all_results
def run_parameter_sweep(
backends: list[str],
batch_specs: list[str],
base_config_args: dict,
sweep: ParameterSweep,
console: Console,
) -> list[BenchmarkResult]:
"""
Run parameter sweep for given backends and batch specs.
Args:
backends: List of backend names
batch_specs: List of batch specifications
base_config_args: Base configuration arguments (num_layers, head_dim, etc.)
sweep: ParameterSweep configuration
console: Rich console for output
Returns:
List of BenchmarkResult objects
"""
all_results = []
# Build list of values to sweep (including auto if requested)
sweep_values = list(sweep.values)
if sweep.include_auto:
sweep_values.append("auto")
console.print(f"[yellow]Sweep mode: testing {sweep.param_name} = {sweep_values}[/]")
total = len(backends) * len(batch_specs) * len(sweep_values)
with tqdm(total=total, desc="Benchmarking") as pbar:
for backend in backends:
for spec in batch_specs:
for value in sweep_values:
# Create config with original backend for running
config = BenchmarkConfig(
backend=backend, batch_spec=spec, **base_config_args
)
# Prepare kwargs for benchmark runner
kwargs = {}
if value != "auto":
kwargs[sweep.param_name] = value
# Run benchmark
result = run_benchmark(config, **kwargs)
# Replace backend with labeled version for display
backend_label = sweep.get_label(backend, value)
labeled_config = replace(result.config, backend=backend_label)
result = replace(result, config=labeled_config)
all_results.append(result)
if not result.success:
console.print(
f"[red]Error {backend} {spec} {sweep.param_name}="
f"{value}: {result.error}[/]"
)
pbar.update(1)
# Display sweep results
console.print("\n[bold green]Sweep Results:[/]")
backend_labels = [sweep.get_label(b, v) for b in backends for v in sweep_values]
formatter = ResultsFormatter(console)
formatter.print_table(all_results, backend_labels)
# Show optimal values
console.print(f"\n[bold cyan]Optimal {sweep.param_name} per batch spec:[/]")
by_spec = {}
for r in all_results:
if r.success:
spec = r.config.batch_spec
if spec not in by_spec:
by_spec[spec] = []
by_spec[spec].append(r)
for spec in sorted(by_spec.keys()):
results = by_spec[spec]
best = min(results, key=lambda r: r.mean_time)
console.print(
f" {spec}: [bold green]{best.config.backend}[/] ({best.mean_time:.6f}s)"
)
return all_results
def load_config_from_yaml(config_path: str) -> dict:
"""Load configuration from YAML file."""
with open(config_path) as f:
return yaml.safe_load(f)
def generate_batch_specs_from_ranges(ranges: list[dict]) -> list[str]:
"""
Generate batch specs from range specifications.
Args:
ranges: List of range specifications, each containing:
- template: Batch spec template (e.g., "q{q_len}kv1k")
- q_len: Dict with start, stop, step, end_inclusive (optional)
- Other parameters can also be ranges
Returns:
List of generated batch spec strings
Example:
ranges = [
{
"template": "q{q_len}kv1k",
"q_len": {
"start": 1,
"stop": 16,
"step": 1,
"end_inclusive": true # Optional, defaults to true
}
}
]
Returns: ["q1kv1k", "q2kv1k", ..., "q16kv1k"]
"""
all_specs = []
for range_spec in ranges:
template = range_spec.get("template")
if not template:
raise ValueError("Range specification must include 'template'")
# Extract all range parameters from the spec
range_params = {}
for key, value in range_spec.items():
if key == "template":
continue
if isinstance(value, dict) and "start" in value:
# This is a range specification
start = value["start"]
stop = value["stop"]
step = value.get("step", 1)
# Check if end should be inclusive (default: True)
end_inclusive = value.get("end_inclusive", True)
# Adjust stop based on end_inclusive
if end_inclusive:
range_params[key] = list(range(start, stop + 1, step))
else:
range_params[key] = list(range(start, stop, step))
else:
# This is a fixed value
range_params[key] = [value]
# Generate all combinations (Cartesian product)
if range_params:
import itertools
param_names = list(range_params.keys())
param_values = [range_params[name] for name in param_names]
for values in itertools.product(*param_values):
params = dict(zip(param_names, values))
spec = template.format(**params)
all_specs.append(spec)
else:
# No parameters, just use template as-is
all_specs.append(template)
return all_specs
def main():
parser = argparse.ArgumentParser(
description="Universal vLLM attention benchmark",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
# Config file
parser.add_argument(
"--config",
help="Path to YAML config file (overrides other args)",
)
# Backend selection
parser.add_argument(
"--backends",
nargs="+",
help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
"flashinfer_mla, flashattn_mla, flashmla)",
)
parser.add_argument(
"--backend",
help="Single backend (alternative to --backends)",
)
# Batch specifications
parser.add_argument(
"--batch-specs",
nargs="+",
default=["q2k", "8q1s1k"],
help="Batch specifications using extended grammar",
)
# Model config
parser.add_argument("--num-layers", type=int, default=10, help="Number of layers")
parser.add_argument("--head-dim", type=int, default=128, help="Head dimension")
parser.add_argument("--num-q-heads", type=int, default=32, help="Query heads")
parser.add_argument("--num-kv-heads", type=int, default=8, help="KV heads")
parser.add_argument("--block-size", type=int, default=16, help="Block size")
# Benchmark settings
parser.add_argument("--device", default="cuda:0", help="Device")
parser.add_argument("--repeats", type=int, default=1, help="Repetitions")
parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations")
parser.add_argument("--profile-memory", action="store_true", help="Profile memory")
# Parameter sweep (use YAML config for advanced sweeps)
parser.add_argument(
"--sweep-param",
help="Parameter name to sweep (e.g., num_kv_splits, reorder_batch_threshold)",
)
parser.add_argument(
"--sweep-values",
type=int,
nargs="+",
help="Values to sweep for the parameter",
)
# Output
parser.add_argument("--output-csv", help="Save to CSV")
parser.add_argument("--output-json", help="Save to JSON")
args = parser.parse_args()
console = Console()
console.print("[bold cyan]vLLM Attention Benchmark[/]")
# Load config from YAML if provided
if args.config:
console.print(f"[yellow]Loading config from: {args.config}[/]")
yaml_config = load_config_from_yaml(args.config)
# Show description if available
if "description" in yaml_config:
console.print(f"[dim]{yaml_config['description']}[/]")
# Override args with YAML values
# (YAML takes precedence unless CLI arg was explicitly set)
# Backend(s)
if "backend" in yaml_config:
args.backend = yaml_config["backend"]
args.backends = None
elif "backends" in yaml_config:
args.backends = yaml_config["backends"]
args.backend = None
# Check for special modes
if "mode" in yaml_config:
args.mode = yaml_config["mode"]
else:
args.mode = None
# Batch specs and sizes
# Support both explicit batch_specs and generated batch_spec_ranges
if "batch_spec_ranges" in yaml_config:
# Generate batch specs from ranges
generated_specs = generate_batch_specs_from_ranges(
yaml_config["batch_spec_ranges"]
)
# Combine with any explicit batch_specs
if "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"] + generated_specs
else:
args.batch_specs = generated_specs
console.print(
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
)
elif "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"]
if "batch_sizes" in yaml_config:
args.batch_sizes = yaml_config["batch_sizes"]
else:
args.batch_sizes = None
# Model config
if "model" in yaml_config:
model = yaml_config["model"]
args.num_layers = model.get("num_layers", args.num_layers)
args.head_dim = model.get("head_dim", args.head_dim)
args.num_q_heads = model.get("num_q_heads", args.num_q_heads)
args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads)
args.block_size = model.get("block_size", args.block_size)
# Benchmark settings
if "benchmark" in yaml_config:
bench = yaml_config["benchmark"]
args.device = bench.get("device", args.device)
args.repeats = bench.get("repeats", args.repeats)
args.warmup_iters = bench.get("warmup_iters", args.warmup_iters)
args.profile_memory = bench.get("profile_memory", args.profile_memory)
# Parameter sweep configuration
if "parameter_sweep" in yaml_config:
sweep_config = yaml_config["parameter_sweep"]
args.parameter_sweep = ParameterSweep(
param_name=sweep_config["param_name"],
values=sweep_config["values"],
include_auto=sweep_config.get("include_auto", False),
label_format=sweep_config.get(
"label_format", "{backend}_{param_name}_{value}"
),
)
else:
args.parameter_sweep = None
# Model parameter sweep configuration
if "model_parameter_sweep" in yaml_config:
sweep_config = yaml_config["model_parameter_sweep"]
args.model_parameter_sweep = ModelParameterSweep(
param_name=sweep_config["param_name"],
values=sweep_config["values"],
label_format=sweep_config.get(
"label_format", "{backend}_{param_name}_{value}"
),
)
else:
args.model_parameter_sweep = None
# Output
if "output" in yaml_config:
output = yaml_config["output"]
if "csv" in output and not args.output_csv:
args.output_csv = output["csv"]
if "json" in output and not args.output_json:
args.output_json = output["json"]
console.print()
# Handle CLI-based parameter sweep (if not from YAML)
if (
(not hasattr(args, "parameter_sweep") or args.parameter_sweep is None)
and args.sweep_param
and args.sweep_values
):
args.parameter_sweep = ParameterSweep(
param_name=args.sweep_param,
values=args.sweep_values,
include_auto=False,
label_format="{backend}_{param_name}_{value}",
)
# Determine backends
backends = args.backends or ([args.backend] if args.backend else ["flash"])
console.print(f"Backends: {', '.join(backends)}")
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
console.print()
# Run benchmarks
all_results = []
# Handle special mode: decode_vs_prefill comparison
if hasattr(args, "mode") and args.mode == "decode_vs_prefill":
console.print("[yellow]Mode: Decode vs Prefill pipeline comparison[/]")
console.print(
"[dim]For each query length, testing both decode and prefill pipelines[/]"
)
console.print("[dim]Using batched execution for optimal performance[/]")
# Extract batch sizes from config
batch_sizes = getattr(args, "batch_sizes", [1])
backend = backends[0] # Use first backend (should only be one)
# Calculate total benchmarks
total = len(batch_sizes)
with tqdm(total=total, desc="Benchmarking") as pbar:
for batch_size in batch_sizes:
# Prepare all configs for this batch size
configs_with_thresholds = []
for spec in args.batch_specs:
# Parse the batch spec to get query length
requests = parse_batch_spec(spec)
if not requests:
console.print(
f"[red]Error: Could not parse batch spec '{spec}'[/]"
)
continue
# Get query length from first request
query_length = requests[0].q_len
# Create batch spec for this batch size
# For batch_size > 1, we need to prepend the count
batch_spec = f"{batch_size}{spec}" if batch_size > 1 else spec
# Create base config (without backend name)
base_config = BenchmarkConfig(
backend=backend, # Will be overridden later
batch_spec=batch_spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
)
# Add decode pipeline config
decode_threshold = query_length
config_decode = replace(
base_config,
backend=f"{backend}_decode_qlen{query_length}_bs{batch_size}",
)
configs_with_thresholds.append((config_decode, decode_threshold))
# Add prefill pipeline config if query_length > 1
if query_length > 1:
prefill_threshold = query_length - 1
config_prefill = replace(
base_config,
backend=f"{backend}_prefill_qlen{query_length}"
f"_bs{batch_size}",
)
configs_with_thresholds.append(
(config_prefill, prefill_threshold)
)
# Run all benchmarks for this batch size in one go (batched mode)
try:
from mla_runner import run_mla_benchmark as run_mla
# Use batched API: pass list of (config, threshold) tuples
timing_results = run_mla(backend, configs_with_thresholds)
# Create BenchmarkResult objects from timing results
for (config, _), timing in zip(
configs_with_thresholds, timing_results
):
result = BenchmarkResult(
config=config,
mean_time=timing["mean"],
std_time=timing["std"],
min_time=timing["min"],
max_time=timing["max"],
throughput_tokens_per_sec=timing.get("throughput", None),
)
all_results.append(result)
except Exception as e:
import traceback
console.print(
f"[red]Error running batched benchmarks for "
f"batch_size={batch_size}: {e}[/]"
)
console.print("[red]Traceback:[/]")
traceback.print_exc()
# Add error results for all configs
for config, _ in configs_with_thresholds:
result = BenchmarkResult(
config=config,
mean_time=float("inf"),
std_time=0,
min_time=float("inf"),
max_time=float("inf"),
error=str(e),
)
all_results.append(result)
pbar.update(1)
# Display decode vs prefill results
console.print("\n[bold green]Decode vs Prefill Results:[/]")
# Group by batch size
by_batch_size = {}
for r in all_results:
if r.success:
# Extract batch size from backend name
parts = r.config.backend.split("_")
bs_part = [p for p in parts if p.startswith("bs")]
if bs_part:
bs = int(bs_part[0][2:])
if bs not in by_batch_size:
by_batch_size[bs] = []
by_batch_size[bs].append(r)
# For each batch size, analyze crossover point
for bs in sorted(by_batch_size.keys()):
console.print(f"\n[bold cyan]Batch size: {bs}[/]")
results = by_batch_size[bs]
# Group by query length
by_qlen = {}
for r in results:
parts = r.config.backend.split("_")
qlen_part = [p for p in parts if p.startswith("qlen")]
if qlen_part:
qlen = int(qlen_part[0][4:])
if qlen not in by_qlen:
by_qlen[qlen] = {}
pipeline = "decode" if "decode" in r.config.backend else "prefill"
by_qlen[qlen][pipeline] = r
# Find crossover point
last_decode_faster = None
for qlen in sorted(by_qlen.keys()):
pipelines = by_qlen[qlen]
if "decode" in pipelines and "prefill" in pipelines:
decode_time = pipelines["decode"].mean_time
prefill_time = pipelines["prefill"].mean_time
faster = "decode" if decode_time < prefill_time else "prefill"
speedup = (
prefill_time / decode_time
if decode_time < prefill_time
else decode_time / prefill_time
)
console.print(
f" qlen={qlen:3d}: decode={decode_time:.6f}s, "
f"prefill={prefill_time:.6f}s -> "
f"[bold]{faster}[/] ({speedup:.2f}x)"
)
if faster == "decode":
last_decode_faster = qlen
if last_decode_faster is not None:
optimal_threshold = last_decode_faster
console.print(
f"\n [bold green]Optimal threshold for batch_size={bs}: "
f"{optimal_threshold}[/]"
)
console.print(
f" [dim](Use decode pipeline for query_length <= "
f"{optimal_threshold})[/]"
)
else:
console.print(
f"\n [yellow]Prefill always faster for batch_size={bs}[/]"
)
# Handle model parameter sweep mode
elif hasattr(args, "model_parameter_sweep") and args.model_parameter_sweep:
# Model parameter sweep
base_config_args = {
"num_layers": args.num_layers,
"head_dim": args.head_dim,
"num_q_heads": args.num_q_heads,
"num_kv_heads": args.num_kv_heads,
"block_size": args.block_size,
"device": args.device,
"repeats": args.repeats,
"warmup_iters": args.warmup_iters,
"profile_memory": args.profile_memory,
}
all_results = run_model_parameter_sweep(
backends,
args.batch_specs,
base_config_args,
args.model_parameter_sweep,
console,
)
# Handle parameter sweep mode (unified)
elif hasattr(args, "parameter_sweep") and args.parameter_sweep:
# Unified parameter sweep
base_config_args = {
"num_layers": args.num_layers,
"head_dim": args.head_dim,
"num_q_heads": args.num_q_heads,
"num_kv_heads": args.num_kv_heads,
"block_size": args.block_size,
"device": args.device,
"repeats": args.repeats,
"warmup_iters": args.warmup_iters,
"profile_memory": args.profile_memory,
}
all_results = run_parameter_sweep(
backends, args.batch_specs, base_config_args, args.parameter_sweep, console
)
else:
# Normal mode: compare backends
total = len(backends) * len(args.batch_specs)
with tqdm(total=total, desc="Benchmarking") as pbar:
for spec in args.batch_specs:
for backend in backends:
config = BenchmarkConfig(
backend=backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
)
result = run_benchmark(config)
all_results.append(result)
if not result.success:
console.print(f"[red]Error {backend} {spec}: {result.error}[/]")
pbar.update(1)
# Display results
console.print("\n[bold green]Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(all_results, backends)
# Save results
if all_results:
formatter = ResultsFormatter(console)
if args.output_csv:
formatter.save_csv(all_results, args.output_csv)
if args.output_json:
formatter.save_json(all_results, args.output_json)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Common utilities for attention benchmarking."""
import csv
import json
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import numpy as np
import torch
from rich.console import Console
from rich.table import Table
# Mock classes for vLLM attention infrastructure
class MockHfConfig:
"""Mock HuggingFace config that satisfies vLLM's requirements."""
def __init__(self, mla_dims: dict):
self.num_attention_heads = mla_dims["num_q_heads"]
self.num_key_value_heads = mla_dims["num_kv_heads"]
self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"]
self.model_type = "deepseek_v2"
self.is_encoder_decoder = False
self.kv_lora_rank = mla_dims["kv_lora_rank"]
self.qk_nope_head_dim = mla_dims["qk_nope_head_dim"]
self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"]
self.v_head_dim = mla_dims["v_head_dim"]
self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]
def get_text_config(self):
return self
# Import AttentionLayerBase at module level to avoid circular dependencies
try:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
_HAS_ATTENTION_LAYER_BASE = True
except ImportError:
_HAS_ATTENTION_LAYER_BASE = False
AttentionLayerBase = object # Fallback
class MockKVBProj:
"""Mock KV projection layer for MLA prefill mode.
Mimics ColumnParallelLinear behavior for kv_b_proj in MLA backends.
Projects kv_c_normed to [qk_nope_head_dim + v_head_dim] per head.
"""
def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int):
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
self.out_dim = qk_nope_head_dim + v_head_dim
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""
Project kv_c_normed to output space.
Args:
x: Input tensor [num_tokens, kv_lora_rank]
Returns:
Tuple containing output tensor
[num_tokens, num_heads, qk_nope_head_dim + v_head_dim]
"""
num_tokens = x.shape[0]
result = torch.randn(
num_tokens,
self.num_heads,
self.out_dim,
device=x.device,
dtype=x.dtype,
)
return (result,) # Return as tuple to match ColumnParallelLinear API
class MockLayer(AttentionLayerBase):
"""Mock attention layer with scale parameters and impl.
Inherits from AttentionLayerBase so it passes isinstance checks
in get_layers_from_vllm_config when FlashInfer prefill is enabled.
"""
def __init__(self, device: torch.device, impl=None, kv_cache_spec=None):
# Don't call super().__init__() as AttentionLayerBase doesn't have __init__
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._q_scale = torch.tensor(1.0, device=device)
# Scalar floats for kernels that need them
self._k_scale_float = float(self._k_scale.item())
self._v_scale_float = float(self._v_scale.item())
self._q_scale_float = float(self._q_scale.item())
# AttentionImpl for metadata builders to query
self.impl = impl
# KV cache spec for get_kv_cache_spec
self._kv_cache_spec = kv_cache_spec
def get_attn_backend(self):
"""Get the attention backend class (required by AttentionLayerBase)."""
# Return None as this is just a mock layer for benchmarking
return None
def get_kv_cache_spec(self):
"""Get the KV cache spec (required by AttentionLayerBase)."""
return self._kv_cache_spec
class MockModelConfig:
"""Mock model configuration."""
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
max_model_len: int = 32768,
):
self._n_q = num_q_heads
self._n_kv = num_kv_heads
self._d = head_dim
self.dtype = dtype
self.max_model_len = max_model_len
def get_num_attention_heads(self, _=None) -> int:
return self._n_q
def get_num_kv_heads(self, _=None) -> int:
return self._n_kv
def get_head_size(self) -> int:
return self._d
def get_num_layers(self) -> int:
"""Mock method for layer count queries."""
return 1
def get_sliding_window_for_layer(self, _layer_idx: int):
"""Mock method for sliding window queries."""
return None
def get_logits_soft_cap_for_layer(self, _layer_idx: int):
"""Mock method for logits soft cap queries."""
return None
def get_sm_scale_for_layer(self, _layer_idx: int) -> float:
"""Mock method for SM scale queries."""
return 1.0 / (self.get_head_size() ** 0.5)
class MockParallelConfig:
"""Mock parallel configuration."""
pass
class MockCompilationConfig:
"""Mock compilation configuration."""
def __init__(self):
self.full_cuda_graph = False
self.static_forward_context = {}
class MockVLLMConfig:
"""Mock VLLM configuration."""
def __init__(self):
self.compilation_config = MockCompilationConfig()
class MockRunner:
"""Mock GPU runner for metadata builders."""
def __init__(
self,
seq_lens: np.ndarray,
query_start_locs: np.ndarray,
device: torch.device,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
):
self.model_config = MockModelConfig(num_q_heads, num_kv_heads, head_dim, dtype)
self.parallel_config = MockParallelConfig()
self.vllm_config = MockVLLMConfig()
self.seq_lens_np = seq_lens
self.query_start_loc_np = query_start_locs
self.device = device
self.attention_chunk_size = None
self.num_query_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.dtype = dtype
@dataclass
class ParameterSweep:
"""Configuration for sweeping a backend parameter."""
param_name: str # Name of the backend parameter to sweep
values: list[Any] # List of values to test
include_auto: bool = False # Also test with param unset (auto mode)
label_format: str = "{backend}_{param_name}_{value}" # Result label template
def get_label(self, backend: str, value: Any) -> str:
"""Generate a label for a specific parameter value."""
return self.label_format.format(
backend=backend, param_name=self.param_name, value=value
)
@dataclass
class ModelParameterSweep:
"""Configuration for sweeping a model configuration parameter."""
param_name: str # Name of the model config parameter to sweep (e.g., "num_q_heads")
values: list[Any] # List of values to test
label_format: str = "{backend}_{param_name}_{value}" # Result label template
def get_label(self, backend: str, value: Any) -> str:
"""Generate a label for a specific parameter value."""
return self.label_format.format(
backend=backend, param_name=self.param_name, value=value
)
@dataclass
class BenchmarkConfig:
"""Configuration for a single benchmark run."""
backend: str
batch_spec: str
num_layers: int
head_dim: int
num_q_heads: int
num_kv_heads: int
block_size: int
device: str
dtype: torch.dtype = torch.float16
repeats: int = 1
warmup_iters: int = 3
profile_memory: bool = False
use_cuda_graphs: bool = False
# MLA-specific
kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None
v_head_dim: int | None = None
# Backend-specific tuning
num_kv_splits: int | None = None # CUTLASS MLA
reorder_batch_threshold: int | None = None # FlashAttn MLA, FlashMLA
@dataclass
class BenchmarkResult:
"""Results from a single benchmark run."""
config: BenchmarkConfig
mean_time: float # seconds
std_time: float # seconds
min_time: float # seconds
max_time: float # seconds
throughput_tokens_per_sec: float | None = None
memory_allocated_mb: float | None = None
memory_reserved_mb: float | None = None
error: str | None = None
@property
def success(self) -> bool:
"""Whether benchmark completed successfully."""
return self.error is None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"config": asdict(self.config),
"mean_time": self.mean_time,
"std_time": self.std_time,
"min_time": self.min_time,
"max_time": self.max_time,
"throughput_tokens_per_sec": self.throughput_tokens_per_sec,
"memory_allocated_mb": self.memory_allocated_mb,
"memory_reserved_mb": self.memory_reserved_mb,
"error": self.error,
}
class ResultsFormatter:
"""Format and display benchmark results."""
def __init__(self, console: Console | None = None):
self.console = console or Console()
def print_table(
self,
results: list[BenchmarkResult],
backends: list[str],
compare_to_fastest: bool = True,
):
"""
Print results as a rich table.
Args:
results: List of BenchmarkResult
backends: List of backend names being compared
compare_to_fastest: Show percentage comparison to fastest
"""
# Group by batch spec
by_spec = {}
for r in results:
spec = r.config.batch_spec
if spec not in by_spec:
by_spec[spec] = {}
by_spec[spec][r.config.backend] = r
# Create shortened backend names for display
def shorten_backend_name(name: str) -> str:
"""Shorten long backend names for table display."""
# Remove common prefixes
name = name.replace("flashattn_mla", "famla")
name = name.replace("flashinfer_mla", "fimla")
name = name.replace("flashmla", "fmla")
name = name.replace("cutlass_mla", "cmla")
name = name.replace("numsplits", "ns")
return name
table = Table(title="Attention Benchmark Results")
table.add_column("Batch\nSpec", no_wrap=True)
multi = len(backends) > 1
for backend in backends:
short_name = shorten_backend_name(backend)
# Time column
col_time = f"{short_name}\nTime (s)"
table.add_column(col_time, justify="right", no_wrap=False)
if multi and compare_to_fastest:
# Relative performance column
col_rel = f"{short_name}\nvs Best"
table.add_column(col_rel, justify="right", no_wrap=False)
# Add rows
for spec in sorted(by_spec.keys()):
spec_results = by_spec[spec]
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
best_time = min(times.values()) if times else 0.0
row = [spec]
for backend in backends:
if backend in spec_results:
r = spec_results[backend]
if r.success:
row.append(f"{r.mean_time:.6f}")
if multi and compare_to_fastest:
pct = (
(r.mean_time / best_time * 100) if best_time > 0 else 0
)
pct_str = f"{pct:.1f}%"
if r.mean_time == best_time:
pct_str = f"[bold green]{pct_str}[/]"
row.append(pct_str)
else:
row.append("[red]ERROR[/]")
if multi and compare_to_fastest:
row.append("-")
else:
row.append("-")
if multi and compare_to_fastest:
row.append("-")
table.add_row(*row)
self.console.print(table)
def save_csv(self, results: list[BenchmarkResult], path: str):
"""Save results to CSV file."""
if not results:
return
path_obj = Path(path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"backend",
"batch_spec",
"num_layers",
"mean_time",
"std_time",
"throughput",
"memory_mb",
],
)
writer.writeheader()
for r in results:
writer.writerow(
{
"backend": r.config.backend,
"batch_spec": r.config.batch_spec,
"num_layers": r.config.num_layers,
"mean_time": r.mean_time,
"std_time": r.std_time,
"throughput": r.throughput_tokens_per_sec or 0,
"memory_mb": r.memory_allocated_mb or 0,
}
)
self.console.print(f"[green]Saved CSV results to {path}[/]")
def save_json(self, results: list[BenchmarkResult], path: str):
"""Save results to JSON file."""
path_obj = Path(path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
data = [r.to_dict() for r in results]
with open(path, "w") as f:
json.dump(data, f, indent=2, default=str)
self.console.print(f"[green]Saved JSON results to {path}[/]")
def setup_mla_dims(model_name: str = "deepseek-v3") -> dict:
"""
Get MLA dimensions for known models.
Args:
model_name: Model identifier
Returns:
Dict with MLA dimension configuration
"""
configs = {
"deepseek-v2": {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": 128,
"num_kv_heads": 1,
"head_dim": 576,
},
"deepseek-v3": {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": 128,
"num_kv_heads": 1,
"head_dim": 576,
},
"deepseek-v2-lite": {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": 16,
"num_kv_heads": 1,
"head_dim": 576,
},
}
if model_name not in configs:
raise ValueError(
f"Unknown model '{model_name}'. Known models: {list(configs.keys())}"
)
return configs[model_name]
def get_attention_scale(head_dim: int) -> float:
"""Compute attention scale factor (1/sqrt(d))."""
return 1.0 / math.sqrt(head_dim)
def is_mla_backend(backend: str) -> bool:
"""
Check if backend is an MLA backend using the backend's is_mla() property.
Args:
backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA")
Returns:
True if the backend is an MLA backend, False otherwise
"""
from vllm.v1.attention.backends.registry import AttentionBackendEnum
try:
backend_class = AttentionBackendEnum[backend.upper()].get_class()
return backend_class.is_mla()
except (KeyError, ValueError, ImportError):
return False
# MLA decode-only benchmark configuration
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1 # MLA uses single latent KV
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
batch_specs:
# Small batches, varying sequence lengths
- "16q1s512" # 16 requests, 512 KV cache
- "16q1s1k" # 16 requests, 1k KV cache
- "16q1s2k" # 16 requests, 2k KV cache
- "16q1s4k" # 16 requests, 4k KV cache
# Medium batches
- "32q1s1k" # 32 requests, 1k KV cache
- "32q1s2k" # 32 requests, 2k KV cache
- "32q1s4k" # 32 requests, 4k KV cache
- "32q1s8k" # 32 requests, 8k KV cache
# Large batches
- "64q1s1k" # 64 requests, 1k KV cache
- "64q1s2k" # 64 requests, 2k KV cache
- "64q1s4k" # 64 requests, 4k KV cache
- "64q1s8k" # 64 requests, 8k KV cache
# Very large batches
- "128q1s1k" # 128 requests, 1k KV cache
- "128q1s2k" # 128 requests, 2k KV cache
# Long context
- "32q1s16k" # 32 requests, 16k KV cache
- "32q1s32k" # 32 requests, 32k KV cache
backends:
- cutlass_mla
- flashinfer_mla
- flashattn_mla # Hopper only
- flashmla # Hopper only
device: "cuda:0"
repeats: 5
warmup_iters: 3
profile_memory: true
# Backend-specific tuning
cutlass_mla:
num_kv_splits: auto # or specific value like 4, 8, 16
flashattn_mla:
reorder_batch_threshold: 512
flashmla:
reorder_batch_threshold: 1
# MLA mixed batch benchmark (prefill + decode)
# Tests chunked prefill performance
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128
batch_specs:
# Small prefill + decode
- "1q1k_8q1s1k" # 1 prefill + 8 decode
- "2q2k_16q1s1k" # 2 prefill + 16 decode
- "4q1k_32q1s2k" # 4 prefill + 32 decode
# Medium prefill + decode
- "2q4k_32q1s2k" # 2 medium prefill + 32 decode
- "4q4k_64q1s2k" # 4 medium prefill + 64 decode
- "8q2k_64q1s4k" # 8 prefill + 64 decode
# Large prefill + decode (chunked prefill stress test)
- "2q8k_32q1s1k" # 2 large prefill + 32 decode
- "1q16k_16q1s2k" # 1 very large prefill + 16 decode
- "2q16k_32q1s4k" # 2 very large prefill + 32 decode
# Context extension + decode
- "2q1kkv2k_16q1s1k" # 2 extend + 16 decode
- "4q2kkv4k_32q1s2k" # 4 extend + 32 decode
- "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode
# Explicitly chunked prefill
- "q8k" # 8k prefill with chunking hint
- "q16k" # 16k prefill with chunking hint
- "2q8k_32q1s2k" # 2 chunked prefill + 32 decode
# High decode ratio (realistic serving)
- "1q2k_63q1s1k" # 1 prefill + 63 decode
- "2q2k_62q1s2k" # 2 prefill + 62 decode
- "4q4k_60q1s4k" # 4 prefill + 60 decode
backends:
- cutlass_mla
- flashinfer_mla
- flashattn_mla # Hopper only
- flashmla # Hopper only
device: "cuda:0"
repeats: 5
warmup_iters: 3
profile_memory: true
# Analyze chunked prefill workspace size impact
chunked_prefill:
test_workspace_sizes: [4096, 8192, 16384, 32768, 65536]
# Study 4: What is optimal reorder_batch_threshold for MLA backends supporting query length > 1?
# Question: At what query length does prefill pipeline become faster than decode pipeline?
# Methodology: For each query length, compare decode vs prefill performance to find crossover point
# Applies to: FlashAttn MLA, FlashMLA
description: "Decode vs Prefill pipeline crossover analysis"
# Test FlashAttn MLA
backend: flashattn_mla
# Mode: decode_vs_prefill comparison (special sweep mode)
# For each batch spec, we'll test both decode and prefill pipelines
mode: "decode_vs_prefill"
# Query lengths to test (from old benchmark_mla_threshold.py methodology)
# Each query length will be tested with BOTH decode and prefill pipelines:
# - decode: threshold >= query_length (forces decode pipeline)
# - prefill: threshold < query_length (forces prefill pipeline)
#
# We use q<N>s1k format which creates q_len=N, seq_len=1024 requests
# This tests different query lengths with fixed sequence length context
#
# Using batch_spec_ranges for automatic generation:
batch_spec_ranges:
- template: "q{q_len}s1k"
q_len:
start: 1
stop: 16
step: 1
end_inclusive: false
- template: "q{q_len}s1k"
q_len:
start: 16
stop: 64
step: 2
end_inclusive: false
- template: "q{q_len}s1k"
q_len:
start: 64
stop: 1024
step: 4
end_inclusive: true
# Batch sizes to test (from old script)
batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
# Model configuration (DeepSeek V2/V3 defaults)
model:
num_layers: 10
head_dim: 576
num_q_heads: 128
num_kv_heads: 1
block_size: 128
# Benchmark settings
benchmark:
device: "cuda:0"
repeats: 15 # More repeats for spec decode variance
warmup_iters: 5
profile_memory: false
# Output
output:
csv: "reorder_threshold_results.csv"
json: "reorder_threshold_results.json"
# Expected outcome (reproduces old benchmark_mla_threshold.py study):
# - For each batch size, find the crossover point where prefill becomes faster than decode
# - Show decode vs prefill performance across all query lengths
# - Determine optimal reorder_batch_threshold based on last query length where decode is faster
# - Understand how crossover point varies with batch size
# - Provide data-driven guidance for default threshold value
#
# Methodology (from old script):
# - Each query length tested with BOTH pipelines:
# * decode: threshold >= query_length (forces decode pipeline)
# * prefill: threshold < query_length (forces prefill pipeline)
# - Compare which is faster to find crossover point
#
# Speculative decoding benchmark configuration
# Tests reorder_batch_threshold optimization
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
batch_specs:
# Pure speculative decode (K-token verification)
- "q2s1k" # 2-token spec, 1k KV
- "q4s1k" # 4-token spec, 1k KV
- "q8s1k" # 8-token spec, 1k KV
- "q16s1k" # 16-token spec, 1k KV
# Speculative with different context lengths
- "q4s2k" # 4-token spec, 2k KV
- "q4s4k" # 4-token spec, 4k KV
- "q8s2k" # 8-token spec, 2k KV
- "q8s4k" # 8-token spec, 4k KV
# Mixed: speculative + regular decode
- "32q4s1k" # 32 spec requests
- "16q4s1k_16q1s1k" # 16 spec + 16 regular
- "8q8s2k_24q1s2k" # 8 spec (8-tok) + 24 regular
# Mixed: speculative + prefill + decode
- "2q1k_16q4s1k_16q1s1k" # 2 prefill + 16 spec + 16 decode
- "4q2k_32q4s2k_32q1s2k" # 4 prefill + 32 spec + 32 decode
# Large batches with speculation
- "64q4s1k" # 64 spec requests
- "32q8s2k" # 32 spec (8-token)
- "16q16s4k" # 16 spec (16-token)
# Backends that support query length > 1
backends:
- flashattn_mla # reorder_batch_threshold = 512
- flashmla # reorder_batch_threshold = 1 (tunable)
# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism
# - flashinfer_mla
# Benchmark settings
benchmark:
device: "cuda:0"
repeats: 10 # More repeats for statistical significance
warmup_iters: 5
profile_memory: false
# Test these threshold values for optimization
parameter_sweep:
param_name: "reorder_batch_threshold"
values: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
include_auto: false
label_format: "{backend}_threshold_{value}"
# Standard attention backend benchmark configuration
model:
num_layers: 32
num_q_heads: 32
num_kv_heads: 8 # GQA with 4:1 ratio
head_dim: 128
block_size: 16
batch_specs:
# Pure prefill
- "q512" # Small prefill (512 tokens)
- "q2k" # Medium prefill (2048 tokens)
- "q4k" # Large prefill (4096 tokens)
- "q8k" # Very large prefill (8192 tokens)
# Pure decode
- "8q1s1k" # 8 requests, 1k KV cache each
- "16q1s2k" # 16 requests, 2k KV cache each
- "32q1s1k" # 32 requests, 1k KV cache each
- "64q1s4k" # 64 requests, 4k KV cache each
# Mixed prefill/decode
- "2q2k_8q1s1k" # 2 prefill + 8 decode
- "4q1k_16q1s2k" # 4 prefill + 16 decode
- "2q4k_32q1s1k" # 2 large prefill + 32 decode
# Context extension
- "q1ks2k" # 1k query, 2k sequence (chunked prefill)
- "2q1ks4k" # 2 requests: 1k query, 4k sequence
backends:
- flash
- triton
- flashinfer
device: "cuda:0"
repeats: 5
warmup_iters: 3
profile_memory: false
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
MLA benchmark runner - shared utilities for MLA benchmarks.
This module provides helpers for running MLA backends without
needing full VllmConfig integration.
"""
import importlib
import numpy as np
import torch
from batch_spec import parse_batch_spec
from common import (
BenchmarkResult,
MockHfConfig,
MockKVBProj,
MockLayer,
setup_mla_dims,
)
from vllm.config import (
CacheConfig,
CompilationConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
# ============================================================================
# VllmConfig Creation
# ============================================================================
def _add_mock_methods_to_model_config(model_config: ModelConfig) -> None:
"""
Add mock methods for layer-specific queries to ModelConfig.
These methods are needed by metadata builders but aren't normally
present on ModelConfig when used in benchmark contexts.
"""
import types
model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
model_config.get_sliding_window_for_layer = types.MethodType(
lambda self, _i: None, model_config
)
model_config.get_logits_soft_cap_for_layer = types.MethodType(
lambda self, _i: None, model_config
)
model_config.get_sm_scale_for_layer = types.MethodType(
lambda self, _i: 1.0 / model_config.get_head_size() ** 0.5, model_config
)
def create_minimal_vllm_config(
model_name: str = "deepseek-v3",
block_size: int = 128,
max_num_seqs: int = 256,
mla_dims: dict | None = None,
) -> VllmConfig:
"""
Create minimal VllmConfig for MLA benchmarks.
Args:
model_name: Model name (deepseek-v2, deepseek-v3, etc.) - used if mla_dims not
provided
block_size: KV cache block size
max_num_seqs: Maximum number of sequences
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
setup_mla_dims(model_name)
Returns:
VllmConfig for benchmarking
"""
# Get MLA dimensions - use provided or load from model name
if mla_dims is None:
mla_dims = setup_mla_dims(model_name)
# Create mock HF config first (avoids downloading from HuggingFace)
mock_hf_config = MockHfConfig(mla_dims)
# Create a temporary minimal config.json to avoid HF downloads
# This ensures consistent ModelConfig construction without network access
import json
import os
import shutil
import tempfile
minimal_config = {
"architectures": ["DeepseekV2ForCausalLM"],
"model_type": "deepseek_v2",
"num_attention_heads": mla_dims["num_q_heads"],
"num_key_value_heads": mla_dims["num_kv_heads"],
"hidden_size": mla_dims["head_dim"] * mla_dims["num_q_heads"],
"torch_dtype": "bfloat16",
"max_position_embeddings": 163840, # DeepSeek V3 default
"rope_theta": 10000.0,
"vocab_size": 128256,
}
# Create temporary directory with config.json
temp_dir = tempfile.mkdtemp(prefix="vllm_bench_")
config_path = os.path.join(temp_dir, "config.json")
with open(config_path, "w") as f:
json.dump(minimal_config, f)
try:
# Create model config using local path - no HF downloads
model_config = ModelConfig(
model=temp_dir, # Use local temp directory
tokenizer=None,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
max_model_len=32768,
quantization=None,
quantization_param_path=None,
enforce_eager=False,
max_context_len_to_capture=None,
max_seq_len_to_capture=8192,
max_logprobs=20,
disable_sliding_window=False,
skip_tokenizer_init=True,
served_model_name=None,
limit_mm_per_prompt=None,
use_async_output_proc=True,
config_format="auto",
)
finally:
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
# Override with our mock config
model_config.hf_config = mock_hf_config
model_config.hf_text_config = mock_hf_config
# Add mock methods for layer-specific queries
_add_mock_methods_to_model_config(model_config)
# Create sub-configs
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=False,
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=8192,
max_model_len=32768,
is_encoder_decoder=False,
enable_chunked_prefill=True,
)
parallel_config = ParallelConfig(
tensor_parallel_size=1,
)
compilation_config = CompilationConfig()
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
compilation_config=compilation_config,
)
# ============================================================================
# Backend Configuration
# ============================================================================
# Backend name to class name prefix mapping
_BACKEND_NAME_MAP = {
"flashattn_mla": "FlashAttnMLA",
"flashmla": "FlashMLA",
"flashinfer_mla": "FlashInferMLA",
"cutlass_mla": "CutlassMLA",
}
# Special properties that differ from defaults
_BACKEND_PROPERTIES = {
"flashmla": {
"query_format": "concat", # Single concatenated tensor (vs tuple)
"block_size": 64, # FlashMLA uses fixed block size
},
"flashinfer_mla": {
"block_size": 64, # FlashInfer MLA only supports 32 or 64
},
}
def _get_backend_config(backend: str) -> dict:
"""
Get backend configuration using naming conventions.
All MLA backends follow the pattern:
- Module: vllm.v1.attention.backends.mla.{backend}
- Impl: {Name}Impl
- Metadata: {Name}Metadata (or MLACommonMetadata)
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
- MetadataBuilder: {Name}MetadataBuilder
"""
if backend not in _BACKEND_NAME_MAP:
raise ValueError(f"Unknown backend: {backend}")
name = _BACKEND_NAME_MAP[backend]
props = _BACKEND_PROPERTIES.get(backend, {})
# Check if backend uses common metadata (FlashInfer, CUTLASS)
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
return {
"module": f"vllm.v1.attention.backends.mla.{backend}",
"impl_class": f"{name}Impl",
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
"decode_metadata_class": "MLACommonDecodeMetadata"
if uses_common
else f"{name}DecodeMetadata",
"builder_class": f"{name}MetadataBuilder",
"query_format": props.get("query_format", "tuple"),
"block_size": props.get("block_size", None),
}
# ============================================================================
# Metadata Building Helpers
# ============================================================================
def _build_attention_metadata(
requests: list,
block_size: int,
device: torch.device,
builder_instance,
) -> tuple:
"""
Build attention metadata from batch requests.
Args:
requests: List of BatchRequest objects
block_size: KV cache block size
device: Target device
builder_instance: Metadata builder instance
Returns:
Tuple of (metadata, kv_cache_num_blocks)
"""
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv = max(kv_lens)
# Build query start locations
q_start_cpu = torch.tensor(
[0] + [sum(q_lens[: i + 1]) for i in range(len(q_lens))],
dtype=torch.int32,
)
q_start_gpu = q_start_cpu.to(device)
# Build sequence lengths
seq_lens_cpu = torch.tensor(kv_lens, dtype=torch.int32)
seq_lens_gpu = seq_lens_cpu.to(device)
# Build num_computed_tokens (context length for each request)
context_lens = [kv_len - q_len for q_len, kv_len in zip(q_lens, kv_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
# Build block table
num_blocks_per_req = [(kv + block_size - 1) // block_size for kv in kv_lens]
max_num_blocks = max(num_blocks_per_req)
block_table_cpu = np.zeros((len(requests), max_num_blocks), dtype=np.int32)
current_block = 0
for i, num_blocks in enumerate(num_blocks_per_req):
for j in range(num_blocks):
block_table_cpu[i, j] = current_block
current_block += 1
block_table_gpu = torch.from_numpy(block_table_cpu).to(device)
# Build slot mapping
slot_mapping_list = []
for i, (q_len, kv_len, num_blocks) in enumerate(
zip(q_lens, kv_lens, num_blocks_per_req)
):
context_len = kv_len - q_len
for j in range(q_len):
token_kv_idx = context_len + j
block_idx = token_kv_idx // block_size
offset_in_block = token_kv_idx % block_size
global_block_id = block_table_cpu[i, block_idx]
slot_id = global_block_id * block_size + offset_in_block
slot_mapping_list.append(slot_id)
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64, device=device)
# Create CommonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
common_attn_metadata = CommonAttentionMetadata(
num_reqs=len(requests),
max_query_len=max(q_lens),
max_seq_len=max_kv,
num_actual_tokens=total_q,
query_start_loc=q_start_gpu,
query_start_loc_cpu=q_start_cpu,
seq_lens=seq_lens_gpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
slot_mapping=slot_mapping,
block_table_tensor=block_table_gpu,
dcp_local_seq_lens=None,
)
# Use the production build() method
metadata = builder_instance.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=False,
)
return metadata, current_block
def _create_input_tensors(
total_q: int,
mla_dims: dict,
query_format: str,
device: torch.device,
dtype: torch.dtype,
):
"""
Create input tensors for both decode and prefill modes.
MLA requires different tensor formats for decode vs prefill:
- Decode: Uses kv_lora_rank (512) dimension
- Prefill: Uses qk_nope_head_dim (128) to stay under FlashAttention's 256 limit
Args:
total_q: Total number of query tokens
mla_dims: MLA dimension configuration
query_format: Either "tuple" or "concat"
device: Target device
dtype: Tensor dtype
Returns:
Tuple of (decode_inputs, prefill_inputs)
- decode_inputs: Query tensor(s) for decode mode
- prefill_inputs: Dict with 'q', 'k_c_normed', 'k_pe', 'k_scale' for prefill
"""
if query_format == "tuple":
# Decode mode format: (q_nope, q_pe) where q_nope has kv_lora_rank dim
q_nope_decode = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["kv_lora_rank"],
device=device,
dtype=dtype,
)
q_pe = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
decode_inputs = (q_nope_decode, q_pe)
# For prefill, we need q with qk_nope_head_dim instead of kv_lora_rank
q_nope_prefill = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["qk_nope_head_dim"],
device=device,
dtype=dtype,
)
prefill_q = torch.cat([q_nope_prefill, q_pe], dim=-1)
else: # concat
decode_inputs = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
# For prefill with concat format
prefill_q = torch.randn(
total_q,
mla_dims["num_q_heads"],
mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
# Create additional inputs needed for prefill forward
k_c_normed = torch.randn(
total_q,
mla_dims["kv_lora_rank"],
device=device,
dtype=dtype,
)
k_pe = torch.randn(
total_q,
1, # Single head for MLA
mla_dims["qk_rope_head_dim"],
device=device,
dtype=dtype,
)
k_scale = torch.ones(1, device=device, dtype=torch.float32)
output = torch.zeros(
total_q,
mla_dims["num_q_heads"] * mla_dims["v_head_dim"],
device=device,
dtype=dtype,
)
prefill_inputs = {
"q": prefill_q,
"k_c_normed": k_c_normed,
"k_pe": k_pe,
"k_scale": k_scale,
"output": output,
}
return decode_inputs, prefill_inputs
# ============================================================================
# Backend Initialization
# ============================================================================
def _create_backend_impl(
backend_cfg: dict,
mla_dims: dict,
vllm_config: VllmConfig,
device: torch.device,
):
"""
Create backend implementation instance.
Args:
backend_cfg: Backend configuration dict
mla_dims: MLA dimension configuration
vllm_config: VllmConfig instance
device: Target device
Returns:
Tuple of (impl, layer, builder_instance)
"""
# Import backend classes
backend_module = importlib.import_module(backend_cfg["module"])
impl_class = getattr(backend_module, backend_cfg["impl_class"])
# Calculate scale
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
# Create mock kv_b_proj layer for prefill mode
mock_kv_b_proj = MockKVBProj(
num_heads=mla_dims["num_q_heads"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
)
# Create impl
impl = impl_class(
num_heads=mla_dims["num_q_heads"],
head_size=mla_dims["head_dim"],
scale=scale,
num_kv_heads=mla_dims["num_kv_heads"],
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=mla_dims["kv_lora_rank"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
kv_b_proj=mock_kv_b_proj,
)
# Initialize DCP attributes
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
impl.dcp_world_size = 1
impl.dcp_rank = 0
# Create KV cache spec for MockLayer
from vllm.v1.kv_cache_interface import FullAttentionSpec
kv_cache_spec = FullAttentionSpec(
block_size=backend_cfg["block_size"] or vllm_config.cache_config.block_size,
num_kv_heads=1, # MLA uses 1 KV head
head_size=576, # MLA head dim
dtype=torch.bfloat16,
)
# Create mock layer
layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec)
# Create builder instance if needed
builder_instance = None
if backend_cfg["builder_class"]:
builder_class = getattr(backend_module, backend_cfg["builder_class"])
# Populate static_forward_context so builder can find the layer
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
builder_instance = builder_class(
kv_cache_spec=kv_cache_spec,
layer_names=["placeholder"],
vllm_config=vllm_config,
device=device,
)
return impl, layer, builder_instance
# ============================================================================
# Config Helpers
# ============================================================================
def _extract_mla_dims_from_config(config) -> dict | None:
"""
Extract MLA dimensions from BenchmarkConfig if all required fields are present.
Args:
config: BenchmarkConfig instance
Returns:
Dict with MLA dimensions if all fields are provided, None otherwise
"""
# Check if all MLA-specific fields are provided
if all(
[
config.kv_lora_rank is not None,
config.qk_nope_head_dim is not None,
config.qk_rope_head_dim is not None,
config.v_head_dim is not None,
]
):
return {
"kv_lora_rank": config.kv_lora_rank,
"qk_nope_head_dim": config.qk_nope_head_dim,
"qk_rope_head_dim": config.qk_rope_head_dim,
"v_head_dim": config.v_head_dim,
"num_q_heads": config.num_q_heads,
"num_kv_heads": config.num_kv_heads,
"head_dim": config.head_dim,
}
# Fallback: if MLA fields not fully specified, try to construct from basic fields
elif config.head_dim == 576:
# This looks like a DeepSeek MLA config, use standard dimensions with custom
# head count
return {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": config.num_q_heads,
"num_kv_heads": config.num_kv_heads,
"head_dim": config.head_dim,
}
return None
# ============================================================================
# Benchmark Execution
# ============================================================================
def _run_single_benchmark(
config,
impl,
layer,
builder_instance,
backend_cfg: dict,
mla_dims: dict,
device: torch.device,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
Args:
config: BenchmarkConfig instance
impl: Backend implementation instance
layer: MockLayer instance
builder_instance: Metadata builder instance
backend_cfg: Backend configuration dict
mla_dims: MLA dimension configuration
device: Target device
Returns:
BenchmarkResult with timing statistics
"""
# Parse batch spec
requests = parse_batch_spec(config.batch_spec)
q_lens = [r.q_len for r in requests]
total_q = sum(q_lens)
# Determine block size
block_size = backend_cfg["block_size"] or config.block_size
# Build metadata
metadata, num_blocks = _build_attention_metadata(
requests, block_size, device, builder_instance
)
# Create KV cache
kv_cache = torch.zeros(
num_blocks,
block_size,
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"],
device=device,
dtype=torch.bfloat16,
)
# Create input tensors for both decode and prefill modes
decode_inputs, prefill_inputs = _create_input_tensors(
total_q,
mla_dims,
backend_cfg["query_format"],
device,
torch.bfloat16,
)
# Determine which forward method to use based on metadata
if metadata.decode is not None:
forward_fn = lambda: impl._forward_decode(
decode_inputs, kv_cache, metadata, layer
)
elif metadata.prefill is not None:
forward_fn = lambda: impl._forward_prefill(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
else:
raise RuntimeError("Metadata has neither decode nor prefill metadata")
# Warmup
for _ in range(config.warmup_iters):
forward_fn()
torch.cuda.synchronize()
# Benchmark
times = []
for _ in range(config.repeats):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(config.num_layers):
forward_fn()
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers)
mean_time = float(np.mean(times))
return BenchmarkResult(
config=config,
mean_time=mean_time,
std_time=float(np.std(times)),
min_time=float(np.min(times)),
max_time=float(np.max(times)),
throughput_tokens_per_sec=total_q / mean_time if mean_time > 0 else 0,
)
def _run_mla_benchmark_batched(
backend: str,
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
This function reuses backend initialization across multiple benchmarks
to avoid setup/teardown overhead.
Args:
backend: Backend name
configs_with_params: List of (config, threshold, num_splits) tuples
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
- num_splits: num_kv_splits (CUTLASS only)
Returns:
List of BenchmarkResult objects
"""
if not configs_with_params:
return []
backend_cfg = _get_backend_config(backend)
device = torch.device(configs_with_params[0][0].device)
torch.cuda.set_device(device)
# Determine block size
config_block_size = configs_with_params[0][0].block_size
block_size = backend_cfg["block_size"] or config_block_size
# Extract MLA dimensions from the first config
first_config = configs_with_params[0][0]
mla_dims = _extract_mla_dims_from_config(first_config)
# If config didn't provide MLA dims, fall back to default model
if mla_dims is None:
mla_dims = setup_mla_dims("deepseek-v3")
# Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config = create_minimal_vllm_config(
model_name="deepseek-v3", # Used only for model path
block_size=block_size,
mla_dims=mla_dims, # Use custom dims from config or default
)
results = []
with set_current_vllm_config(vllm_config):
# Create backend impl, layer, and builder (reused across benchmarks)
impl, layer, builder_instance = _create_backend_impl(
backend_cfg, mla_dims, vllm_config, device
)
# Run each benchmark with the shared impl
for config, threshold, num_splits in configs_with_params:
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
original_threshold = None
if threshold is not None and builder_instance:
original_threshold = builder_instance.reorder_batch_threshold
builder_instance.reorder_batch_threshold = threshold
# Set num_splits for CUTLASS
original_num_splits = None
if num_splits is not None and hasattr(impl, "_num_kv_splits"):
original_num_splits = impl._num_kv_splits
impl._num_kv_splits = num_splits
try:
result = _run_single_benchmark(
config,
impl,
layer,
builder_instance,
backend_cfg,
mla_dims,
device,
)
results.append(result)
finally:
# Restore original threshold
if original_threshold is not None:
builder_instance.reorder_batch_threshold = original_threshold
# Restore original num_splits
if original_num_splits is not None:
impl._num_kv_splits = original_num_splits
return results
# ============================================================================
# Public API
# ============================================================================
def run_mla_benchmark(
backend: str,
config,
reorder_batch_threshold: int | None = None,
num_kv_splits: int | None = None,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
Always uses batched execution internally for optimal performance.
Args:
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
(single config mode only)
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
"""
# Normalize to batched mode: (config, threshold, num_splits)
if isinstance(config, list):
# Already in batched format
if len(config) > 0 and isinstance(config[0], tuple):
# Format: [(cfg, param), ...] where param is threshold or num_splits
if backend in ("flashattn_mla", "flashmla"):
configs_with_params = [(cfg, param, None) for cfg, param in config]
else: # cutlass_mla or flashinfer_mla
configs_with_params = [(cfg, None, param) for cfg, param in config]
else:
# Format: [cfg, ...] - just configs
configs_with_params = [(cfg, None, None) for cfg in config]
return_single = False
else:
# Single config: convert to batched format
configs_with_params = [(config, reorder_batch_threshold, num_kv_splits)]
return_single = True
# Use unified batched execution
results = _run_mla_benchmark_batched(backend, configs_with_params)
# Return single result or list based on input
return results[0] if return_single else results
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standard attention benchmark runner - shared utilities for non-MLA benchmarks.
This module provides helpers for running standard attention backends
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
"""
import types
import numpy as np
import torch
from batch_spec import parse_batch_spec, reorder_for_flashinfer
from common import BenchmarkConfig, BenchmarkResult, MockLayer, get_attention_scale
from vllm.config import (
CacheConfig,
CompilationConfig,
DeviceConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
# Backend Configuration
# ============================================================================
_BACKEND_CONFIG = {
"flash": {
"module": "vllm.v1.attention.backends.flash_attn",
"backend_class": "FlashAttentionBackend",
"dtype": torch.float16,
"cache_layout": "standard",
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
},
"triton": {
"module": "vllm.v1.attention.backends.triton_attn",
"backend_class": "TritonAttentionBackend",
"dtype": torch.float32,
"cache_layout": "standard",
},
"flashinfer": {
"module": "vllm.v1.attention.backends.flashinfer",
"backend_class": "FlashInferBackend",
"dtype": torch.float16,
"cache_layout": "flashinfer",
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
},
}
def _get_backend_config(backend: str) -> dict:
if backend not in _BACKEND_CONFIG:
raise ValueError(
f"Unknown backend: {backend}. "
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
)
return _BACKEND_CONFIG[backend]
# ============================================================================
# Metadata Building Helpers
# ============================================================================
def _build_common_attn_metadata(
q_lens: list[int],
kv_lens: list[int],
block_size: int,
device: torch.device,
) -> CommonAttentionMetadata:
"""Build CommonAttentionMetadata from query/kv lengths."""
batch_size = len(q_lens)
total_tokens = sum(q_lens)
query_start_loc = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
query_start_loc[1:] = torch.tensor(q_lens, dtype=torch.int32, device=device).cumsum(
0
)
query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_blocks = (max(kv_lens) + block_size - 1) // block_size
num_blocks = batch_size * max_blocks
block_table_tensor = torch.arange(
num_blocks, dtype=torch.int32, device=device
).view(batch_size, max_blocks)
slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device)
max_query_len = max(q_lens)
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_size,
num_actual_tokens=total_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
def _create_vllm_config(
config: BenchmarkConfig,
dtype: torch.dtype,
max_num_blocks: int,
) -> VllmConfig:
"""Create a VllmConfig for benchmarking with mock model methods."""
model_config = ModelConfig(
model="meta-llama/Meta-Llama-3-8B",
tokenizer="meta-llama/Meta-Llama-3-8B",
trust_remote_code=False,
dtype=dtype,
seed=0,
max_model_len=1024,
)
cache_config = CacheConfig(
block_size=config.block_size,
cache_dtype="auto",
swap_space=0,
)
cache_config.num_gpu_blocks = max_num_blocks
cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig(tensor_parallel_size=1)
scheduler_config = SchedulerConfig(
max_num_seqs=256,
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
enable_chunked_prefill=True,
)
device_config = DeviceConfig()
load_config = LoadConfig()
compilation_config = CompilationConfig()
# Add mock methods for benchmark config values
model_config.get_num_layers = types.MethodType(
lambda self: config.num_layers, model_config
)
model_config.get_sliding_window_for_layer = types.MethodType(
lambda self, i: None, model_config
)
model_config.get_logits_soft_cap_for_layer = types.MethodType(
lambda self, i: 0.0, model_config
)
model_config.get_sm_scale_for_layer = types.MethodType(
lambda self, i: 1.0 / config.head_dim**0.5, model_config
)
model_config.get_num_attention_heads = types.MethodType(
lambda self, parallel_config=None: config.num_q_heads, model_config
)
model_config.get_num_kv_heads = types.MethodType(
lambda self, parallel_config=None: config.num_kv_heads, model_config
)
model_config.get_head_size = types.MethodType(
lambda self: config.head_dim, model_config
)
model_config.get_sliding_window = types.MethodType(lambda self: None, model_config)
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
compilation_config=compilation_config,
)
# ============================================================================
# Backend Initialization
# ============================================================================
def _create_backend_impl(
backend_cfg: dict,
config: BenchmarkConfig,
device: torch.device,
):
"""Create backend implementation instance."""
import importlib
backend_module = importlib.import_module(backend_cfg["module"])
backend_class = getattr(backend_module, backend_cfg["backend_class"])
scale = get_attention_scale(config.head_dim)
dtype = backend_cfg["dtype"]
impl = backend_class.get_impl_cls()(
num_heads=config.num_q_heads,
head_size=config.head_dim,
scale=scale,
num_kv_heads=config.num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
return backend_class, impl, layer, dtype
def _create_metadata_builder(
backend_class,
kv_cache_spec: FullAttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
):
"""Create metadata builder instance."""
return backend_class.get_builder_cls()(
kv_cache_spec=kv_cache_spec,
layer_names=["layer_0"],
vllm_config=vllm_config,
device=device,
)
# ============================================================================
# Tensor Creation Helpers
# ============================================================================
def _create_input_tensors(
config: BenchmarkConfig,
total_q: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple:
"""Create Q, K, V input tensors for all layers."""
q_list = [
torch.randn(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
)
for _ in range(config.num_layers)
]
k_list = [
torch.randn(
total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
)
for _ in range(config.num_layers)
]
v_list = [
torch.randn(
total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
)
for _ in range(config.num_layers)
]
return q_list, k_list, v_list
def _create_kv_cache(
config: BenchmarkConfig,
max_num_blocks: int,
cache_layout: str,
device: torch.device,
dtype: torch.dtype,
) -> list:
"""Create KV cache tensors for all layers."""
if cache_layout == "flashinfer":
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
max_num_blocks,
2,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
else:
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
2,
max_num_blocks,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
return cache_list
# ============================================================================
# Benchmark Execution
# ============================================================================
def _run_single_benchmark(
config: BenchmarkConfig,
impl,
layer,
q_list: list,
k_list: list,
v_list: list,
cache_list: list,
attn_metadata,
device: torch.device,
dtype: torch.dtype,
) -> tuple:
"""Run single benchmark iteration with warmup and timing loop."""
total_q = q_list[0].shape[0]
out = torch.empty(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
)
# Warmup
for _ in range(config.warmup_iters):
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
torch.cuda.synchronize()
# Benchmark
times = []
for _ in range(config.repeats):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer
mem_stats = {}
if config.profile_memory:
mem_stats = {
"allocated_mb": torch.cuda.memory_allocated(device) / 1024**2,
"reserved_mb": torch.cuda.memory_reserved(device) / 1024**2,
}
return times, mem_stats
# ============================================================================
# Public API
# ============================================================================
def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""
Run standard attention benchmark with real kernels.
Supports: flash, triton, flashinfer
Args:
config: Benchmark configuration
Returns:
BenchmarkResult with timing and memory statistics
"""
device = torch.device(config.device)
torch.cuda.set_device(device)
backend_cfg = _get_backend_config(config.backend)
requests = parse_batch_spec(config.batch_spec)
if config.backend == "flashinfer":
requests = reorder_for_flashinfer(requests)
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv = max(kv_lens)
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
backend_class, impl, layer, dtype = _create_backend_impl(
backend_cfg, config, device
)
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device
)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_metadata,
)
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
cache_list = _create_kv_cache(
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
)
times, mem_stats = _run_single_benchmark(
config,
impl,
layer,
q_list,
k_list,
v_list,
cache_list,
attn_metadata,
device,
dtype,
)
mean_time = np.mean(times)
throughput = total_q / mean_time if mean_time > 0 else 0
return BenchmarkResult(
config=config,
mean_time=mean_time,
std_time=np.std(times),
min_time=np.min(times),
max_time=np.max(times),
throughput_tokens_per_sec=throughput,
memory_allocated_mb=mem_stats.get("allocated_mb"),
memory_reserved_mb=mem_stats.get("reserved_mb"),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment