"docs/vscode:/vscode.git/clone" did not exist on "c47aafa37c7579c3f9b3188b05f43cb71d83dbb5"
Unverified Commit d0d97e29 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Fix up attention benchmarks (#33810)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 9562912c
...@@ -17,3 +17,14 @@ steps: ...@@ -17,3 +17,14 @@ steps:
- tests/benchmarks/ - tests/benchmarks/
commands: commands:
- pytest -v -s benchmarks/ - pytest -v -s benchmarks/
- label: Attention Benchmarks Smoke Test (B200)
device: b200
num_gpus: 2
optional: true
timeout_in_minutes: 10
source_file_dependencies:
- benchmarks/attention_benchmarks/
- vllm/v1/attention/
commands:
- python benchmarks/attention_benchmarks/benchmark.py --backends flash flashinfer --batch-specs "8q1s1k" --repeats 1 --warmup-iters 1
...@@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict: ...@@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict:
sum(r.kv_len for r in requests) / len(requests) if requests else 0 sum(r.kv_len for r in requests) / len(requests) if requests else 0
), ),
} }
def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str:
"""
Classify a batch spec into a type string.
Args:
batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k")
spec_decode_threshold: Max q_len to be considered spec-decode vs extend
Returns:
Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)"
"""
requests = parse_batch_spec(batch_spec)
# Classify each request
types_present = set()
for req in requests:
if req.is_decode:
types_present.add("decode")
elif req.is_prefill:
types_present.add("prefill")
elif req.is_extend:
# Distinguish spec-decode (small q_len) from extend (chunked prefill)
if req.q_len <= spec_decode_threshold:
types_present.add("spec-decode")
else:
types_present.add("extend")
if len(types_present) == 1:
return types_present.pop()
elif len(types_present) > 1:
# Sort for consistent output
sorted_types = sorted(types_present)
return f"mixed ({'+'.join(sorted_types)})"
else:
return "unknown"
...@@ -12,6 +12,7 @@ from typing import Any ...@@ -12,6 +12,7 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from batch_spec import get_batch_type, parse_batch_spec
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
...@@ -316,12 +317,14 @@ class ResultsFormatter: ...@@ -316,12 +317,14 @@ class ResultsFormatter:
backends: List of backend names being compared backends: List of backend names being compared
compare_to_fastest: Show percentage comparison to fastest compare_to_fastest: Show percentage comparison to fastest
""" """
# Group by batch spec # Group by batch spec, preserving first-occurrence order
by_spec = {} by_spec = {}
specs_order = []
for r in results: for r in results:
spec = r.config.batch_spec spec = r.config.batch_spec
if spec not in by_spec: if spec not in by_spec:
by_spec[spec] = {} by_spec[spec] = {}
specs_order.append(spec)
by_spec[spec][r.config.backend] = r by_spec[spec][r.config.backend] = r
# Create shortened backend names for display # Create shortened backend names for display
...@@ -337,6 +340,8 @@ class ResultsFormatter: ...@@ -337,6 +340,8 @@ class ResultsFormatter:
table = Table(title="Attention Benchmark Results") table = Table(title="Attention Benchmark Results")
table.add_column("Batch\nSpec", no_wrap=True) table.add_column("Batch\nSpec", no_wrap=True)
table.add_column("Type", no_wrap=True)
table.add_column("Batch\nSize", justify="right", no_wrap=True)
multi = len(backends) > 1 multi = len(backends) > 1
for backend in backends: for backend in backends:
...@@ -350,12 +355,14 @@ class ResultsFormatter: ...@@ -350,12 +355,14 @@ class ResultsFormatter:
table.add_column(col_rel, justify="right", no_wrap=False) table.add_column(col_rel, justify="right", no_wrap=False)
# Add rows # Add rows
for spec in sorted(by_spec.keys()): for spec in specs_order:
spec_results = by_spec[spec] spec_results = by_spec[spec]
times = {b: r.mean_time for b, r in spec_results.items() if r.success} times = {b: r.mean_time for b, r in spec_results.items() if r.success}
best_time = min(times.values()) if times else 0.0 best_time = min(times.values()) if times else 0.0
row = [spec] batch_type = get_batch_type(spec)
batch_size = len(parse_batch_spec(spec))
row = [spec, batch_type, str(batch_size)]
for backend in backends: for backend in backends:
if backend in spec_results: if backend in spec_results:
r = spec_results[backend] r = spec_results[backend]
......
...@@ -25,10 +25,18 @@ batch_specs: ...@@ -25,10 +25,18 @@ batch_specs:
- "4q1k_16q1s2k" # 4 prefill + 16 decode - "4q1k_16q1s2k" # 4 prefill + 16 decode
- "2q4k_32q1s1k" # 2 large prefill + 32 decode - "2q4k_32q1s1k" # 2 large prefill + 32 decode
# Context extension # Speculative decode (q <= 8)
- "q1ks2k" # 1k query, 2k sequence (chunked prefill) - "16q2s1k" # 16 requests, 2 spec tokens, 1k KV cache
- "16q4s1k" # 16 requests, 4 spec tokens, 1k KV cache
- "16q8s1k" # 16 requests, 8 spec tokens, 1k KV cache
- "32q4s2k" # 32 requests, 4 spec tokens, 2k KV cache
- "8q8s4k" # 8 requests, 8 spec tokens, 4k KV cache
# Context extension (chunked prefill)
- "q1ks2k" # 1k query, 2k sequence
- "2q1ks4k" # 2 requests: 1k query, 4k sequence - "2q1ks4k" # 2 requests: 1k query, 4k sequence
# Available backends: flash, triton, flashinfer
backends: backends:
- flash - flash
- triton - triton
......
...@@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends ...@@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends
(FlashAttention, Triton, FlashInfer) with real vLLM integration. (FlashAttention, Triton, FlashInfer) with real vLLM integration.
""" """
import logging
import types import types
from contextlib import contextmanager
import numpy as np import numpy as np
import torch import torch
...@@ -24,8 +26,13 @@ from vllm.config import ( ...@@ -24,8 +26,13 @@ from vllm.config import (
ParallelConfig, ParallelConfig,
SchedulerConfig, SchedulerConfig,
VllmConfig, VllmConfig,
set_current_vllm_config,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
get_kv_cache_layout,
set_kv_cache_layout,
) )
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================ # ============================================================================
...@@ -37,22 +44,14 @@ _BACKEND_CONFIG = { ...@@ -37,22 +44,14 @@ _BACKEND_CONFIG = {
"flash": { "flash": {
"module": "vllm.v1.attention.backends.flash_attn", "module": "vllm.v1.attention.backends.flash_attn",
"backend_class": "FlashAttentionBackend", "backend_class": "FlashAttentionBackend",
"dtype": torch.float16,
"cache_layout": "standard",
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
}, },
"triton": { "triton": {
"module": "vllm.v1.attention.backends.triton_attn", "module": "vllm.v1.attention.backends.triton_attn",
"backend_class": "TritonAttentionBackend", "backend_class": "TritonAttentionBackend",
"dtype": torch.float32,
"cache_layout": "standard",
}, },
"flashinfer": { "flashinfer": {
"module": "vllm.v1.attention.backends.flashinfer", "module": "vllm.v1.attention.backends.flashinfer",
"backend_class": "FlashInferBackend", "backend_class": "FlashInferBackend",
"dtype": torch.float16,
"cache_layout": "flashinfer",
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
}, },
} }
...@@ -66,6 +65,18 @@ def _get_backend_config(backend: str) -> dict: ...@@ -66,6 +65,18 @@ def _get_backend_config(backend: str) -> dict:
return _BACKEND_CONFIG[backend] return _BACKEND_CONFIG[backend]
@contextmanager
def log_warnings_and_errors_only():
"""Temporarily set vLLM logger to WARNING level."""
logger = logging.getLogger("vllm")
old_level = logger.level
logger.setLevel(logging.WARNING)
try:
yield
finally:
logger.setLevel(old_level)
# ============================================================================ # ============================================================================
# Metadata Building Helpers # Metadata Building Helpers
# ============================================================================ # ============================================================================
...@@ -88,11 +99,7 @@ def _build_common_attn_metadata( ...@@ -88,11 +99,7 @@ def _build_common_attn_metadata(
query_start_loc_cpu = query_start_loc.cpu() query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu() max_seq_len = int(seq_lens.max().item())
max_seq_len = int(seq_lens_cpu.max())
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_blocks = (max(kv_lens) + block_size - 1) // block_size max_blocks = (max(kv_lens) + block_size - 1) // block_size
num_blocks = batch_size * max_blocks num_blocks = batch_size * max_blocks
...@@ -107,8 +114,6 @@ def _build_common_attn_metadata( ...@@ -107,8 +114,6 @@ def _build_common_attn_metadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_size, num_reqs=batch_size,
num_actual_tokens=total_tokens, num_actual_tokens=total_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
...@@ -121,7 +126,6 @@ def _build_common_attn_metadata( ...@@ -121,7 +126,6 @@ def _build_common_attn_metadata(
def _create_vllm_config( def _create_vllm_config(
config: BenchmarkConfig, config: BenchmarkConfig,
dtype: torch.dtype,
max_num_blocks: int, max_num_blocks: int,
) -> VllmConfig: ) -> VllmConfig:
"""Create a VllmConfig for benchmarking with mock model methods.""" """Create a VllmConfig for benchmarking with mock model methods."""
...@@ -129,7 +133,7 @@ def _create_vllm_config( ...@@ -129,7 +133,7 @@ def _create_vllm_config(
model="meta-llama/Meta-Llama-3-8B", model="meta-llama/Meta-Llama-3-8B",
tokenizer="meta-llama/Meta-Llama-3-8B", tokenizer="meta-llama/Meta-Llama-3-8B",
trust_remote_code=False, trust_remote_code=False,
dtype=dtype, dtype="auto", # Use model's native dtype
seed=0, seed=0,
max_model_len=1024, max_model_len=1024,
) )
...@@ -198,6 +202,7 @@ def _create_backend_impl( ...@@ -198,6 +202,7 @@ def _create_backend_impl(
backend_cfg: dict, backend_cfg: dict,
config: BenchmarkConfig, config: BenchmarkConfig,
device: torch.device, device: torch.device,
dtype: torch.dtype,
): ):
"""Create backend implementation instance.""" """Create backend implementation instance."""
import importlib import importlib
...@@ -206,7 +211,6 @@ def _create_backend_impl( ...@@ -206,7 +211,6 @@ def _create_backend_impl(
backend_class = getattr(backend_module, backend_cfg["backend_class"]) backend_class = getattr(backend_module, backend_cfg["backend_class"])
scale = get_attention_scale(config.head_dim) scale = get_attention_scale(config.head_dim)
dtype = backend_cfg["dtype"]
impl = backend_class.get_impl_cls()( impl = backend_class.get_impl_cls()(
num_heads=config.num_q_heads, num_heads=config.num_q_heads,
...@@ -227,7 +231,7 @@ def _create_backend_impl( ...@@ -227,7 +231,7 @@ def _create_backend_impl(
layer = MockLayer(device, kv_cache_spec=kv_cache_spec) layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
return backend_class, impl, layer, dtype return backend_class, impl, layer
def _create_metadata_builder( def _create_metadata_builder(
...@@ -235,11 +239,44 @@ def _create_metadata_builder( ...@@ -235,11 +239,44 @@ def _create_metadata_builder(
kv_cache_spec: FullAttentionSpec, kv_cache_spec: FullAttentionSpec,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
backend_name: str = "",
): ):
"""Create metadata builder instance.""" """Create metadata builder instance."""
return backend_class.get_builder_cls()( layer_names = ["layer_0"]
builder_cls = backend_class.get_builder_cls()
# Flashinfer needs get_per_layer_parameters mocked since we don't have
# real model layers registered
if backend_name == "flashinfer":
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
head_size = vllm_config.model_config.get_head_size()
return {
layer_name: PerLayerParameters(
window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5), # Standard scale
)
for layer_name in layer_names
}
with unittest.mock.patch(
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
mock_get_per_layer_parameters,
):
return builder_cls(
kv_cache_spec=kv_cache_spec, kv_cache_spec=kv_cache_spec,
layer_names=["layer_0"], layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
return builder_cls(
kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config, vllm_config=vllm_config,
device=device, device=device,
) )
...@@ -281,39 +318,44 @@ def _create_input_tensors( ...@@ -281,39 +318,44 @@ def _create_input_tensors(
def _create_kv_cache( def _create_kv_cache(
config: BenchmarkConfig, config: BenchmarkConfig,
max_num_blocks: int, max_num_blocks: int,
cache_layout: str, backend_class,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
) -> list: ) -> list:
"""Create KV cache tensors for all layers.""" """Create KV cache tensors for all layers using the backend's methods.
if cache_layout == "flashinfer":
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim] Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
cache_list = [ to create the cache with the correct shape and memory layout.
torch.zeros( """
max_num_blocks, # Get the logical shape from the backend
2, cache_shape = backend_class.get_kv_cache_shape(
config.block_size, num_blocks=max_num_blocks,
config.num_kv_heads, block_size=config.block_size,
config.head_dim, num_kv_heads=config.num_kv_heads,
device=device, head_size=config.head_dim,
dtype=dtype,
)
for _ in range(config.num_layers)
]
else:
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
2,
max_num_blocks,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
) )
for _ in range(config.num_layers)
] # Get the stride order for custom memory layout
try:
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
except (AttributeError, NotImplementedError):
stride_order = tuple(range(len(cache_shape)))
# Permute shape to physical layout order
physical_shape = tuple(cache_shape[i] for i in stride_order)
# Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
cache_list = []
for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
# Permute to logical view
cache = cache.permute(*inv_order)
cache_list.append(cache)
return cache_list return cache_list
...@@ -418,13 +460,32 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: ...@@ -418,13 +460,32 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
kv_lens = [r.kv_len for r in requests] kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens) total_q = sum(q_lens)
max_kv = max(kv_lens) max_kv = max(kv_lens)
batch_size = len(q_lens)
# Calculate total blocks needed: batch_size * max_blocks_per_request
max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size
max_num_blocks = batch_size * max_blocks_per_request
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size # Suppress vLLM logs during setup to reduce spam
with log_warnings_and_errors_only():
# Create vllm_config first - uses model's native dtype via "auto"
vllm_config = _create_vllm_config(config, max_num_blocks)
dtype = vllm_config.model_config.dtype
backend_class, impl, layer, dtype = _create_backend_impl( # Wrap everything in set_current_vllm_config context
backend_cfg, config, device # This is required for backends like flashinfer that need global config
with set_current_vllm_config(vllm_config):
backend_class, impl, layer = _create_backend_impl(
backend_cfg, config, device, dtype
) )
# Set KV cache layout if the backend requires a specific one
# (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
required_layout = backend_class.get_required_kv_cache_layout()
if required_layout is not None:
set_kv_cache_layout(required_layout)
get_kv_cache_layout.cache_clear()
common_metadata = _build_common_attn_metadata( common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device q_lens, kv_lens, config.block_size, device
) )
...@@ -436,10 +497,8 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: ...@@ -436,10 +497,8 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
dtype=dtype, dtype=dtype,
) )
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
builder = _create_metadata_builder( builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device backend_class, kv_cache_spec, vllm_config, device, config.backend
) )
attn_metadata = builder.build( attn_metadata = builder.build(
...@@ -447,10 +506,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: ...@@ -447,10 +506,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
common_attn_metadata=common_metadata, common_attn_metadata=common_metadata,
) )
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype) q_list, k_list, v_list = _create_input_tensors(
config, total_q, device, dtype
)
cache_list = _create_kv_cache( cache_list = _create_kv_cache(
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype config, max_num_blocks, backend_class, device, dtype
) )
times, mem_stats = _run_single_benchmark( times, mem_stats = _run_single_benchmark(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment