Unverified Commit e82fa448 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Add attention benchmarking tools (#26835)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent d9aa39a3
# vLLM Attention Benchmarking Suite
Fast, flexible benchmarking for vLLM attention and MLA backends with an extended batch specification grammar.
## Quick Start
```bash
cd benchmarks/attention_benchmarks
# Run a pre-configured benchmark
python benchmark.py --config configs/mla_decode.yaml
python benchmark.py --config configs/mla_mixed_batch.yaml
python benchmark.py --config configs/speculative_decode.yaml
python benchmark.py --config configs/standard_attention.yaml
python benchmark.py --config configs/reorder_threshold.yaml
# Or run custom benchmarks
python benchmark.py \
--backends flash flashinfer \
--batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \
--output-csv results.csv
```
## Simplified Batch Specification Grammar
Express workloads concisely using query length and sequence length:
```python
"q2k" # 2048-token prefill (q_len=2048, seq_len=2048)
"q1s1k" # Decode: 1 token with 1K sequence
"8q1s1k" # 8 decode requests
"q4s1k" # 4-token extend (e.g., spec decode)
"2q2k_32q1s1k" # Mixed: 2 prefills + 32 decodes
"16q4s1k" # 16 spec decode (4 tokens each)
```
### Grammar Rule
```text
Format: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
- count: Number of identical requests (optional, default=1)
- q_len: Query length (number of new tokens)
- seq_len: Total sequence length (optional, defaults to q_len for prefill)
- 'k': Multiplies value by 1024
Mixed batches: Use _ to combine (e.g., "2q2k_32q1s1k")
```
**Note**: Decode, prefill, and spec decode are just different query lengths - no special syntax needed!
## Pre-configured Benchmarks
The suite includes several pre-configured YAML benchmark configurations:
### MLA Decode Benchmark
Tests pure decode performance across MLA backends with varying batch sizes and sequence lengths.
```bash
python benchmark.py --config configs/mla_decode.yaml
```
### MLA Mixed Batch Benchmark
Tests chunked prefill performance with mixed prefill + decode batches.
```bash
python benchmark.py --config configs/mla_mixed_batch.yaml
```
### Speculative Decoding Benchmark
Tests speculative decode scenarios (K-token verification) and reorder_batch_threshold optimization.
```bash
python benchmark.py --config configs/speculative_decode.yaml
```
### Standard Attention Benchmark
Tests standard attention backends (Flash/Triton/FlashInfer) with pure prefill, decode, and mixed batches.
```bash
python benchmark.py --config configs/standard_attention.yaml
```
### Reorder Threshold Study
**Question:** At what query length does the prefill pipeline become faster than the decode pipeline?
Tests query lengths from 1-1024 across 9 batch sizes to find the crossover point. Uses `decode_vs_prefill` mode to compare both pipelines for each query length.
```bash
python benchmark.py --config configs/reorder_threshold.yaml
```
---
## Universal Benchmark
The `benchmark.py` script handles **all** backends - both standard attention and MLA.
### Standard Attention (Flash/Triton/FlashInfer)
```bash
python benchmark.py \
--backends flash triton flashinfer \
--batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \
--num-layers 10 \
--repeats 5 \
--output-csv results.csv
```
### MLA Backends
```bash
# Compare all MLA backends
python benchmark.py \
--backends cutlass_mla flashinfer_mla flashattn_mla flashmla \
--batch-specs "64q1s1k" "64q1s4k" \
--output-csv mla_results.csv
```
### Parameter Sweeps
Use `--sweep-param` and `--sweep-values` to run parameter sweeps from the CLI:
#### CUTLASS MLA num-splits Optimization
**Question:** What is the optimal `num_kv_splits` for CUTLASS MLA?
```bash
python benchmark.py \
--backend cutlass_mla \
--batch-specs "64q1s1k" "64q1s4k" "64q1s16k" \
--sweep-param num_kv_splits \
--sweep-values 1 2 4 8 16 \
--output-json optimal_splits.json
```
#### Reorder Batch Threshold Optimization
**Question:** What's the optimal `reorder_batch_threshold` for speculative decoding?
```bash
python benchmark.py \
--backend flashmla \
--batch-specs "q4s1k" "q8s2k" \
--sweep-param reorder_batch_threshold \
--sweep-values 1 4 16 64 256 512 \
--output-csv threshold_sweep.csv
```
### All Command-Line Options
```text
--config CONFIG # Path to YAML config file (overrides other args)
--backends BACKEND [BACKEND ...] # flash, triton, flashinfer, cutlass_mla,
# flashinfer_mla, flashattn_mla, flashmla
--backend BACKEND # Single backend (alternative to --backends)
--batch-specs SPEC [SPEC ...] # Batch specifications using extended grammar
# Model configuration
--num-layers N # Number of layers
--head-dim N # Head dimension
--num-q-heads N # Query heads
--num-kv-heads N # KV heads
--block-size N # Block size
# Benchmark settings
--device DEVICE # Device (default: cuda:0)
--repeats N # Repetitions
--warmup-iters N # Warmup iterations
--profile-memory # Profile memory usage
# Parameter sweeps
--sweep-param PARAM # Parameter name to sweep (e.g., num_kv_splits,
# reorder_batch_threshold)
--sweep-values N [N ...] # Values to sweep for the parameter
# Output
--output-csv FILE # Save to CSV
--output-json FILE # Save to JSON
```
## Hardware Requirements
| Backend | Hardware |
|---------|----------|
| Flash/Triton/FlashInfer | Any CUDA GPU |
| CUTLASS MLA | Blackwell (SM100+) |
| FlashAttn MLA | Hopper (SM90+) |
| FlashMLA | Hopper (SM90+) |
| FlashInfer-MLA | Any CUDA GPU |
## Using MLA Runner Directly
All MLA backends are available through `mla_runner.run_mla_benchmark()`:
```python
from mla_runner import run_mla_benchmark
from common import BenchmarkConfig
config = BenchmarkConfig(
backend="cutlass_mla",
batch_spec="64q1s4k",
num_layers=10,
head_dim=576,
num_q_heads=128,
num_kv_heads=1,
block_size=128,
device="cuda:0",
repeats=5,
warmup_iters=3,
)
# CUTLASS MLA with specific num_kv_splits
result = run_mla_benchmark("cutlass_mla", config, num_kv_splits=4)
print(f"Time: {result.mean_time:.6f}s")
# FlashInfer-MLA
result = run_mla_benchmark("flashinfer_mla", config)
# FlashAttn MLA (Hopper SM90+)
result = run_mla_benchmark("flashattn_mla", config, reorder_batch_threshold=64)
# FlashMLA (Hopper SM90+)
result = run_mla_benchmark("flashmla", config, reorder_batch_threshold=64)
```
## Python API
```python
from batch_spec import parse_batch_spec, format_batch_spec, get_batch_stats
from common import BenchmarkConfig, BenchmarkResult, ResultsFormatter
# Parse batch specs
requests = parse_batch_spec("2q2k_q4s1k_32q1s1k")
print(format_batch_spec(requests))
# "2 prefill (2x2k), 1 extend (1xq4kv1k), 32 decode (32x1k)"
# Get batch statistics
stats = get_batch_stats(requests)
print(f"Total tokens: {stats['total_tokens']}")
print(f"Num decode: {stats['num_decode']}, Num prefill: {stats['num_prefill']}")
# Format results
formatter = ResultsFormatter()
formatter.save_csv(results, "output.csv")
formatter.save_json(results, "output.json")
```
## Tips
**1. Warmup matters** - Use `--warmup-iters 10` for stable results
**2. Multiple repeats** - Use `--repeats 20` for low variance
**3. Save results** - Always use `--output-csv` or `--output-json`
**4. Test incrementally** - Start with `--num-layers 1 --repeats 1`
**5. Extended grammar** - Leverage spec decode, chunked prefill patterns
**6. Parameter sweeps** - Use `--sweep-param` and `--sweep-values` to find optimal values
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM Attention Benchmarking Suite."""
from .batch_spec import (
BatchRequest,
format_batch_spec,
get_batch_stats,
parse_batch_spec,
reorder_for_flashinfer,
split_by_type,
)
from .common import (
BenchmarkConfig,
BenchmarkResult,
MockLayer,
MockModelConfig,
ResultsFormatter,
get_attention_scale,
is_mla_backend,
setup_mla_dims,
)
__all__ = [
# Batch specification
"BatchRequest",
"parse_batch_spec",
"format_batch_spec",
"reorder_for_flashinfer",
"split_by_type",
"get_batch_stats",
# Benchmarking infrastructure
"BenchmarkConfig",
"BenchmarkResult",
"ResultsFormatter",
# Mock objects
"MockLayer",
"MockModelConfig",
# Utilities
"setup_mla_dims",
"get_attention_scale",
"is_mla_backend",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simplified batch specification grammar for attention benchmarks.
Grammar (underscore-separated segments):
Format: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
- count: Number of identical requests (optional, default=1)
- q_len: Query length (number of new tokens)
- seq_len: Total sequence length (optional, defaults to q_len for prefill)
- 'k' suffix: Multiplies value by 1024
Common patterns:
- Prefill: q_len == seq_len (e.g., "q2k" → 2048 new tokens, 2048 seq)
- Decode: q_len == 1 (e.g., "q1s1k" → 1 token, 1024 seq length)
- Extend: q_len < seq_len (e.g., "q4s1k" → 4 tokens, 1024 seq length)
Examples:
q2k -> [(2048, 2048)] # Prefill: 2048 tokens
q1s1k -> [(1, 1024)] # Decode: 1 token, 1K sequence
8q1s1k -> [(1, 1024)] * 8 # 8 decode requests
q4s1k -> [(4, 1024)] # 4-token extend (spec decode)
2q1k_32q1s1k -> [(1024, 1024)] * 2 + [(1, 1024)] * 32 # Mixed batch
16q4s1k -> [(4, 1024)] * 16 # 16 spec decode requests
"""
from collections import Counter
from dataclasses import dataclass
import regex as re
@dataclass
class BatchRequest:
"""Represents a single request in a batch."""
q_len: int # Query length (number of new tokens)
kv_len: int # Total KV cache length
@property
def is_decode(self) -> bool:
"""True if this is a decode request (q_len == 1)."""
return self.q_len == 1
@property
def is_prefill(self) -> bool:
"""True if this is a pure prefill (q_len == kv_len)."""
return self.q_len == self.kv_len
@property
def is_extend(self) -> bool:
"""True if this is context extension (q_len > 1, kv_len > q_len)."""
return self.q_len > 1 and self.kv_len > self.q_len
@property
def context_len(self) -> int:
"""Context length (KV cache - query)."""
return self.kv_len - self.q_len
def as_tuple(self) -> tuple[int, int]:
"""Return as (q_len, kv_len) tuple for compatibility."""
return (self.q_len, self.kv_len)
def _parse_size(size_str: str, k_suffix: str) -> int:
"""Parse size string with optional 'k' suffix."""
size = int(size_str)
return size * 1024 if k_suffix == "k" else size
def parse_batch_spec(spec: str) -> list[BatchRequest]:
"""
Parse batch specification string into list of BatchRequest objects.
Grammar: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
Args:
spec: Batch specification string (see module docstring for grammar)
Returns:
List of BatchRequest objects
Raises:
ValueError: If spec format is invalid
"""
requests = []
for seg in spec.split("_"):
# Unified pattern: (<count>?) q<q_len>(k?) (s<seq_len>(k?))?
m = re.match(r"^(?:(\d+))?q(\d+)(k?)(?:s(\d+)(k?))?$", seg)
if m:
cnt = int(m.group(1)) if m.group(1) else 1
q_len = _parse_size(m.group(2), m.group(3))
kv_len = _parse_size(m.group(4), m.group(5)) if m.group(4) else q_len
requests.extend([BatchRequest(q_len=q_len, kv_len=kv_len)] * cnt)
continue
raise ValueError(f"Invalid batch spec segment: '{seg}'")
return requests
def format_batch_spec(requests: list[BatchRequest]) -> str:
"""
Format list of BatchRequest into human-readable string.
Groups requests by type and provides counts and sizes.
Args:
requests: List of BatchRequest objects
Returns:
Formatted string describing the batch
"""
kinds = {
"prefill": [],
"extend": [],
"decode": [],
}
for req in requests:
tup = (req.q_len, req.kv_len)
if req.is_prefill:
kinds["prefill"].append(tup)
elif req.is_extend:
kinds["extend"].append(tup)
elif req.is_decode:
kinds["decode"].append(tup)
parts = []
for kind in ["prefill", "extend", "decode"]:
lst = kinds[kind]
if not lst:
continue
cnt_total = len(lst)
ctr = Counter(lst)
inner = []
for (q, kv), cnt in ctr.items():
if kind == "prefill":
size = f"{q // 1024}k" if q % 1024 == 0 else str(q)
inner.append(f"{cnt}x{size}")
elif kind == "decode":
size = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv)
inner.append(f"{cnt}x{size}")
else: # extend
qstr = f"{q // 1024}k" if q % 1024 == 0 else str(q)
kstr = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv)
inner.append(f"{cnt}xq{qstr}kv{kstr}")
parts.append(f"{cnt_total} {kind} ({', '.join(inner)})")
return ", ".join(parts)
def reorder_for_flashinfer(requests: list[BatchRequest]) -> list[BatchRequest]:
"""
Reorder requests for FlashInfer: decode first, then prefill.
FlashInfer expects decode requests before prefill requests for
optimal performance.
Args:
requests: Original list of BatchRequest
Returns:
Reordered list with decode requests first
"""
decodes = [r for r in requests if r.is_decode]
non_decodes = [r for r in requests if not r.is_decode]
return decodes + non_decodes
def split_by_type(
requests: list[BatchRequest],
) -> dict[str, list[BatchRequest]]:
"""
Split requests by type for analysis.
Args:
requests: List of BatchRequest
Returns:
Dict with keys: 'decode', 'prefill', 'extend'
"""
result = {
"decode": [],
"prefill": [],
"extend": [],
}
for req in requests:
if req.is_decode:
result["decode"].append(req)
elif req.is_prefill:
result["prefill"].append(req)
elif req.is_extend:
result["extend"].append(req)
return result
def get_batch_stats(requests: list[BatchRequest]) -> dict:
"""
Compute statistics about a batch.
Args:
requests: List of BatchRequest
Returns:
Dict with batch statistics
"""
by_type = split_by_type(requests)
return {
"total_requests": len(requests),
"num_decode": len(by_type["decode"]),
"num_prefill": len(by_type["prefill"]),
"num_extend": len(by_type["extend"]),
"total_tokens": sum(r.q_len for r in requests),
"total_kv_cache": sum(r.kv_len for r in requests),
"max_q_len": max((r.q_len for r in requests), default=0),
"max_kv_len": max((r.kv_len for r in requests), default=0),
"avg_q_len": sum(r.q_len for r in requests) / len(requests) if requests else 0,
"avg_kv_len": (
sum(r.kv_len for r in requests) / len(requests) if requests else 0
),
}
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Common utilities for attention benchmarking."""
import csv
import json
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import numpy as np
import torch
from rich.console import Console
from rich.table import Table
# Mock classes for vLLM attention infrastructure
class MockHfConfig:
"""Mock HuggingFace config that satisfies vLLM's requirements."""
def __init__(self, mla_dims: dict):
self.num_attention_heads = mla_dims["num_q_heads"]
self.num_key_value_heads = mla_dims["num_kv_heads"]
self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"]
self.model_type = "deepseek_v2"
self.is_encoder_decoder = False
self.kv_lora_rank = mla_dims["kv_lora_rank"]
self.qk_nope_head_dim = mla_dims["qk_nope_head_dim"]
self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"]
self.v_head_dim = mla_dims["v_head_dim"]
self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]
def get_text_config(self):
return self
# Import AttentionLayerBase at module level to avoid circular dependencies
try:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
_HAS_ATTENTION_LAYER_BASE = True
except ImportError:
_HAS_ATTENTION_LAYER_BASE = False
AttentionLayerBase = object # Fallback
class MockKVBProj:
"""Mock KV projection layer for MLA prefill mode.
Mimics ColumnParallelLinear behavior for kv_b_proj in MLA backends.
Projects kv_c_normed to [qk_nope_head_dim + v_head_dim] per head.
"""
def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int):
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
self.out_dim = qk_nope_head_dim + v_head_dim
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""
Project kv_c_normed to output space.
Args:
x: Input tensor [num_tokens, kv_lora_rank]
Returns:
Tuple containing output tensor
[num_tokens, num_heads, qk_nope_head_dim + v_head_dim]
"""
num_tokens = x.shape[0]
result = torch.randn(
num_tokens,
self.num_heads,
self.out_dim,
device=x.device,
dtype=x.dtype,
)
return (result,) # Return as tuple to match ColumnParallelLinear API
class MockLayer(AttentionLayerBase):
"""Mock attention layer with scale parameters and impl.
Inherits from AttentionLayerBase so it passes isinstance checks
in get_layers_from_vllm_config when FlashInfer prefill is enabled.
"""
def __init__(self, device: torch.device, impl=None, kv_cache_spec=None):
# Don't call super().__init__() as AttentionLayerBase doesn't have __init__
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._q_scale = torch.tensor(1.0, device=device)
# Scalar floats for kernels that need them
self._k_scale_float = float(self._k_scale.item())
self._v_scale_float = float(self._v_scale.item())
self._q_scale_float = float(self._q_scale.item())
# AttentionImpl for metadata builders to query
self.impl = impl
# KV cache spec for get_kv_cache_spec
self._kv_cache_spec = kv_cache_spec
def get_attn_backend(self):
"""Get the attention backend class (required by AttentionLayerBase)."""
# Return None as this is just a mock layer for benchmarking
return None
def get_kv_cache_spec(self):
"""Get the KV cache spec (required by AttentionLayerBase)."""
return self._kv_cache_spec
class MockModelConfig:
"""Mock model configuration."""
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
max_model_len: int = 32768,
):
self._n_q = num_q_heads
self._n_kv = num_kv_heads
self._d = head_dim
self.dtype = dtype
self.max_model_len = max_model_len
def get_num_attention_heads(self, _=None) -> int:
return self._n_q
def get_num_kv_heads(self, _=None) -> int:
return self._n_kv
def get_head_size(self) -> int:
return self._d
def get_num_layers(self) -> int:
"""Mock method for layer count queries."""
return 1
def get_sliding_window_for_layer(self, _layer_idx: int):
"""Mock method for sliding window queries."""
return None
def get_logits_soft_cap_for_layer(self, _layer_idx: int):
"""Mock method for logits soft cap queries."""
return None
def get_sm_scale_for_layer(self, _layer_idx: int) -> float:
"""Mock method for SM scale queries."""
return 1.0 / (self.get_head_size() ** 0.5)
class MockParallelConfig:
"""Mock parallel configuration."""
pass
class MockCompilationConfig:
"""Mock compilation configuration."""
def __init__(self):
self.full_cuda_graph = False
self.static_forward_context = {}
class MockVLLMConfig:
"""Mock VLLM configuration."""
def __init__(self):
self.compilation_config = MockCompilationConfig()
class MockRunner:
"""Mock GPU runner for metadata builders."""
def __init__(
self,
seq_lens: np.ndarray,
query_start_locs: np.ndarray,
device: torch.device,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
):
self.model_config = MockModelConfig(num_q_heads, num_kv_heads, head_dim, dtype)
self.parallel_config = MockParallelConfig()
self.vllm_config = MockVLLMConfig()
self.seq_lens_np = seq_lens
self.query_start_loc_np = query_start_locs
self.device = device
self.attention_chunk_size = None
self.num_query_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.dtype = dtype
@dataclass
class ParameterSweep:
"""Configuration for sweeping a backend parameter."""
param_name: str # Name of the backend parameter to sweep
values: list[Any] # List of values to test
include_auto: bool = False # Also test with param unset (auto mode)
label_format: str = "{backend}_{param_name}_{value}" # Result label template
def get_label(self, backend: str, value: Any) -> str:
"""Generate a label for a specific parameter value."""
return self.label_format.format(
backend=backend, param_name=self.param_name, value=value
)
@dataclass
class ModelParameterSweep:
"""Configuration for sweeping a model configuration parameter."""
param_name: str # Name of the model config parameter to sweep (e.g., "num_q_heads")
values: list[Any] # List of values to test
label_format: str = "{backend}_{param_name}_{value}" # Result label template
def get_label(self, backend: str, value: Any) -> str:
"""Generate a label for a specific parameter value."""
return self.label_format.format(
backend=backend, param_name=self.param_name, value=value
)
@dataclass
class BenchmarkConfig:
"""Configuration for a single benchmark run."""
backend: str
batch_spec: str
num_layers: int
head_dim: int
num_q_heads: int
num_kv_heads: int
block_size: int
device: str
dtype: torch.dtype = torch.float16
repeats: int = 1
warmup_iters: int = 3
profile_memory: bool = False
use_cuda_graphs: bool = False
# MLA-specific
kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None
v_head_dim: int | None = None
# Backend-specific tuning
num_kv_splits: int | None = None # CUTLASS MLA
reorder_batch_threshold: int | None = None # FlashAttn MLA, FlashMLA
@dataclass
class BenchmarkResult:
"""Results from a single benchmark run."""
config: BenchmarkConfig
mean_time: float # seconds
std_time: float # seconds
min_time: float # seconds
max_time: float # seconds
throughput_tokens_per_sec: float | None = None
memory_allocated_mb: float | None = None
memory_reserved_mb: float | None = None
error: str | None = None
@property
def success(self) -> bool:
"""Whether benchmark completed successfully."""
return self.error is None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"config": asdict(self.config),
"mean_time": self.mean_time,
"std_time": self.std_time,
"min_time": self.min_time,
"max_time": self.max_time,
"throughput_tokens_per_sec": self.throughput_tokens_per_sec,
"memory_allocated_mb": self.memory_allocated_mb,
"memory_reserved_mb": self.memory_reserved_mb,
"error": self.error,
}
class ResultsFormatter:
"""Format and display benchmark results."""
def __init__(self, console: Console | None = None):
self.console = console or Console()
def print_table(
self,
results: list[BenchmarkResult],
backends: list[str],
compare_to_fastest: bool = True,
):
"""
Print results as a rich table.
Args:
results: List of BenchmarkResult
backends: List of backend names being compared
compare_to_fastest: Show percentage comparison to fastest
"""
# Group by batch spec
by_spec = {}
for r in results:
spec = r.config.batch_spec
if spec not in by_spec:
by_spec[spec] = {}
by_spec[spec][r.config.backend] = r
# Create shortened backend names for display
def shorten_backend_name(name: str) -> str:
"""Shorten long backend names for table display."""
# Remove common prefixes
name = name.replace("flashattn_mla", "famla")
name = name.replace("flashinfer_mla", "fimla")
name = name.replace("flashmla", "fmla")
name = name.replace("cutlass_mla", "cmla")
name = name.replace("numsplits", "ns")
return name
table = Table(title="Attention Benchmark Results")
table.add_column("Batch\nSpec", no_wrap=True)
multi = len(backends) > 1
for backend in backends:
short_name = shorten_backend_name(backend)
# Time column
col_time = f"{short_name}\nTime (s)"
table.add_column(col_time, justify="right", no_wrap=False)
if multi and compare_to_fastest:
# Relative performance column
col_rel = f"{short_name}\nvs Best"
table.add_column(col_rel, justify="right", no_wrap=False)
# Add rows
for spec in sorted(by_spec.keys()):
spec_results = by_spec[spec]
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
best_time = min(times.values()) if times else 0.0
row = [spec]
for backend in backends:
if backend in spec_results:
r = spec_results[backend]
if r.success:
row.append(f"{r.mean_time:.6f}")
if multi and compare_to_fastest:
pct = (
(r.mean_time / best_time * 100) if best_time > 0 else 0
)
pct_str = f"{pct:.1f}%"
if r.mean_time == best_time:
pct_str = f"[bold green]{pct_str}[/]"
row.append(pct_str)
else:
row.append("[red]ERROR[/]")
if multi and compare_to_fastest:
row.append("-")
else:
row.append("-")
if multi and compare_to_fastest:
row.append("-")
table.add_row(*row)
self.console.print(table)
def save_csv(self, results: list[BenchmarkResult], path: str):
"""Save results to CSV file."""
if not results:
return
path_obj = Path(path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"backend",
"batch_spec",
"num_layers",
"mean_time",
"std_time",
"throughput",
"memory_mb",
],
)
writer.writeheader()
for r in results:
writer.writerow(
{
"backend": r.config.backend,
"batch_spec": r.config.batch_spec,
"num_layers": r.config.num_layers,
"mean_time": r.mean_time,
"std_time": r.std_time,
"throughput": r.throughput_tokens_per_sec or 0,
"memory_mb": r.memory_allocated_mb or 0,
}
)
self.console.print(f"[green]Saved CSV results to {path}[/]")
def save_json(self, results: list[BenchmarkResult], path: str):
"""Save results to JSON file."""
path_obj = Path(path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
data = [r.to_dict() for r in results]
with open(path, "w") as f:
json.dump(data, f, indent=2, default=str)
self.console.print(f"[green]Saved JSON results to {path}[/]")
def setup_mla_dims(model_name: str = "deepseek-v3") -> dict:
"""
Get MLA dimensions for known models.
Args:
model_name: Model identifier
Returns:
Dict with MLA dimension configuration
"""
configs = {
"deepseek-v2": {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": 128,
"num_kv_heads": 1,
"head_dim": 576,
},
"deepseek-v3": {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": 128,
"num_kv_heads": 1,
"head_dim": 576,
},
"deepseek-v2-lite": {
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"num_q_heads": 16,
"num_kv_heads": 1,
"head_dim": 576,
},
}
if model_name not in configs:
raise ValueError(
f"Unknown model '{model_name}'. Known models: {list(configs.keys())}"
)
return configs[model_name]
def get_attention_scale(head_dim: int) -> float:
"""Compute attention scale factor (1/sqrt(d))."""
return 1.0 / math.sqrt(head_dim)
def is_mla_backend(backend: str) -> bool:
"""
Check if backend is an MLA backend using the backend's is_mla() property.
Args:
backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA")
Returns:
True if the backend is an MLA backend, False otherwise
"""
from vllm.v1.attention.backends.registry import AttentionBackendEnum
try:
backend_class = AttentionBackendEnum[backend.upper()].get_class()
return backend_class.is_mla()
except (KeyError, ValueError, ImportError):
return False
# MLA decode-only benchmark configuration
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1 # MLA uses single latent KV
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
batch_specs:
# Small batches, varying sequence lengths
- "16q1s512" # 16 requests, 512 KV cache
- "16q1s1k" # 16 requests, 1k KV cache
- "16q1s2k" # 16 requests, 2k KV cache
- "16q1s4k" # 16 requests, 4k KV cache
# Medium batches
- "32q1s1k" # 32 requests, 1k KV cache
- "32q1s2k" # 32 requests, 2k KV cache
- "32q1s4k" # 32 requests, 4k KV cache
- "32q1s8k" # 32 requests, 8k KV cache
# Large batches
- "64q1s1k" # 64 requests, 1k KV cache
- "64q1s2k" # 64 requests, 2k KV cache
- "64q1s4k" # 64 requests, 4k KV cache
- "64q1s8k" # 64 requests, 8k KV cache
# Very large batches
- "128q1s1k" # 128 requests, 1k KV cache
- "128q1s2k" # 128 requests, 2k KV cache
# Long context
- "32q1s16k" # 32 requests, 16k KV cache
- "32q1s32k" # 32 requests, 32k KV cache
backends:
- cutlass_mla
- flashinfer_mla
- flashattn_mla # Hopper only
- flashmla # Hopper only
device: "cuda:0"
repeats: 5
warmup_iters: 3
profile_memory: true
# Backend-specific tuning
cutlass_mla:
num_kv_splits: auto # or specific value like 4, 8, 16
flashattn_mla:
reorder_batch_threshold: 512
flashmla:
reorder_batch_threshold: 1
# MLA mixed batch benchmark (prefill + decode)
# Tests chunked prefill performance
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128
batch_specs:
# Small prefill + decode
- "1q1k_8q1s1k" # 1 prefill + 8 decode
- "2q2k_16q1s1k" # 2 prefill + 16 decode
- "4q1k_32q1s2k" # 4 prefill + 32 decode
# Medium prefill + decode
- "2q4k_32q1s2k" # 2 medium prefill + 32 decode
- "4q4k_64q1s2k" # 4 medium prefill + 64 decode
- "8q2k_64q1s4k" # 8 prefill + 64 decode
# Large prefill + decode (chunked prefill stress test)
- "2q8k_32q1s1k" # 2 large prefill + 32 decode
- "1q16k_16q1s2k" # 1 very large prefill + 16 decode
- "2q16k_32q1s4k" # 2 very large prefill + 32 decode
# Context extension + decode
- "2q1kkv2k_16q1s1k" # 2 extend + 16 decode
- "4q2kkv4k_32q1s2k" # 4 extend + 32 decode
- "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode
# Explicitly chunked prefill
- "q8k" # 8k prefill with chunking hint
- "q16k" # 16k prefill with chunking hint
- "2q8k_32q1s2k" # 2 chunked prefill + 32 decode
# High decode ratio (realistic serving)
- "1q2k_63q1s1k" # 1 prefill + 63 decode
- "2q2k_62q1s2k" # 2 prefill + 62 decode
- "4q4k_60q1s4k" # 4 prefill + 60 decode
backends:
- cutlass_mla
- flashinfer_mla
- flashattn_mla # Hopper only
- flashmla # Hopper only
device: "cuda:0"
repeats: 5
warmup_iters: 3
profile_memory: true
# Analyze chunked prefill workspace size impact
chunked_prefill:
test_workspace_sizes: [4096, 8192, 16384, 32768, 65536]
# Study 4: What is optimal reorder_batch_threshold for MLA backends supporting query length > 1?
# Question: At what query length does prefill pipeline become faster than decode pipeline?
# Methodology: For each query length, compare decode vs prefill performance to find crossover point
# Applies to: FlashAttn MLA, FlashMLA
description: "Decode vs Prefill pipeline crossover analysis"
# Test FlashAttn MLA
backend: flashattn_mla
# Mode: decode_vs_prefill comparison (special sweep mode)
# For each batch spec, we'll test both decode and prefill pipelines
mode: "decode_vs_prefill"
# Query lengths to test (from old benchmark_mla_threshold.py methodology)
# Each query length will be tested with BOTH decode and prefill pipelines:
# - decode: threshold >= query_length (forces decode pipeline)
# - prefill: threshold < query_length (forces prefill pipeline)
#
# We use q<N>s1k format which creates q_len=N, seq_len=1024 requests
# This tests different query lengths with fixed sequence length context
#
# Using batch_spec_ranges for automatic generation:
batch_spec_ranges:
- template: "q{q_len}s1k"
q_len:
start: 1
stop: 16
step: 1
end_inclusive: false
- template: "q{q_len}s1k"
q_len:
start: 16
stop: 64
step: 2
end_inclusive: false
- template: "q{q_len}s1k"
q_len:
start: 64
stop: 1024
step: 4
end_inclusive: true
# Batch sizes to test (from old script)
batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
# Model configuration (DeepSeek V2/V3 defaults)
model:
num_layers: 10
head_dim: 576
num_q_heads: 128
num_kv_heads: 1
block_size: 128
# Benchmark settings
benchmark:
device: "cuda:0"
repeats: 15 # More repeats for spec decode variance
warmup_iters: 5
profile_memory: false
# Output
output:
csv: "reorder_threshold_results.csv"
json: "reorder_threshold_results.json"
# Expected outcome (reproduces old benchmark_mla_threshold.py study):
# - For each batch size, find the crossover point where prefill becomes faster than decode
# - Show decode vs prefill performance across all query lengths
# - Determine optimal reorder_batch_threshold based on last query length where decode is faster
# - Understand how crossover point varies with batch size
# - Provide data-driven guidance for default threshold value
#
# Methodology (from old script):
# - Each query length tested with BOTH pipelines:
# * decode: threshold >= query_length (forces decode pipeline)
# * prefill: threshold < query_length (forces prefill pipeline)
# - Compare which is faster to find crossover point
#
# Speculative decoding benchmark configuration
# Tests reorder_batch_threshold optimization
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
batch_specs:
# Pure speculative decode (K-token verification)
- "q2s1k" # 2-token spec, 1k KV
- "q4s1k" # 4-token spec, 1k KV
- "q8s1k" # 8-token spec, 1k KV
- "q16s1k" # 16-token spec, 1k KV
# Speculative with different context lengths
- "q4s2k" # 4-token spec, 2k KV
- "q4s4k" # 4-token spec, 4k KV
- "q8s2k" # 8-token spec, 2k KV
- "q8s4k" # 8-token spec, 4k KV
# Mixed: speculative + regular decode
- "32q4s1k" # 32 spec requests
- "16q4s1k_16q1s1k" # 16 spec + 16 regular
- "8q8s2k_24q1s2k" # 8 spec (8-tok) + 24 regular
# Mixed: speculative + prefill + decode
- "2q1k_16q4s1k_16q1s1k" # 2 prefill + 16 spec + 16 decode
- "4q2k_32q4s2k_32q1s2k" # 4 prefill + 32 spec + 32 decode
# Large batches with speculation
- "64q4s1k" # 64 spec requests
- "32q8s2k" # 32 spec (8-token)
- "16q16s4k" # 16 spec (16-token)
# Backends that support query length > 1
backends:
- flashattn_mla # reorder_batch_threshold = 512
- flashmla # reorder_batch_threshold = 1 (tunable)
# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism
# - flashinfer_mla
# Benchmark settings
benchmark:
device: "cuda:0"
repeats: 10 # More repeats for statistical significance
warmup_iters: 5
profile_memory: false
# Test these threshold values for optimization
parameter_sweep:
param_name: "reorder_batch_threshold"
values: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
include_auto: false
label_format: "{backend}_threshold_{value}"
# Standard attention backend benchmark configuration
model:
num_layers: 32
num_q_heads: 32
num_kv_heads: 8 # GQA with 4:1 ratio
head_dim: 128
block_size: 16
batch_specs:
# Pure prefill
- "q512" # Small prefill (512 tokens)
- "q2k" # Medium prefill (2048 tokens)
- "q4k" # Large prefill (4096 tokens)
- "q8k" # Very large prefill (8192 tokens)
# Pure decode
- "8q1s1k" # 8 requests, 1k KV cache each
- "16q1s2k" # 16 requests, 2k KV cache each
- "32q1s1k" # 32 requests, 1k KV cache each
- "64q1s4k" # 64 requests, 4k KV cache each
# Mixed prefill/decode
- "2q2k_8q1s1k" # 2 prefill + 8 decode
- "4q1k_16q1s2k" # 4 prefill + 16 decode
- "2q4k_32q1s1k" # 2 large prefill + 32 decode
# Context extension
- "q1ks2k" # 1k query, 2k sequence (chunked prefill)
- "2q1ks4k" # 2 requests: 1k query, 4k sequence
backends:
- flash
- triton
- flashinfer
device: "cuda:0"
repeats: 5
warmup_iters: 3
profile_memory: false
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standard attention benchmark runner - shared utilities for non-MLA benchmarks.
This module provides helpers for running standard attention backends
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
"""
import types
import numpy as np
import torch
from batch_spec import parse_batch_spec, reorder_for_flashinfer
from common import BenchmarkConfig, BenchmarkResult, MockLayer, get_attention_scale
from vllm.config import (
CacheConfig,
CompilationConfig,
DeviceConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
# Backend Configuration
# ============================================================================
_BACKEND_CONFIG = {
"flash": {
"module": "vllm.v1.attention.backends.flash_attn",
"backend_class": "FlashAttentionBackend",
"dtype": torch.float16,
"cache_layout": "standard",
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
},
"triton": {
"module": "vllm.v1.attention.backends.triton_attn",
"backend_class": "TritonAttentionBackend",
"dtype": torch.float32,
"cache_layout": "standard",
},
"flashinfer": {
"module": "vllm.v1.attention.backends.flashinfer",
"backend_class": "FlashInferBackend",
"dtype": torch.float16,
"cache_layout": "flashinfer",
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
},
}
def _get_backend_config(backend: str) -> dict:
if backend not in _BACKEND_CONFIG:
raise ValueError(
f"Unknown backend: {backend}. "
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
)
return _BACKEND_CONFIG[backend]
# ============================================================================
# Metadata Building Helpers
# ============================================================================
def _build_common_attn_metadata(
q_lens: list[int],
kv_lens: list[int],
block_size: int,
device: torch.device,
) -> CommonAttentionMetadata:
"""Build CommonAttentionMetadata from query/kv lengths."""
batch_size = len(q_lens)
total_tokens = sum(q_lens)
query_start_loc = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
query_start_loc[1:] = torch.tensor(q_lens, dtype=torch.int32, device=device).cumsum(
0
)
query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_blocks = (max(kv_lens) + block_size - 1) // block_size
num_blocks = batch_size * max_blocks
block_table_tensor = torch.arange(
num_blocks, dtype=torch.int32, device=device
).view(batch_size, max_blocks)
slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device)
max_query_len = max(q_lens)
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_size,
num_actual_tokens=total_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
def _create_vllm_config(
config: BenchmarkConfig,
dtype: torch.dtype,
max_num_blocks: int,
) -> VllmConfig:
"""Create a VllmConfig for benchmarking with mock model methods."""
model_config = ModelConfig(
model="meta-llama/Meta-Llama-3-8B",
tokenizer="meta-llama/Meta-Llama-3-8B",
trust_remote_code=False,
dtype=dtype,
seed=0,
max_model_len=1024,
)
cache_config = CacheConfig(
block_size=config.block_size,
cache_dtype="auto",
swap_space=0,
)
cache_config.num_gpu_blocks = max_num_blocks
cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig(tensor_parallel_size=1)
scheduler_config = SchedulerConfig(
max_num_seqs=256,
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
enable_chunked_prefill=True,
)
device_config = DeviceConfig()
load_config = LoadConfig()
compilation_config = CompilationConfig()
# Add mock methods for benchmark config values
model_config.get_num_layers = types.MethodType(
lambda self: config.num_layers, model_config
)
model_config.get_sliding_window_for_layer = types.MethodType(
lambda self, i: None, model_config
)
model_config.get_logits_soft_cap_for_layer = types.MethodType(
lambda self, i: 0.0, model_config
)
model_config.get_sm_scale_for_layer = types.MethodType(
lambda self, i: 1.0 / config.head_dim**0.5, model_config
)
model_config.get_num_attention_heads = types.MethodType(
lambda self, parallel_config=None: config.num_q_heads, model_config
)
model_config.get_num_kv_heads = types.MethodType(
lambda self, parallel_config=None: config.num_kv_heads, model_config
)
model_config.get_head_size = types.MethodType(
lambda self: config.head_dim, model_config
)
model_config.get_sliding_window = types.MethodType(lambda self: None, model_config)
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
compilation_config=compilation_config,
)
# ============================================================================
# Backend Initialization
# ============================================================================
def _create_backend_impl(
backend_cfg: dict,
config: BenchmarkConfig,
device: torch.device,
):
"""Create backend implementation instance."""
import importlib
backend_module = importlib.import_module(backend_cfg["module"])
backend_class = getattr(backend_module, backend_cfg["backend_class"])
scale = get_attention_scale(config.head_dim)
dtype = backend_cfg["dtype"]
impl = backend_class.get_impl_cls()(
num_heads=config.num_q_heads,
head_size=config.head_dim,
scale=scale,
num_kv_heads=config.num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
return backend_class, impl, layer, dtype
def _create_metadata_builder(
backend_class,
kv_cache_spec: FullAttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
):
"""Create metadata builder instance."""
return backend_class.get_builder_cls()(
kv_cache_spec=kv_cache_spec,
layer_names=["layer_0"],
vllm_config=vllm_config,
device=device,
)
# ============================================================================
# Tensor Creation Helpers
# ============================================================================
def _create_input_tensors(
config: BenchmarkConfig,
total_q: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple:
"""Create Q, K, V input tensors for all layers."""
q_list = [
torch.randn(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
)
for _ in range(config.num_layers)
]
k_list = [
torch.randn(
total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
)
for _ in range(config.num_layers)
]
v_list = [
torch.randn(
total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype
)
for _ in range(config.num_layers)
]
return q_list, k_list, v_list
def _create_kv_cache(
config: BenchmarkConfig,
max_num_blocks: int,
cache_layout: str,
device: torch.device,
dtype: torch.dtype,
) -> list:
"""Create KV cache tensors for all layers."""
if cache_layout == "flashinfer":
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
max_num_blocks,
2,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
else:
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
2,
max_num_blocks,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
return cache_list
# ============================================================================
# Benchmark Execution
# ============================================================================
def _run_single_benchmark(
config: BenchmarkConfig,
impl,
layer,
q_list: list,
k_list: list,
v_list: list,
cache_list: list,
attn_metadata,
device: torch.device,
dtype: torch.dtype,
) -> tuple:
"""Run single benchmark iteration with warmup and timing loop."""
total_q = q_list[0].shape[0]
out = torch.empty(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
)
# Warmup
for _ in range(config.warmup_iters):
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
torch.cuda.synchronize()
# Benchmark
times = []
for _ in range(config.repeats):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer
mem_stats = {}
if config.profile_memory:
mem_stats = {
"allocated_mb": torch.cuda.memory_allocated(device) / 1024**2,
"reserved_mb": torch.cuda.memory_reserved(device) / 1024**2,
}
return times, mem_stats
# ============================================================================
# Public API
# ============================================================================
def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""
Run standard attention benchmark with real kernels.
Supports: flash, triton, flashinfer
Args:
config: Benchmark configuration
Returns:
BenchmarkResult with timing and memory statistics
"""
device = torch.device(config.device)
torch.cuda.set_device(device)
backend_cfg = _get_backend_config(config.backend)
requests = parse_batch_spec(config.batch_spec)
if config.backend == "flashinfer":
requests = reorder_for_flashinfer(requests)
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv = max(kv_lens)
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
backend_class, impl, layer, dtype = _create_backend_impl(
backend_cfg, config, device
)
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device
)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_metadata,
)
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
cache_list = _create_kv_cache(
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
)
times, mem_stats = _run_single_benchmark(
config,
impl,
layer,
q_list,
k_list,
v_list,
cache_list,
attn_metadata,
device,
dtype,
)
mean_time = np.mean(times)
throughput = total_q / mean_time if mean_time > 0 else 0
return BenchmarkResult(
config=config,
mean_time=mean_time,
std_time=np.std(times),
min_time=np.min(times),
max_time=np.max(times),
throughput_tokens_per_sec=throughput,
memory_allocated_mb=mem_stats.get("allocated_mb"),
memory_reserved_mb=mem_stats.get("reserved_mb"),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment